@@ -28,13 +28,17 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
2828 "--id_token" , config .get ('id_token' ),
2929 "--video_resolution_buckets" ]
3030 dataset_cmd += config .get ('video_resolution_buckets' ).split (' ' )
31+ dataset_cmd += ["--image_resolution_buckets" ]
32+ dataset_cmd += config .get ('image_resolution_buckets' ).split (' ' )
3133 dataset_cmd += ["--caption_dropout_p" , config .get ('caption_dropout_p' ),
3234 "--caption_dropout_technique" , config .get ('caption_dropout_technique' ),
3335 "--text_encoder_dtype" , config .get ('text_encoder_dtype' ),
3436 "--text_encoder_2_dtype" , config .get ('text_encoder_2_dtype' ),
3537 "--text_encoder_3_dtype" , config .get ('text_encoder_3_dtype' ),
3638 "--vae_dtype" , config .get ('vae_dtype' ),
3739 '--precompute_conditions' if config .get ('precompute_conditions' ) else '' ]
40+ if config .get ('dataset_file' ):
41+ dataset_cmd += ["--dataset_file" , config .get ('dataset_file' )]
3842
3943 dataloader_cmd = ["--dataloader_num_workers" , config .get ('dataloader_num_workers' )]
4044
@@ -56,6 +60,8 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
5660 "--checkpointing_limit" , config .get ('checkpointing_limit' ),
5761 '--enable_slicing' if config .get ('enable_slicing' ) else '' ,
5862 '--enable_tiling' if config .get ('enable_tiling' ) else '' ]
63+ if config .get ('enable_model_cpu_offload' ):
64+ training_cmd += ["--enable_model_cpu_offload" ]
5965
6066 if config .get ('resume_from_checkpoint' ):
6167 training_cmd += ["--resume_from_checkpoint" , config .get ('resume_from_checkpoint' )]
0 commit comments