4646from diffusers .training_utils import cast_training_params
4747from diffusers .utils import convert_unet_state_dict_to_peft , export_to_video , load_image
4848from diffusers .utils .hub_utils import load_or_create_model_card , populate_model_card
49- from diffusers .utils .torch_utils import is_compiled_module
5049from huggingface_hub import create_repo , upload_folder
5150from peft import LoraConfig , get_peft_model_state_dict , set_peft_model_state_dict
5251from torch .utils .data import DataLoader
5756from args import get_args # isort:skip
5857from dataset import BucketSampler , VideoDatasetWithResizing , VideoDatasetWithResizeAndRectangleCrop # isort:skip
5958from text_encoder import compute_prompt_embeddings # isort:skip
60- from utils import get_gradient_norm , get_optimizer , prepare_rotary_positional_embeddings , print_memory , reset_memory # isort:skip
59+ from utils import (
60+ get_gradient_norm ,
61+ get_optimizer ,
62+ prepare_rotary_positional_embeddings ,
63+ print_memory ,
64+ reset_memory ,
65+ unwrap_model ,
66+ )
6167
6268
6369logger = get_logger (__name__ )
@@ -155,7 +161,6 @@ def log_validation(
155161 pipe : CogVideoXImageToVideoPipeline ,
156162 args : Dict [str , Any ],
157163 pipeline_args : Dict [str , Any ],
158- epoch ,
159164 is_final_validation : bool = False ,
160165):
161166 logger .info (
@@ -201,6 +206,64 @@ def log_validation(
201206 return videos
202207
203208
209+ def run_validation (
210+ args : Dict [str , Any ],
211+ accelerator : Accelerator ,
212+ transformer ,
213+ scheduler ,
214+ model_config : Dict [str , Any ],
215+ weight_dtype : torch .dtype ,
216+ ) -> None :
217+ accelerator .print ("===== Memory before validation =====" )
218+ print_memory (accelerator .device )
219+ torch .cuda .synchronize (accelerator .device )
220+
221+ pipe = CogVideoXImageToVideoPipeline .from_pretrained (
222+ args .pretrained_model_name_or_path ,
223+ transformer = unwrap_model (accelerator , transformer ),
224+ scheduler = scheduler ,
225+ revision = args .revision ,
226+ variant = args .variant ,
227+ torch_dtype = weight_dtype ,
228+ )
229+
230+ if args .enable_slicing :
231+ pipe .vae .enable_slicing ()
232+ if args .enable_tiling :
233+ pipe .vae .enable_tiling ()
234+ if args .enable_model_cpu_offload :
235+ pipe .enable_model_cpu_offload ()
236+
237+ validation_prompts = args .validation_prompt .split (args .validation_prompt_separator )
238+ validation_images = args .validation_images .split (args .validation_prompt_separator )
239+ for validation_image , validation_prompt in zip (validation_images , validation_prompts ):
240+ pipeline_args = {
241+ "image" : load_image (validation_image ),
242+ "prompt" : validation_prompt ,
243+ "guidance_scale" : args .guidance_scale ,
244+ "use_dynamic_cfg" : args .use_dynamic_cfg ,
245+ "height" : args .height ,
246+ "width" : args .width ,
247+ "max_sequence_length" : model_config .max_text_seq_length ,
248+ }
249+
250+ log_validation (
251+ pipe = pipe ,
252+ args = args ,
253+ accelerator = accelerator ,
254+ pipeline_args = pipeline_args ,
255+ )
256+
257+ accelerator .print ("===== Memory after validation =====" )
258+ print_memory (accelerator .device )
259+ reset_memory (accelerator .device )
260+
261+ del pipe
262+ gc .collect ()
263+ torch .cuda .empty_cache ()
264+ torch .cuda .synchronize (accelerator .device )
265+
266+
204267class CollateFunction :
205268 def __init__ (self , weight_dtype : torch .dtype , load_tensors : bool ) -> None :
206269 self .weight_dtype = weight_dtype
@@ -308,6 +371,12 @@ def main(args):
308371 variant = args .variant ,
309372 )
310373
374+ # These changes will also be required when trying to run inference with the trained lora
375+ if args .ignore_learned_positional_embeddings :
376+ del transformer .patch_embed .pos_embedding
377+ transformer .patch_embed .use_learned_positional_embeddings = False
378+ transformer .config .use_learned_positional_embeddings = False
379+
311380 vae = AutoencoderKLCogVideoX .from_pretrained (
312381 args .pretrained_model_name_or_path ,
313382 subfolder = "vae" ,
@@ -373,19 +442,14 @@ def main(args):
373442 )
374443 transformer .add_adapter (transformer_lora_config )
375444
376- def unwrap_model (model ):
377- model = accelerator .unwrap_model (model )
378- model = model ._orig_mod if is_compiled_module (model ) else model
379- return model
380-
381445 # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
382446 def save_model_hook (models , weights , output_dir ):
383447 if accelerator .is_main_process :
384448 transformer_lora_layers_to_save = None
385449
386450 for model in models :
387- if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
388- model = unwrap_model (model )
451+ if isinstance (unwrap_model (accelerator , model ), type (unwrap_model (accelerator , transformer ))):
452+ model = unwrap_model (accelerator , model )
389453 transformer_lora_layers_to_save = get_peft_model_state_dict (model )
390454 else :
391455 raise ValueError (f"unexpected save model: { model .__class__ } " )
@@ -407,10 +471,10 @@ def load_model_hook(models, input_dir):
407471 while len (models ) > 0 :
408472 model = models .pop ()
409473
410- if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
411- transformer_ = unwrap_model (model )
474+ if isinstance (unwrap_model (accelerator , model ), type (unwrap_model (accelerator , transformer ))):
475+ transformer_ = unwrap_model (accelerator , model )
412476 else :
413- raise ValueError (f"Unexpected save model: { unwrap_model (model ).__class__ } " )
477+ raise ValueError (f"Unexpected save model: { unwrap_model (accelerator , model ).__class__ } " )
414478 else :
415479 transformer_ = CogVideoXTransformer3DModel .from_pretrained (
416480 args .pretrained_model_name_or_path , subfolder = "transformer"
@@ -776,6 +840,7 @@ def load_model_hook(models, input_dir):
776840 progress_bar .update (1 )
777841 global_step += 1
778842
843+ # Checkpointing
779844 if accelerator .is_main_process or accelerator .distributed_type == DistributedType .DEEPSPEED :
780845 if global_step % args .checkpointing_steps == 0 :
781846 # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
@@ -802,6 +867,13 @@ def load_model_hook(models, input_dir):
802867 accelerator .save_state (save_path )
803868 logger .info (f"Saved state to { save_path } " )
804869
870+ # Validation
871+ should_run_validation = args .validation_prompt is not None and (
872+ args .validation_steps is not None and global_step % args .validation_steps == 0
873+ )
874+ if should_run_validation :
875+ run_validation (args , accelerator , transformer , scheduler , model_config , weight_dtype )
876+
805877 last_lr = lr_scheduler .get_last_lr ()[0 ] if lr_scheduler is not None else args .learning_rate
806878 logs = {"loss" : loss .detach ().item (), "lr" : last_lr }
807879 # gradnorm + deepspeed: https://github.com/microsoft/DeepSpeed/issues/4555
@@ -819,61 +891,16 @@ def load_model_hook(models, input_dir):
819891 break
820892
821893 if accelerator .is_main_process :
822- if args .validation_prompt is not None and (epoch + 1 ) % args .validation_epochs == 0 :
823- accelerator .print ("===== Memory before validation =====" )
824- print_memory (accelerator .device )
825- torch .cuda .synchronize (accelerator .device )
826-
827- pipe = CogVideoXImageToVideoPipeline .from_pretrained (
828- args .pretrained_model_name_or_path ,
829- transformer = unwrap_model (transformer ),
830- scheduler = scheduler ,
831- revision = args .revision ,
832- variant = args .variant ,
833- torch_dtype = weight_dtype ,
834- )
835-
836- if args .enable_slicing :
837- pipe .vae .enable_slicing ()
838- if args .enable_tiling :
839- pipe .vae .enable_tiling ()
840- if args .enable_model_cpu_offload :
841- pipe .enable_model_cpu_offload ()
842-
843- validation_prompts = args .validation_prompt .split (args .validation_prompt_separator )
844- validation_images = args .validation_images .split (args .validation_prompt_separator )
845- for validation_image , validation_prompt in zip (validation_images , validation_prompts ):
846- pipeline_args = {
847- "image" : load_image (validation_image ),
848- "prompt" : validation_prompt ,
849- "guidance_scale" : args .guidance_scale ,
850- "use_dynamic_cfg" : args .use_dynamic_cfg ,
851- "height" : args .height ,
852- "width" : args .width ,
853- "max_sequence_length" : model_config .max_text_seq_length ,
854- }
855-
856- log_validation (
857- pipe = pipe ,
858- args = args ,
859- accelerator = accelerator ,
860- pipeline_args = pipeline_args ,
861- epoch = epoch ,
862- )
863-
864- accelerator .print ("===== Memory after validation =====" )
865- print_memory (accelerator .device )
866- reset_memory (accelerator .device )
867-
868- del pipe
869- gc .collect ()
870- torch .cuda .empty_cache ()
871- torch .cuda .synchronize (accelerator .device )
894+ should_run_validation = args .validation_prompt is not None and (
895+ args .validation_epochs is not None and (epoch + 1 ) % args .validation_epochs == 0
896+ )
897+ if should_run_validation :
898+ run_validation (args , accelerator , transformer , scheduler , model_config , weight_dtype )
872899
873900 accelerator .wait_for_everyone ()
874901
875902 if accelerator .is_main_process :
876- transformer = unwrap_model (transformer )
903+ transformer = unwrap_model (accelerator , transformer )
877904 dtype = (
878905 torch .float16
879906 if args .mixed_precision == "fp16"
@@ -944,7 +971,6 @@ def load_model_hook(models, input_dir):
944971 pipe = pipe ,
945972 args = args ,
946973 pipeline_args = pipeline_args ,
947- epoch = epoch ,
948974 is_final_validation = True ,
949975 )
950976 validation_outputs .extend (video )
0 commit comments