22import  time 
33from  collections  import  deque 
44from  copy  import  deepcopy 
5+ from  typing  import  Dict 
56
67import  torch 
78from  tqdm .auto  import  tqdm 
2324logger  =  init_logger (__name__ )
2425
2526
26- def  get_norm (model_pred , norms , gradient_accumulation_steps ):
27+ def  get_norm (model_pred : torch .Tensor , norms : Dict [str , float ],
28+              gradient_accumulation_steps : int ) ->  None :
2729    """Calculate and aggregate model prediction norms.""" 
2830    fro_norm  =  (
2931        torch .linalg .matrix_norm (model_pred , ord = "fro" ) /   # codespell:ignore 
@@ -66,7 +68,10 @@ def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs):
6668        args_copy .mode  =  Mode .INFERENCE 
6769        args_copy .vae_config .load_encoder  =  False 
6870        validation_pipeline  =  WanValidationPipeline .from_pretrained (
69-             fastvideo_args .model_path , args = args_copy )
71+             fastvideo_args .model_path ,
72+             args = None ,
73+             mode = Mode .INFERENCE ,
74+             loaded_modules = {"transformer" : self .get_module ("transformer" )})
7075
7176        self .validation_pipeline  =  validation_pipeline 
7277
@@ -95,11 +100,7 @@ def distill_one_step(
95100        pred_decay_weight ,
96101        pred_decay_type ,
97102        hunyuan_teacher_disable_cfg ,
98-         weighting_scheme ,
99-         logit_mean ,
100-         logit_std ,
101-         mode_scale ,
102-     ):
103+     ) ->  tuple [float , float , Dict [str , float ]]:
103104        """Perform one step of distillation training.""" 
104105        total_loss  =  0.0 
105106        optimizer .zero_grad ()
@@ -170,17 +171,16 @@ def distill_one_step(
170171                noisy_model_input , model_pred , indices , multiphase )
171172
172173            # Get teacher model prediction 
173-             with  torch .no_grad ():
174-                 with  torch .autocast ("cuda" , dtype = torch .bfloat16 ):
175-                     with  set_forward_context (current_timestep = timesteps ,
176-                                              attn_metadata = None ):
177-                         cond_teacher_output  =  teacher_transformer (
178-                             noisy_model_input ,
179-                             encoder_hidden_states ,
180-                             timesteps ,
181-                             encoder_attention_mask ,
182-                             return_dict = False ,
183-                         )[0 ].float ()
174+             with  torch .no_grad (), torch .autocast (
175+                     "cuda" , dtype = torch .bfloat16 ), set_forward_context (
176+                         current_timestep = timesteps , attn_metadata = None ):
177+                 cond_teacher_output  =  teacher_transformer (
178+                     noisy_model_input ,
179+                     encoder_hidden_states ,
180+                     timesteps ,
181+                     encoder_attention_mask ,
182+                     return_dict = False ,
183+                 )[0 ].float ()
184184
185185                if  not_apply_cfg_solver :
186186                    uncond_teacher_output  =  cond_teacher_output 
@@ -313,31 +313,30 @@ def forward(
313313        uncond_prompt_embed  =  self .uncond_prompt_embed 
314314        uncond_prompt_mask  =  self .uncond_prompt_mask 
315315
316-         # Train! 
316+         assert  self .training_args .sp_size  is  not None 
317+         assert  self .training_args .gradient_accumulation_steps  is  not None 
317318        total_batch_size  =  (self .world_size  * 
318319                            self .training_args .gradient_accumulation_steps  / 
319320                            self .training_args .sp_size  * 
320321                            self .training_args .train_sp_batch_size )
321322        logger .info ("***** Running distillation training *****" )
322-         logger .info (f"  Resume training from step { init_steps }  )
323-         logger .info (
324-             f"  Instantaneous batch size per device = { self .training_args .train_batch_size }  
325-         )
323+         logger .info ("  Resume training from step %s" , init_steps )
324+         logger .info ("  Instantaneous batch size per device = %s" ,
325+                     self .training_args .train_batch_size )
326326        logger .info (
327-             f"  Total train batch size (w. data & sequence parallel, accumulation) = { total_batch_size }  
328-         )
327+             "  Total train batch size (w. data & sequence parallel, accumulation) = %s" ,
328+             total_batch_size )
329+         logger .info ("  Gradient Accumulation steps = %s" ,
330+                     self .training_args .gradient_accumulation_steps )
331+         logger .info ("  Total optimization steps = %s" ,
332+                     self .training_args .max_train_steps )
329333        logger .info (
330-             f"  Gradient Accumulation steps = { self .training_args .gradient_accumulation_steps }  
331-         )
332-         logger .info (
333-             f"  Total optimization steps = { self .training_args .max_train_steps }  
334-         )
335-         logger .info (
336-             f"  Total training parameters per FSDP shard = { sum (p .numel () for  p  in  self .transformer .parameters () if  p .requires_grad ) /  1e9 }  
337-         )
338-         logger .info (
339-             f"  Master weight dtype: { self .transformer .parameters ().__next__ ().dtype }  
340-         )
334+             "  Total training parameters per FSDP shard = %s B" ,
335+             sum (p .numel ()
336+                 for  p  in  self .transformer .parameters () if  p .requires_grad ) / 
337+             1e9 )
338+         logger .info ("  Master weight dtype: %s" ,
339+                     self .transformer .parameters ().__next__ ().dtype )
341340
342341        # Potentially load in the weights and states from a previous save 
343342        if  self .training_args .resume_from_checkpoint :
@@ -352,13 +351,14 @@ def forward(
352351        )
353352
354353        loader_iter  =  iter (train_dataloader )
355-         step_times  =  deque (maxlen = 100 )
354+         step_times :  deque [ float ]  =  deque (maxlen = 100 )
356355
357356        # Skip steps if resuming 
358357        for  i  in  range (init_steps ):
359358            next (loader_iter )
360359
361-         def  get_num_phases (multi_phased_distill_schedule , step ):
360+         def  get_num_phases (multi_phased_distill_schedule : str ,
361+                            step : int ) ->  int :
362362            # step-phase,step-phase 
363363            multi_phases  =  multi_phased_distill_schedule .split ("," )
364364            phase  =  multi_phases [- 1 ].split ("-" )[- 1 ]
@@ -400,10 +400,6 @@ def get_num_phases(multi_phased_distill_schedule, step):
400400                    self .training_args .pred_decay_weight ,
401401                    self .training_args .pred_decay_type ,
402402                    self .training_args .hunyuan_teacher_disable_cfg ,
403-                     self .training_args .weighting_scheme ,
404-                     self .training_args .logit_mean ,
405-                     self .training_args .logit_std ,
406-                     self .training_args .mode_scale ,
407403                )
408404
409405                step_time  =  time .perf_counter () -  start_time 
@@ -462,7 +458,7 @@ def get_num_phases(multi_phased_distill_schedule, step):
462458                self .sp_group .barrier ()
463459
464460            if  self .training_args .log_validation  and  step  %  self .training_args .validation_steps  ==  0 :
465-                 self .log_validation (self .transformer , self .training_args , step )
461+                 self ._log_validation (self .transformer , self .training_args , step )
466462
467463        # Final checkpoint 
468464        if  self .training_args .use_lora :
@@ -476,7 +472,7 @@ def get_num_phases(multi_phased_distill_schedule, step):
476472            cleanup_dist_env_and_memory ()
477473
478474
479- def  main (args ):
475+ def  main (args )  ->   None :
480476    logger .info ("Starting distillation pipeline..." )
481477
482478    pipeline  =  WanDistillationPipeline .from_pretrained (
0 commit comments