44from copy import deepcopy
55
66import torch
7- import wandb
87from tqdm .auto import tqdm
98
9+ import wandb
1010from fastvideo .distill .solver import extract_into_tensor
1111from fastvideo .v1 .distributed import cleanup_dist_env_and_memory , get_sp_group
1212from fastvideo .v1 .fastvideo_args import FastVideoArgs , Mode , TrainingArgs
1313from fastvideo .v1 .forward_context import set_forward_context
1414from fastvideo .v1 .logger import init_logger
1515from fastvideo .v1 .pipelines .pipeline_batch_info import ForwardBatch
16- from fastvideo .v1 .training .training_utils import (
17- clip_grad_norm_while_handling_failing_dtensor_cases ,
18- save_checkpoint , normalize_dit_input )
1916from fastvideo .v1 .pipelines .wan .wan_pipeline import WanValidationPipeline
20- from fastvideo .v1 .training .distillation_pipeline import DistillationPipeline , reshard_fsdp
17+ from fastvideo .v1 .training .distillation_pipeline import (DistillationPipeline ,
18+ reshard_fsdp )
19+ from fastvideo .v1 .training .training_utils import (
20+ clip_grad_norm_while_handling_failing_dtensor_cases , normalize_dit_input ,
21+ save_checkpoint )
2122
2223logger = init_logger (__name__ )
2324
25+
2426def get_norm (model_pred , norms , gradient_accumulation_steps ):
2527 """Calculate and aggregate model prediction norms."""
2628 fro_norm = (
@@ -44,6 +46,7 @@ def get_norm(model_pred, norms, gradient_accumulation_steps):
4446 norms ["absolute mean" ] += absolute_mean .item ()
4547 norms ["absolute max" ] += absolute_max .item ()
4648
49+
4750class WanDistillationPipeline (DistillationPipeline ):
4851 """
4952 A distillation pipeline for Wan.
@@ -124,15 +127,14 @@ def distill_one_step(
124127 noise = torch .randn_like (latents )
125128
126129 indices = torch .randint (0 ,
127- num_euler_timesteps , (batch_size , ),
128- device = latents .device ).long ()
130+ num_euler_timesteps , (batch_size , ),
131+ device = latents .device ).long ()
129132
130133 if sp_size > 1 :
131134 self .sp_group .broadcast (indices , src = 0 )
132135
133136 # Add noise according to flow matching
134- sigmas = extract_into_tensor (solver .sigmas , indices ,
135- latents .shape )
137+ sigmas = extract_into_tensor (solver .sigmas , indices , latents .shape )
136138 sigmas_prev = extract_into_tensor (solver .sigmas_prev , indices ,
137139 latents .shape )
138140
@@ -186,16 +188,23 @@ def distill_one_step(
186188 # Get teacher model prediction on unconditional embedding
187189 with torch .autocast ("cuda" , dtype = torch .bfloat16 ):
188190 input_kwargs = {
189- "hidden_states" : noisy_model_input ,
190- "encoder_hidden_states" : uncond_prompt_embed .unsqueeze (0 ).expand (
191- batch_size , - 1 , - 1 ),
192- "timestep" : timesteps ,
193- "encoder_attention_mask" : uncond_prompt_mask .unsqueeze (0 ).expand (batch_size , - 1 ),
194- "return_dict" : False ,
191+ "hidden_states" :
192+ noisy_model_input ,
193+ "encoder_hidden_states" :
194+ uncond_prompt_embed .unsqueeze (0 ).expand (
195+ batch_size , - 1 , - 1 ),
196+ "timestep" :
197+ timesteps ,
198+ "encoder_attention_mask" :
199+ uncond_prompt_mask .unsqueeze (0 ).expand (
200+ batch_size , - 1 ),
201+ "return_dict" :
202+ False ,
195203 }
196204 with set_forward_context (current_timestep = timesteps ,
197205 attn_metadata = None ):
198- uncond_teacher_output = teacher_transformer (** input_kwargs )[0 ]
206+ uncond_teacher_output = teacher_transformer (
207+ ** input_kwargs )[0 ]
199208 teacher_output = uncond_teacher_output + distill_cfg * (
200209 cond_teacher_output - uncond_teacher_output )
201210 x_prev = solver .euler_step (noisy_model_input , teacher_output ,
@@ -305,19 +314,24 @@ def forward(
305314 uncond_prompt_mask = self .uncond_prompt_mask
306315
307316 # Train!
308- total_batch_size = (self .world_size * self .training_args .gradient_accumulation_steps /
309- self .training_args .sp_size * self .training_args .train_sp_batch_size )
317+ total_batch_size = (self .world_size *
318+ self .training_args .gradient_accumulation_steps /
319+ self .training_args .sp_size *
320+ self .training_args .train_sp_batch_size )
310321 logger .info ("***** Running distillation training *****" )
311322 logger .info (f" Resume training from step { init_steps } " )
312323 logger .info (
313- f" Instantaneous batch size per device = { self .training_args .train_batch_size } " )
324+ f" Instantaneous batch size per device = { self .training_args .train_batch_size } "
325+ )
314326 logger .info (
315327 f" Total train batch size (w. data & sequence parallel, accumulation) = { total_batch_size } "
316328 )
317329 logger .info (
318330 f" Gradient Accumulation steps = { self .training_args .gradient_accumulation_steps } "
319331 )
320- logger .info (f" Total optimization steps = { self .training_args .max_train_steps } " )
332+ logger .info (
333+ f" Total optimization steps = { self .training_args .max_train_steps } "
334+ )
321335 logger .info (
322336 f" Total training parameters per FSDP shard = { sum (p .numel () for p in self .transformer .parameters () if p .requires_grad ) / 1e9 } B"
323337 )
@@ -354,12 +368,13 @@ def get_num_phases(multi_phased_distill_schedule, step):
354368 return int (phase )
355369 return int (phase )
356370
357- for step in range (init_steps + 1 , self .training_args .max_train_steps + 1 ):
371+ for step in range (init_steps + 1 ,
372+ self .training_args .max_train_steps + 1 ):
358373 start_time = time .perf_counter ()
359374
360375 assert self .training_args .multi_phased_distill_schedule is not None
361- num_phases = get_num_phases (self . training_args . multi_phased_distill_schedule ,
362- step )
376+ num_phases = get_num_phases (
377+ self . training_args . multi_phased_distill_schedule , step )
363378 try :
364379 loss , grad_norm , pred_norm = self .distill_one_step (
365380 self .transformer ,
@@ -407,7 +422,6 @@ def get_num_phases(multi_phased_distill_schedule, step):
407422 step -= 1
408423 continue
409424
410-
411425 if self .rank <= 0 :
412426 wandb .log (
413427 {
@@ -441,10 +455,10 @@ def get_num_phases(multi_phased_distill_schedule, step):
441455 else :
442456 if self .training_args .use_ema :
443457 save_checkpoint (self .ema_transformer , self .rank ,
444- self .training_args .output_dir , step )
458+ self .training_args .output_dir , step )
445459 else :
446460 save_checkpoint (self .transformer , self .rank ,
447- self .training_args .output_dir , step )
461+ self .training_args .output_dir , step )
448462 self .sp_group .barrier ()
449463
450464 if self .training_args .log_validation and step % self .training_args .validation_steps == 0 :
@@ -454,8 +468,9 @@ def get_num_phases(multi_phased_distill_schedule, step):
454468 if self .training_args .use_lora :
455469 raise NotImplementedError ("LoRA is not supported now" )
456470 else :
457- save_checkpoint (self .transformer , self .rank , self .training_args .output_dir ,
458- self .training_args .max_train_steps )
471+ save_checkpoint (self .transformer , self .rank ,
472+ self .training_args .output_dir ,
473+ self .training_args .max_train_steps )
459474
460475 if get_sp_group ():
461476 cleanup_dist_env_and_memory ()
0 commit comments