@@ -20,7 +20,16 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
2020 assert config .get ('pretrained_model_name_or_path' ), "pretrained_model_name_or_path required"
2121
2222 model_cmd = ["--model_name" , config .get ('model_name' ),
23- "--pretrained_model_name_or_path" , config .get ('pretrained_model_name_or_path' )]
23+ "--pretrained_model_name_or_path" , config .get ('pretrained_model_name_or_path' ),
24+ "--text_encoder_dtype" , config .get ('text_encoder_dtype' ),
25+ "--text_encoder_2_dtype" , config .get ('text_encoder_2_dtype' ),
26+ "--text_encoder_3_dtype" , config .get ('text_encoder_3_dtype' ),
27+ "--vae_dtype" , config .get ('vae_dtype' )]
28+
29+ if config .get ('layerwise_upcasting_modules' ) != 'none' :
30+ model_cmd += ["--layerwise_upcasting_modules" , config .get ('layerwise_upcasting_modules' ),
31+ "--layerwise_upcasting_storage_dtype" , config .get ('layerwise_upcasting_storage_dtype' ),
32+ "--layerwise_upcasting_skip_modules_pattern" , config .get ('layerwise_upcasting_skip_modules_pattern' )]
2433
2534 dataset_cmd = ["--data_root" , config .get ('data_root' ),
2635 "--video_column" , config .get ('video_column' ),
@@ -36,6 +45,7 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
3645 "--text_encoder_2_dtype" , config .get ('text_encoder_2_dtype' ),
3746 "--text_encoder_3_dtype" , config .get ('text_encoder_3_dtype' ),
3847 "--vae_dtype" , config .get ('vae_dtype' ),
48+ "--transformer_dtype" , config .get ('transformer_dtype' ),
3949 '--precompute_conditions' if config .get ('precompute_conditions' ) else '' ]
4050 if config .get ('dataset_file' ):
4151 dataset_cmd += ["--dataset_file" , config .get ('dataset_file' )]
@@ -47,7 +57,6 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
4757
4858 training_cmd = ["--training_type" , config .get ('training_type' ),
4959 "--seed" , config .get ('seed' ),
50- "--mixed_precision" , config .get ('mixed_precision' ),
5160 "--batch_size" , config .get ('batch_size' ),
5261 "--train_steps" , config .get ('train_steps' ),
5362 "--rank" , config .get ('rank' ),
0 commit comments