Skip to content

Commit d5ec468

Browse files
authored
[V0] [Distill] support distill in V0 for wan (#444)
1 parent e55fa6e commit d5ec468

File tree

12 files changed

+1424
-54
lines changed

12 files changed

+1424
-54
lines changed

fastvideo/data_preprocess/preprocess_text_embeddings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def main(args):
6868
train_dataset = T5dataset(latents_json_path, args.vae_debug)
6969
text_encoder = load_text_encoder(args.model_type, args.model_path, device=device)
7070
vae, autocast_type, fps = load_vae(args.model_type, args.model_path)
71-
vae.enable_tiling()
71+
if args.model_type != "wan":
72+
vae.enable_tiling()
7273
sampler = DistributedSampler(train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True)
7374
train_dataloader = DataLoader(
7475
train_dataset,

fastvideo/data_preprocess/preprocess_vae_latents.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def main(args):
3333
if not dist.is_initialized():
3434
dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=local_rank)
3535
vae, autocast_type, fps = load_vae(args.model_type, args.model_path)
36-
vae.enable_tiling()
36+
if args.model_type != "wan":
37+
vae.enable_tiling()
3738
os.makedirs(args.output_dir, exist_ok=True)
3839
os.makedirs(os.path.join(args.output_dir, "latent"), exist_ok=True)
3940

fastvideo/distill.py

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import wandb
1313
from accelerate.utils import set_seed
1414
from diffusers import FlowMatchEulerDiscreteScheduler
15+
from fastvideo.distill.solver import PCMFMScheduler
1516
from diffusers.optimization import get_scheduler
1617
from diffusers.utils import check_min_version
1718
from peft import LoraConfig
@@ -23,7 +24,7 @@
2324

2425
from fastvideo.dataset.latent_datasets import (LatentDataset, latent_collate_function)
2526
from fastvideo.distill.solver import EulerSolver, extract_into_tensor
26-
from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input
27+
from fastvideo.utils.latents_utils import normalize_dit_input
2728
from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule
2829
from fastvideo.utils.checkpoint import (resume_lora_optimizer, save_checkpoint, save_lora_checkpoint)
2930
from fastvideo.utils.communications import (broadcast, sp_parallel_dataloader_wrapper)
@@ -123,13 +124,21 @@ def distill_one_step(
123124
noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input
124125
# Predict the noise residual
125126
with torch.autocast("cuda", dtype=torch.bfloat16):
126-
teacher_kwargs = {
127-
"hidden_states": noisy_model_input,
128-
"encoder_hidden_states": encoder_hidden_states,
129-
"timestep": timesteps,
130-
"encoder_attention_mask": encoder_attention_mask, # B, L
131-
"return_dict": False,
132-
}
127+
if args.model_type == "wan":
128+
teacher_kwargs = {
129+
"hidden_states": noisy_model_input,
130+
"encoder_hidden_states": encoder_hidden_states,
131+
"timestep": timesteps,
132+
"return_dict": True,
133+
}
134+
else:
135+
teacher_kwargs = {
136+
"hidden_states": noisy_model_input,
137+
"encoder_hidden_states": encoder_hidden_states,
138+
"timestep": timesteps,
139+
"encoder_attention_mask": encoder_attention_mask, # B, L
140+
"return_dict": False,
141+
}
133142
if hunyuan_teacher_disable_cfg:
134143
teacher_kwargs["guidance"] = torch.tensor([1000.0],
135144
device=noisy_model_input.device,
@@ -141,47 +150,70 @@ def distill_one_step(
141150
with torch.no_grad():
142151
w = distill_cfg
143152
with torch.autocast("cuda", dtype=torch.bfloat16):
144-
cond_teacher_output = teacher_transformer(
145-
noisy_model_input,
146-
encoder_hidden_states,
147-
timesteps,
148-
encoder_attention_mask, # B, L
149-
return_dict=False,
150-
)[0].float()
153+
if args.model_type == "wan":
154+
cond_teacher_kwargs ={
155+
"hidden_states": noisy_model_input,
156+
"encoder_hidden_states": encoder_hidden_states,
157+
"timestep": timesteps,
158+
"return_dict": True,
159+
}
160+
else:
161+
cond_teacher_kwargs = {
162+
"hidden_states": noisy_model_input,
163+
"encoder_hidden_states": encoder_hidden_states,
164+
"timestep": timesteps,
165+
"encoder_attention_mask": encoder_attention_mask, # B, L
166+
"return_dict": False,
167+
}
168+
cond_teacher_output = teacher_transformer(**cond_teacher_kwargs)[0].float()
151169
if not_apply_cfg_solver:
152170
uncond_teacher_output = cond_teacher_output
153171
else:
154172
# Get teacher model prediction on noisy_latents and unconditional embedding
155173
with torch.autocast("cuda", dtype=torch.bfloat16):
156-
uncond_teacher_output = teacher_transformer(
157-
noisy_model_input,
158-
uncond_prompt_embed.unsqueeze(0).expand(bsz, -1, -1),
159-
timesteps,
160-
uncond_prompt_mask.unsqueeze(0).expand(bsz, -1),
161-
return_dict=False,
162-
)[0].float()
174+
if args.model_type == "wan":
175+
uncond_teacher_kwargs = {
176+
"hidden_states": noisy_model_input,
177+
"encoder_hidden_states":uncond_prompt_embed.unsqueeze(0).expand(bsz, -1, -1),
178+
"timestep": timesteps,
179+
"return_dict": True,
180+
}
181+
else:
182+
uncond_teacher_kwargs = {
183+
"hidden_states": noisy_model_input,
184+
"encoder_hidden_states":uncond_prompt_embed.unsqueeze(0).expand(bsz, -1, -1),
185+
"timestep": timesteps,
186+
"encoder_attention_mask": uncond_prompt_mask.unsqueeze(0).expand(bsz, -1),
187+
"return_dict": False,
188+
}
189+
190+
uncond_teacher_output = teacher_transformer(**uncond_teacher_kwargs)[0].float()
191+
163192
teacher_output = uncond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
164193
x_prev = solver.euler_step(noisy_model_input, teacher_output, index)
165194

166195
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
167196
with torch.no_grad():
168197
with torch.autocast("cuda", dtype=torch.bfloat16):
198+
if args.model_type == "wan":
199+
target_pred_kwargs = {
200+
"hidden_states": x_prev.float(),
201+
"encoder_hidden_states": encoder_hidden_states,
202+
"timestep":timesteps_prev,
203+
"return_dict":True,
204+
}
205+
else:
206+
target_pred_kwargs = {
207+
"hidden_states": x_prev.float(),
208+
"encoder_hidden_states": encoder_hidden_states,
209+
"timestep":timesteps_prev,
210+
"encoder_attention_mask":encoder_attention_mask,
211+
"return_dict":False,
212+
}
169213
if ema_transformer is not None:
170-
target_pred = ema_transformer(
171-
x_prev.float(),
172-
encoder_hidden_states,
173-
timesteps_prev,
174-
encoder_attention_mask, # B, L
175-
return_dict=False,
176-
)[0]
214+
target_pred = ema_transformer(**target_pred_kwargs)[0]
177215
else:
178-
target_pred = transformer(
179-
x_prev.float(),
180-
encoder_hidden_states,
181-
timesteps_prev,
182-
encoder_attention_mask, # B, L
183-
return_dict=False,
184-
)[0]
216+
target_pred = transformer(**target_pred_kwargs)[0]
185217

186218
target, end_index = solver.euler_style_multiphase_pred(x_prev, target_pred, index, multiphase, True)
187219

@@ -319,7 +351,9 @@ def main(args):
319351
teacher_transformer.requires_grad_(False)
320352
if args.use_ema:
321353
ema_transformer.requires_grad_(False)
322-
noise_scheduler = FlowMatchEulerDiscreteScheduler(shift=args.shift)
354+
355+
noise_scheduler = FlowMatchEulerDiscreteScheduler()
356+
323357
if args.scheduler_type == "pcm_linear_quadratic":
324358
linear_steps = int(noise_scheduler.config.num_train_timesteps * args.linear_range)
325359
sigmas = linear_quadratic_schedule(

fastvideo/distill_adv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from fastvideo.dataset.latent_datasets import (LatentDataset, latent_collate_function)
2424
from fastvideo.distill.discriminator import Discriminator
2525
from fastvideo.distill.solver import EulerSolver, extract_into_tensor
26-
from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input
26+
from fastvideo.utils.latents_utils import normalize_dit_input
2727
from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule
2828
from fastvideo.utils.checkpoint import (resume_lora_optimizer, resume_training_generator_discriminator, save_checkpoint,
2929
save_lora_checkpoint)

0 commit comments

Comments
 (0)