44from  copy  import  deepcopy 
55
66import  torch 
7- import  wandb 
87from  tqdm .auto  import  tqdm 
98
9+ import  wandb 
1010from  fastvideo .distill .solver  import  extract_into_tensor 
1111from  fastvideo .v1 .distributed  import  cleanup_dist_env_and_memory , get_sp_group 
1212from  fastvideo .v1 .fastvideo_args  import  FastVideoArgs , Mode , TrainingArgs 
1313from  fastvideo .v1 .forward_context  import  set_forward_context 
1414from  fastvideo .v1 .logger  import  init_logger 
1515from  fastvideo .v1 .pipelines .pipeline_batch_info  import  ForwardBatch 
16- from  fastvideo .v1 .training .training_utils  import  (
17-     clip_grad_norm_while_handling_failing_dtensor_cases ,
18-     save_checkpoint , normalize_dit_input )
1916from  fastvideo .v1 .pipelines .wan .wan_pipeline  import  WanValidationPipeline 
20- from  fastvideo .v1 .training .distillation_pipeline  import  DistillationPipeline , reshard_fsdp 
17+ from  fastvideo .v1 .training .distillation_pipeline  import  (DistillationPipeline ,
18+                                                          reshard_fsdp )
19+ from  fastvideo .v1 .training .training_utils  import  (
20+     clip_grad_norm_while_handling_failing_dtensor_cases , normalize_dit_input ,
21+     save_checkpoint )
2122
2223logger  =  init_logger (__name__ )
2324
25+ 
2426def  get_norm (model_pred , norms , gradient_accumulation_steps ):
2527    """Calculate and aggregate model prediction norms.""" 
2628    fro_norm  =  (
@@ -44,6 +46,7 @@ def get_norm(model_pred, norms, gradient_accumulation_steps):
4446    norms ["absolute mean" ] +=  absolute_mean .item ()
4547    norms ["absolute max" ] +=  absolute_max .item ()
4648
49+ 
4750class  WanDistillationPipeline (DistillationPipeline ):
4851    """ 
4952    A distillation pipeline for Wan. 
@@ -124,15 +127,14 @@ def distill_one_step(
124127            noise  =  torch .randn_like (latents )
125128
126129            indices  =  torch .randint (0 ,
127-                                   num_euler_timesteps , (batch_size , ),
128-                                   device = latents .device ).long ()
130+                                      num_euler_timesteps , (batch_size , ),
131+                                      device = latents .device ).long ()
129132
130133            if  sp_size  >  1 :
131134                self .sp_group .broadcast (indices , src = 0 )
132135
133136            # Add noise according to flow matching 
134-             sigmas  =  extract_into_tensor (solver .sigmas , indices ,
135-                                          latents .shape )
137+             sigmas  =  extract_into_tensor (solver .sigmas , indices , latents .shape )
136138            sigmas_prev  =  extract_into_tensor (solver .sigmas_prev , indices ,
137139                                              latents .shape )
138140
@@ -186,16 +188,23 @@ def distill_one_step(
186188                    # Get teacher model prediction on unconditional embedding 
187189                    with  torch .autocast ("cuda" , dtype = torch .bfloat16 ):
188190                        input_kwargs  =  {
189-                             "hidden_states" : noisy_model_input ,
190-                             "encoder_hidden_states" : uncond_prompt_embed .unsqueeze (0 ).expand (
191-                                     batch_size , - 1 , - 1 ),
192-                             "timestep" : timesteps ,
193-                             "encoder_attention_mask" : uncond_prompt_mask .unsqueeze (0 ).expand (batch_size , - 1 ),
194-                             "return_dict" : False ,
191+                             "hidden_states" :
192+                             noisy_model_input ,
193+                             "encoder_hidden_states" :
194+                             uncond_prompt_embed .unsqueeze (0 ).expand (
195+                                 batch_size , - 1 , - 1 ),
196+                             "timestep" :
197+                             timesteps ,
198+                             "encoder_attention_mask" :
199+                             uncond_prompt_mask .unsqueeze (0 ).expand (
200+                                 batch_size , - 1 ),
201+                             "return_dict" :
202+                             False ,
195203                        }
196204                        with  set_forward_context (current_timestep = timesteps ,
197205                                                 attn_metadata = None ):
198-                             uncond_teacher_output  =  teacher_transformer (** input_kwargs )[0 ]
206+                             uncond_teacher_output  =  teacher_transformer (
207+                                 ** input_kwargs )[0 ]
199208                teacher_output  =  uncond_teacher_output  +  distill_cfg  *  (
200209                    cond_teacher_output  -  uncond_teacher_output )
201210                x_prev  =  solver .euler_step (noisy_model_input , teacher_output ,
@@ -305,19 +314,24 @@ def forward(
305314        uncond_prompt_mask  =  self .uncond_prompt_mask 
306315
307316        # Train! 
308-         total_batch_size  =  (self .world_size  *  self .training_args .gradient_accumulation_steps  / 
309-                             self .training_args .sp_size  *  self .training_args .train_sp_batch_size )
317+         total_batch_size  =  (self .world_size  * 
318+                             self .training_args .gradient_accumulation_steps  / 
319+                             self .training_args .sp_size  * 
320+                             self .training_args .train_sp_batch_size )
310321        logger .info ("***** Running distillation training *****" )
311322        logger .info (f"  Resume training from step { init_steps }  )
312323        logger .info (
313-             f"  Instantaneous batch size per device = { self .training_args .train_batch_size }  )
324+             f"  Instantaneous batch size per device = { self .training_args .train_batch_size }  
325+         )
314326        logger .info (
315327            f"  Total train batch size (w. data & sequence parallel, accumulation) = { total_batch_size }  
316328        )
317329        logger .info (
318330            f"  Gradient Accumulation steps = { self .training_args .gradient_accumulation_steps }  
319331        )
320-         logger .info (f"  Total optimization steps = { self .training_args .max_train_steps }  )
332+         logger .info (
333+             f"  Total optimization steps = { self .training_args .max_train_steps }  
334+         )
321335        logger .info (
322336            f"  Total training parameters per FSDP shard = { sum (p .numel () for  p  in  self .transformer .parameters () if  p .requires_grad ) /  1e9 }  
323337        )
@@ -354,12 +368,13 @@ def get_num_phases(multi_phased_distill_schedule, step):
354368                    return  int (phase )
355369            return  int (phase )
356370
357-         for  step  in  range (init_steps  +  1 , self .training_args .max_train_steps  +  1 ):
371+         for  step  in  range (init_steps  +  1 ,
372+                           self .training_args .max_train_steps  +  1 ):
358373            start_time  =  time .perf_counter ()
359374
360375            assert  self .training_args .multi_phased_distill_schedule  is  not None 
361-             num_phases  =  get_num_phases (self . training_args . multi_phased_distill_schedule , 
362-                                          step )
376+             num_phases  =  get_num_phases (
377+                 self . training_args . multi_phased_distill_schedule ,  step )
363378            try :
364379                loss , grad_norm , pred_norm  =  self .distill_one_step (
365380                    self .transformer ,
@@ -407,7 +422,6 @@ def get_num_phases(multi_phased_distill_schedule, step):
407422                step  -=  1 
408423                continue 
409424
410- 
411425            if  self .rank  <=  0 :
412426                wandb .log (
413427                    {
@@ -441,10 +455,10 @@ def get_num_phases(multi_phased_distill_schedule, step):
441455                else :
442456                    if  self .training_args .use_ema :
443457                        save_checkpoint (self .ema_transformer , self .rank ,
444-                                             self .training_args .output_dir , step )
458+                                         self .training_args .output_dir , step )
445459                    else :
446460                        save_checkpoint (self .transformer , self .rank ,
447-                                             self .training_args .output_dir , step )
461+                                         self .training_args .output_dir , step )
448462                self .sp_group .barrier ()
449463
450464            if  self .training_args .log_validation  and  step  %  self .training_args .validation_steps  ==  0 :
@@ -454,8 +468,9 @@ def get_num_phases(multi_phased_distill_schedule, step):
454468        if  self .training_args .use_lora :
455469            raise  NotImplementedError ("LoRA is not supported now" )
456470        else :
457-             save_checkpoint (self .transformer , self .rank , self .training_args .output_dir ,
458-                                self .training_args .max_train_steps )
471+             save_checkpoint (self .transformer , self .rank ,
472+                             self .training_args .output_dir ,
473+                             self .training_args .max_train_steps )
459474
460475        if  get_sp_group ():
461476            cleanup_dist_env_and_memory ()
0 commit comments