1212import wandb
1313from accelerate .utils import set_seed
1414from diffusers import FlowMatchEulerDiscreteScheduler
15+ from fastvideo .distill .solver import PCMFMScheduler
1516from diffusers .optimization import get_scheduler
1617from diffusers .utils import check_min_version
1718from peft import LoraConfig
2324
2425from fastvideo .dataset .latent_datasets import (LatentDataset , latent_collate_function )
2526from fastvideo .distill .solver import EulerSolver , extract_into_tensor
26- from fastvideo .models . mochi_hf . mochi_latents_utils import normalize_dit_input
27+ from fastvideo .utils . latents_utils import normalize_dit_input
2728from fastvideo .models .mochi_hf .pipeline_mochi import linear_quadratic_schedule
2829from fastvideo .utils .checkpoint import (resume_lora_optimizer , save_checkpoint , save_lora_checkpoint )
2930from fastvideo .utils .communications import (broadcast , sp_parallel_dataloader_wrapper )
@@ -123,13 +124,21 @@ def distill_one_step(
123124 noisy_model_input = sigmas * noise + (1.0 - sigmas ) * model_input
124125 # Predict the noise residual
125126 with torch .autocast ("cuda" , dtype = torch .bfloat16 ):
126- teacher_kwargs = {
127- "hidden_states" : noisy_model_input ,
128- "encoder_hidden_states" : encoder_hidden_states ,
129- "timestep" : timesteps ,
130- "encoder_attention_mask" : encoder_attention_mask , # B, L
131- "return_dict" : False ,
132- }
127+ if args .model_type == "wan" :
128+ teacher_kwargs = {
129+ "hidden_states" : noisy_model_input ,
130+ "encoder_hidden_states" : encoder_hidden_states ,
131+ "timestep" : timesteps ,
132+ "return_dict" : True ,
133+ }
134+ else :
135+ teacher_kwargs = {
136+ "hidden_states" : noisy_model_input ,
137+ "encoder_hidden_states" : encoder_hidden_states ,
138+ "timestep" : timesteps ,
139+ "encoder_attention_mask" : encoder_attention_mask , # B, L
140+ "return_dict" : False ,
141+ }
133142 if hunyuan_teacher_disable_cfg :
134143 teacher_kwargs ["guidance" ] = torch .tensor ([1000.0 ],
135144 device = noisy_model_input .device ,
@@ -141,47 +150,70 @@ def distill_one_step(
141150 with torch .no_grad ():
142151 w = distill_cfg
143152 with torch .autocast ("cuda" , dtype = torch .bfloat16 ):
144- cond_teacher_output = teacher_transformer (
145- noisy_model_input ,
146- encoder_hidden_states ,
147- timesteps ,
148- encoder_attention_mask , # B, L
149- return_dict = False ,
150- )[0 ].float ()
153+ if args .model_type == "wan" :
154+ cond_teacher_kwargs = {
155+ "hidden_states" : noisy_model_input ,
156+ "encoder_hidden_states" : encoder_hidden_states ,
157+ "timestep" : timesteps ,
158+ "return_dict" : True ,
159+ }
160+ else :
161+ cond_teacher_kwargs = {
162+ "hidden_states" : noisy_model_input ,
163+ "encoder_hidden_states" : encoder_hidden_states ,
164+ "timestep" : timesteps ,
165+ "encoder_attention_mask" : encoder_attention_mask , # B, L
166+ "return_dict" : False ,
167+ }
168+ cond_teacher_output = teacher_transformer (** cond_teacher_kwargs )[0 ].float ()
151169 if not_apply_cfg_solver :
152170 uncond_teacher_output = cond_teacher_output
153171 else :
154172 # Get teacher model prediction on noisy_latents and unconditional embedding
155173 with torch .autocast ("cuda" , dtype = torch .bfloat16 ):
156- uncond_teacher_output = teacher_transformer (
157- noisy_model_input ,
158- uncond_prompt_embed .unsqueeze (0 ).expand (bsz , - 1 , - 1 ),
159- timesteps ,
160- uncond_prompt_mask .unsqueeze (0 ).expand (bsz , - 1 ),
161- return_dict = False ,
162- )[0 ].float ()
174+ if args .model_type == "wan" :
175+ uncond_teacher_kwargs = {
176+ "hidden_states" : noisy_model_input ,
177+ "encoder_hidden_states" :uncond_prompt_embed .unsqueeze (0 ).expand (bsz , - 1 , - 1 ),
178+ "timestep" : timesteps ,
179+ "return_dict" : True ,
180+ }
181+ else :
182+ uncond_teacher_kwargs = {
183+ "hidden_states" : noisy_model_input ,
184+ "encoder_hidden_states" :uncond_prompt_embed .unsqueeze (0 ).expand (bsz , - 1 , - 1 ),
185+ "timestep" : timesteps ,
186+ "encoder_attention_mask" : uncond_prompt_mask .unsqueeze (0 ).expand (bsz , - 1 ),
187+ "return_dict" : False ,
188+ }
189+
190+ uncond_teacher_output = teacher_transformer (** uncond_teacher_kwargs )[0 ].float ()
191+
163192 teacher_output = uncond_teacher_output + w * (cond_teacher_output - uncond_teacher_output )
164193 x_prev = solver .euler_step (noisy_model_input , teacher_output , index )
165194
166195 # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
167196 with torch .no_grad ():
168197 with torch .autocast ("cuda" , dtype = torch .bfloat16 ):
198+ if args .model_type == "wan" :
199+ target_pred_kwargs = {
200+ "hidden_states" : x_prev .float (),
201+ "encoder_hidden_states" : encoder_hidden_states ,
202+ "timestep" :timesteps_prev ,
203+ "return_dict" :True ,
204+ }
205+ else :
206+ target_pred_kwargs = {
207+ "hidden_states" : x_prev .float (),
208+ "encoder_hidden_states" : encoder_hidden_states ,
209+ "timestep" :timesteps_prev ,
210+ "encoder_attention_mask" :encoder_attention_mask ,
211+ "return_dict" :False ,
212+ }
169213 if ema_transformer is not None :
170- target_pred = ema_transformer (
171- x_prev .float (),
172- encoder_hidden_states ,
173- timesteps_prev ,
174- encoder_attention_mask , # B, L
175- return_dict = False ,
176- )[0 ]
214+ target_pred = ema_transformer (** target_pred_kwargs )[0 ]
177215 else :
178- target_pred = transformer (
179- x_prev .float (),
180- encoder_hidden_states ,
181- timesteps_prev ,
182- encoder_attention_mask , # B, L
183- return_dict = False ,
184- )[0 ]
216+ target_pred = transformer (** target_pred_kwargs )[0 ]
185217
186218 target , end_index = solver .euler_style_multiphase_pred (x_prev , target_pred , index , multiphase , True )
187219
@@ -319,7 +351,9 @@ def main(args):
319351 teacher_transformer .requires_grad_ (False )
320352 if args .use_ema :
321353 ema_transformer .requires_grad_ (False )
322- noise_scheduler = FlowMatchEulerDiscreteScheduler (shift = args .shift )
354+
355+ noise_scheduler = FlowMatchEulerDiscreteScheduler ()
356+
323357 if args .scheduler_type == "pcm_linear_quadratic" :
324358 linear_steps = int (noise_scheduler .config .num_train_timesteps * args .linear_range )
325359 sigmas = linear_quadratic_schedule (
0 commit comments