|
| 1 | +from fastvideo.fastvideo_args import FastVideoArgs |
| 2 | +from fastvideo.pipelines.composed_pipeline_base import ComposedPipelineBase |
| 3 | +from fastvideo.pipelines.preprocess.preprocess_stages import ( |
| 4 | + TextTransformStage, VideoTransformStage) |
| 5 | +from fastvideo.pipelines.stages import (EncodingStage, ImageEncodingStage, |
| 6 | + TextEncodingStage) |
| 7 | +from fastvideo.pipelines.stages.image_encoding import ImageVAEEncodingStage |
| 8 | + |
| 9 | + |
| 10 | +class PreprocessPipelineI2V(ComposedPipelineBase): |
| 11 | + _required_config_modules = [ |
| 12 | + "image_encoder", "image_processor", "text_encoder", "tokenizer", "vae" |
| 13 | + ] |
| 14 | + |
| 15 | + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): |
| 16 | + assert fastvideo_args.preprocess_config is not None |
| 17 | + self.add_stage(stage_name="text_transform_stage", |
| 18 | + stage=TextTransformStage( |
| 19 | + cfg_uncondition_drop_rate=fastvideo_args. |
| 20 | + preprocess_config.training_cfg_rate, |
| 21 | + seed=fastvideo_args.preprocess_config.seed, |
| 22 | + )) |
| 23 | + self.add_stage(stage_name="prompt_encoding_stage", |
| 24 | + stage=TextEncodingStage( |
| 25 | + text_encoders=[self.get_module("text_encoder")], |
| 26 | + tokenizers=[self.get_module("tokenizer")], |
| 27 | + )) |
| 28 | + self.add_stage( |
| 29 | + stage_name="video_transform_stage", |
| 30 | + stage=VideoTransformStage( |
| 31 | + train_fps=fastvideo_args.preprocess_config.train_fps, |
| 32 | + num_frames=fastvideo_args.preprocess_config.num_frames, |
| 33 | + max_height=fastvideo_args.preprocess_config.max_height, |
| 34 | + max_width=fastvideo_args.preprocess_config.max_width, |
| 35 | + do_temporal_sample=fastvideo_args.preprocess_config. |
| 36 | + do_temporal_sample, |
| 37 | + )) |
| 38 | + if (self.get_module("image_encoder") is not None |
| 39 | + and self.get_module("image_processor") is not None): |
| 40 | + self.add_stage( |
| 41 | + stage_name="image_encoding_stage", |
| 42 | + stage=ImageEncodingStage( |
| 43 | + image_encoder=self.get_module("image_encoder"), |
| 44 | + image_processor=self.get_module("image_processor"), |
| 45 | + )) |
| 46 | + self.add_stage(stage_name="image_vae_encoding_stage", |
| 47 | + stage=ImageVAEEncodingStage( |
| 48 | + vae=self.get_module("vae"), )) |
| 49 | + self.add_stage(stage_name="video_encoding_stage", |
| 50 | + stage=EncodingStage(vae=self.get_module("vae"), )) |
| 51 | + |
| 52 | + |
| 53 | +class PreprocessPipelineT2V(ComposedPipelineBase): |
| 54 | + _required_config_modules = ["text_encoder", "tokenizer", "vae"] |
| 55 | + |
| 56 | + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): |
| 57 | + assert fastvideo_args.preprocess_config is not None |
| 58 | + self.add_stage(stage_name="text_transform_stage", |
| 59 | + stage=TextTransformStage( |
| 60 | + cfg_uncondition_drop_rate=fastvideo_args. |
| 61 | + preprocess_config.training_cfg_rate, |
| 62 | + seed=fastvideo_args.preprocess_config.seed, |
| 63 | + )) |
| 64 | + self.add_stage(stage_name="prompt_encoding_stage", |
| 65 | + stage=TextEncodingStage( |
| 66 | + text_encoders=[self.get_module("text_encoder")], |
| 67 | + tokenizers=[self.get_module("tokenizer")], |
| 68 | + )) |
| 69 | + self.add_stage( |
| 70 | + stage_name="video_transform_stage", |
| 71 | + stage=VideoTransformStage( |
| 72 | + train_fps=fastvideo_args.preprocess_config.train_fps, |
| 73 | + num_frames=fastvideo_args.preprocess_config.num_frames, |
| 74 | + max_height=fastvideo_args.preprocess_config.max_height, |
| 75 | + max_width=fastvideo_args.preprocess_config.max_width, |
| 76 | + do_temporal_sample=fastvideo_args.preprocess_config. |
| 77 | + do_temporal_sample, |
| 78 | + )) |
| 79 | + self.add_stage(stage_name="video_encoding_stage", |
| 80 | + stage=EncodingStage(vae=self.get_module("vae"), )) |
| 81 | + |
| 82 | + |
| 83 | +EntryClass = [PreprocessPipelineI2V, PreprocessPipelineT2V] |
0 commit comments