99
1010class PreprocessPipelineI2V (ComposedPipelineBase ):
1111 _required_config_modules = [
12- "image_encoder" , "image_processor" , "text_encoder" , "tokenizer" , "vae"
12+ "image_encoder" ,
13+ "image_processor" ,
14+ "text_encoder" ,
15+ "tokenizer" ,
16+ "text_encoder_2" ,
17+ "tokenizer_2" ,
18+ "vae"
1319 ]
1420
1521 def create_pipeline_stages (self , fastvideo_args : FastVideoArgs ):
@@ -51,8 +57,13 @@ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
5157
5258
5359class PreprocessPipelineT2V (ComposedPipelineBase ):
54- _required_config_modules = ["text_encoder" , "tokenizer" , "vae" ]
55-
60+ _required_config_modules = [
61+ "text_encoder" ,
62+ "tokenizer" ,
63+ "text_encoder_2" ,
64+ "tokenizer_2" ,
65+ "vae"
66+ ]
5667 def create_pipeline_stages (self , fastvideo_args : FastVideoArgs ):
5768 assert fastvideo_args .preprocess_config is not None
5869 self .add_stage (stage_name = "text_transform_stage" ,
@@ -61,10 +72,34 @@ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
6172 preprocess_config .training_cfg_rate ,
6273 seed = fastvideo_args .preprocess_config .seed ,
6374 ))
75+ # llama_tokenizer_kwargs = {
76+ # "padding": "max_length",
77+ # "truncation": True,
78+ # "max_length": 256,
79+ # "return_tensors": "pt"
80+ # }
81+ # clip_tokenizer_kwargs = {
82+ # "padding": "max_length",
83+ # "truncation": True,
84+ # "max_length": 77,
85+ # "return_tensors": "pt"
86+ # }
87+ # if len(fastvideo_args.pipeline_config.text_encoder_configs) >= 2:
88+ # fastvideo_args.pipeline_config.text_encoder_configs[0].tokenizer_kwargs = llama_tokenizer_kwargs
89+ # fastvideo_args.pipeline_config.text_encoder_configs[1].tokenizer_kwargs = clip_tokenizer_kwargs
90+ text_encoders = [
91+ self .get_module ("text_encoder" ),
92+ self .get_module ("text_encoder_2" )
93+ ]
94+ tokenizers = [
95+ self .get_module ("tokenizer" ),
96+ self .get_module ("tokenizer_2" )
97+ ]
98+
6499 self .add_stage (stage_name = "prompt_encoding_stage" ,
65100 stage = TextEncodingStage (
66- text_encoders = [ self . get_module ( "text_encoder" )] ,
67- tokenizers = [ self . get_module ( "tokenizer" )] ,
101+ text_encoders = text_encoders ,
102+ tokenizers = tokenizers ,
68103 ))
69104 self .add_stage (
70105 stage_name = "video_transform_stage" ,
0 commit comments