Skip to content

Commit a87c7d7

Browse files
committed
pre-commit
1 parent 2dea928 commit a87c7d7

File tree

5 files changed

+78
-71
lines changed

5 files changed

+78
-71
lines changed

fastvideo/v1/pipelines/composed_pipeline_base.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010
from abc import ABC, abstractmethod
1111
from copy import deepcopy
12+
from enum import Enum
1213
from typing import Any, Dict, List, Optional, Union, cast
1314

1415
import torch
@@ -106,7 +107,7 @@ def __init__(self,
106107

107108
self.initialize_pipeline(fastvideo_args)
108109

109-
if not fastvideo_args.training_mode:
110+
if fastvideo_args.inference_mode:
110111
logger.info("Creating pipeline stages...")
111112
self.create_pipeline_stages(fastvideo_args)
112113

@@ -119,7 +120,7 @@ def initialize_validation_pipeline(self, training_args: TrainingArgs):
119120
"if log_validation is True, the pipeline must implement this method"
120121
)
121122

122-
def initialize_distillation_pipeline(self, fastvideo_args: FastVideoArgs):
123+
def initialize_distillation_pipeline(self, training_args: TrainingArgs):
123124
raise NotImplementedError(
124125
"if distill_mode is True, the pipeline must implement this method")
125126

@@ -162,29 +163,37 @@ def from_pretrained(cls,
162163
config_args = shallow_asdict(config)
163164
config_args.update(kwargs)
164165

165-
args.model_path = model_path
166166
# Handle both string mode and Mode enum values
167-
mode_str = args.mode if isinstance(args.mode, str) else args.mode.value
167+
mode_str: str | Enum = getattr(
168+
args, 'mode', "inference") if args is not None else "inference"
169+
if hasattr(mode_str, 'value'):
170+
mode_str = mode_str.value
171+
mode_str = str(mode_str)
168172

169173
if mode_str == "inference":
170-
fastvideo_args = FastVideoArgs.from_cli_args(args)
174+
fastvideo_args = FastVideoArgs(model_path=model_path, **config_args)
175+
176+
fastvideo_args.model_path = model_path
171177
for key, value in config_args.items():
172178
setattr(fastvideo_args, key, value)
173-
174179
elif mode_str == "training" or mode_str == "distill":
175180
assert args is not None, "args must be provided for training mode"
176181
fastvideo_args = TrainingArgs.from_cli_args(args)
182+
# TODO(will): fix this so that its not so ugly
183+
fastvideo_args.model_path = model_path
177184
for key, value in config_args.items():
178185
setattr(fastvideo_args, key, value)
179186

180187
fastvideo_args.use_cpu_offload = False
181-
# make sure we are in training mode
188+
# make sure we are in training mode - note: inference_mode is read-only,
189+
# so we don't set it directly here as it's determined by the mode
182190
# we hijack the precision to be the master weight type so that the
183191
# model is loaded with the correct precision. Subsequently we will
184192
# use FSDP2's MixedPrecisionPolicy to set the precision for the
185193
# fwd, bwd, and other operations' precision.
186194
# fastvideo_args.precision = fastvideo_args.master_weight_type
187195
assert fastvideo_args.master_weight_type == 'fp32', 'only fp32 is supported for training'
196+
# assert fastvideo_args.precision == 'fp32', 'only fp32 is supported for training'
188197
else:
189198
raise ValueError(f"Invalid mode: {mode_str}")
190199

fastvideo/v1/training/distillation_pipeline.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import gc
22
import os
33
from abc import ABC, abstractmethod
4+
from typing import List, Optional
45

56
import imageio
67
import numpy as np
@@ -26,7 +27,10 @@
2627

2728

2829
# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
29-
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
30+
def linear_quadratic_schedule(
31+
num_steps: int,
32+
threshold_noise: float,
33+
linear_steps: Optional[int] = None) -> List[float]:
3034
if linear_steps is None:
3135
linear_steps = num_steps // 2
3236
linear_sigma_schedule = [
@@ -48,7 +52,7 @@ def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
4852
return sigma_schedule
4953

5054

51-
def reshard_fsdp(model):
55+
def reshard_fsdp(model: torch.nn.Module) -> None:
5256
"""Reshard FSDP model for EMA updates."""
5357
for m in FSDP.fsdp_modules(model):
5458
if m._has_params and m.sharding_strategy is not ShardingStrategy.NO_SHARD:
@@ -60,6 +64,7 @@ class DistillationPipeline(ComposedPipelineBase, ABC):
6064
A pipeline for distillation training. All distillation pipelines should inherit from this class.
6165
"""
6266
_required_config_modules = ["scheduler", "transformer"]
67+
validation_pipeline: ComposedPipelineBase
6368

6469
def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs):
6570
logger.info("Initializing distillation pipeline...")
@@ -104,6 +109,7 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs):
104109
assert noise_scheduler is not None
105110

106111
# Initialize solver for distillation
112+
sigmas: torch.Tensor | List[float] = []
107113
if fastvideo_args.scheduler_type == "pcm_linear_quadratic":
108114
linear_steps = int(noise_scheduler.config.num_train_timesteps *
109115
fastvideo_args.linear_range)
@@ -112,10 +118,12 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs):
112118
fastvideo_args.linear_quadratic_threshold,
113119
linear_steps,
114120
)
115-
sigmas = torch.tensor(sigmas).to(dtype=torch.float32)
116121
else:
117122
sigmas = noise_scheduler.sigmas
118123

124+
if isinstance(sigmas, list):
125+
sigmas = torch.tensor(sigmas).to(dtype=torch.float32)
126+
119127
self.solver = EulerSolver(
120128
sigmas.numpy(),
121129
noise_scheduler.config.num_train_timesteps,
@@ -203,7 +211,8 @@ def distill_one_step(self, transformer, model_type, teacher_transformer,
203211
raise NotImplementedError(
204212
"Distillation pipeline must implement this method")
205213

206-
def log_validation(self, transformer, fastvideo_args, global_step):
214+
@torch.no_grad()
215+
def _log_validation(self, transformer, fastvideo_args, global_step):
207216
"""Log validation results during training."""
208217
fastvideo_args.mode = Mode.INFERENCE
209218
fastvideo_args.use_cpu_offload = False
@@ -220,8 +229,9 @@ def log_validation(self, transformer, fastvideo_args, global_step):
220229
validation_dataset = ParquetVideoTextDataset(
221230
fastvideo_args.validation_prompt_dir,
222231
batch_size=1,
223-
cfg_rate=0,
224-
num_latent_t=fastvideo_args.num_latent_t)
232+
cfg_rate=fastvideo_args.cfg,
233+
num_latent_t=fastvideo_args.num_latent_t,
234+
validation=True)
225235

226236
validation_dataloader = StatefulDataLoader(validation_dataset,
227237
batch_size=1,
@@ -231,21 +241,13 @@ def log_validation(self, transformer, fastvideo_args, global_step):
231241
pin_memory=True,
232242
drop_last=False)
233243

234-
transformer.requires_grad_(False)
235-
for p in transformer.parameters():
236-
p.requires_grad = False
237244
transformer.eval()
238245

239-
# Add the transformer to the validation pipeline
240-
self.validation_pipeline.add_module("transformer", transformer)
241-
self.validation_pipeline.latent_preparation_stage.transformer = transformer
242-
self.validation_pipeline.denoising_stage.transformer = transformer
243-
244246
# Process validation prompts
245247
videos = []
246248
captions = []
247249
for _, embeddings, masks, infos in validation_dataloader:
248-
logger.info(f"infos: {infos}")
250+
logger.info("infos: %s", infos)
249251
caption = infos['caption']
250252
captions.append(caption)
251253
prompt_embeds = embeddings.to(fastvideo_args.device)

fastvideo/v1/training/training_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from fastvideo.v1.dataset import build_parquet_map_style_dataloader
2020
from fastvideo.v1.distributed import (get_sp_group, get_torch_device,
2121
get_world_group)
22-
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs
22+
from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs
2323
from fastvideo.v1.forward_context import set_forward_context
2424
from fastvideo.v1.logger import init_logger
2525
from fastvideo.v1.pipelines import ComposedPipelineBase
@@ -143,7 +143,7 @@ def train_one_step(self, transformer, model_type, optimizer, lr_scheduler,
143143
@torch.no_grad()
144144
def _log_validation(self, transformer, training_args, global_step) -> None:
145145
assert training_args is not None
146-
training_args.inference_mode = True
146+
training_args.mode = Mode.INFERENCE
147147
training_args.use_cpu_offload = False
148148
if not training_args.log_validation:
149149
return

fastvideo/v1/training/wan_distillation_pipeline.py

Lines changed: 40 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
from collections import deque
44
from copy import deepcopy
5+
from typing import Dict
56

67
import torch
78
from tqdm.auto import tqdm
@@ -23,7 +24,8 @@
2324
logger = init_logger(__name__)
2425

2526

26-
def get_norm(model_pred, norms, gradient_accumulation_steps):
27+
def get_norm(model_pred: torch.Tensor, norms: Dict[str, float],
28+
gradient_accumulation_steps: int) -> None:
2729
"""Calculate and aggregate model prediction norms."""
2830
fro_norm = (
2931
torch.linalg.matrix_norm(model_pred, ord="fro") / # codespell:ignore
@@ -66,7 +68,10 @@ def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs):
6668
args_copy.mode = Mode.INFERENCE
6769
args_copy.vae_config.load_encoder = False
6870
validation_pipeline = WanValidationPipeline.from_pretrained(
69-
fastvideo_args.model_path, args=args_copy)
71+
fastvideo_args.model_path,
72+
args=None,
73+
mode=Mode.INFERENCE,
74+
loaded_modules={"transformer": self.get_module("transformer")})
7075

7176
self.validation_pipeline = validation_pipeline
7277

@@ -95,11 +100,7 @@ def distill_one_step(
95100
pred_decay_weight,
96101
pred_decay_type,
97102
hunyuan_teacher_disable_cfg,
98-
weighting_scheme,
99-
logit_mean,
100-
logit_std,
101-
mode_scale,
102-
):
103+
) -> tuple[float, float, Dict[str, float]]:
103104
"""Perform one step of distillation training."""
104105
total_loss = 0.0
105106
optimizer.zero_grad()
@@ -170,17 +171,16 @@ def distill_one_step(
170171
noisy_model_input, model_pred, indices, multiphase)
171172

172173
# Get teacher model prediction
173-
with torch.no_grad():
174-
with torch.autocast("cuda", dtype=torch.bfloat16):
175-
with set_forward_context(current_timestep=timesteps,
176-
attn_metadata=None):
177-
cond_teacher_output = teacher_transformer(
178-
noisy_model_input,
179-
encoder_hidden_states,
180-
timesteps,
181-
encoder_attention_mask,
182-
return_dict=False,
183-
)[0].float()
174+
with torch.no_grad(), torch.autocast(
175+
"cuda", dtype=torch.bfloat16), set_forward_context(
176+
current_timestep=timesteps, attn_metadata=None):
177+
cond_teacher_output = teacher_transformer(
178+
noisy_model_input,
179+
encoder_hidden_states,
180+
timesteps,
181+
encoder_attention_mask,
182+
return_dict=False,
183+
)[0].float()
184184

185185
if not_apply_cfg_solver:
186186
uncond_teacher_output = cond_teacher_output
@@ -313,31 +313,30 @@ def forward(
313313
uncond_prompt_embed = self.uncond_prompt_embed
314314
uncond_prompt_mask = self.uncond_prompt_mask
315315

316-
# Train!
316+
assert self.training_args.sp_size is not None
317+
assert self.training_args.gradient_accumulation_steps is not None
317318
total_batch_size = (self.world_size *
318319
self.training_args.gradient_accumulation_steps /
319320
self.training_args.sp_size *
320321
self.training_args.train_sp_batch_size)
321322
logger.info("***** Running distillation training *****")
322-
logger.info(f" Resume training from step {init_steps}")
323-
logger.info(
324-
f" Instantaneous batch size per device = {self.training_args.train_batch_size}"
325-
)
323+
logger.info(" Resume training from step %s", init_steps)
324+
logger.info(" Instantaneous batch size per device = %s",
325+
self.training_args.train_batch_size)
326326
logger.info(
327-
f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}"
328-
)
327+
" Total train batch size (w. data & sequence parallel, accumulation) = %s",
328+
total_batch_size)
329+
logger.info(" Gradient Accumulation steps = %s",
330+
self.training_args.gradient_accumulation_steps)
331+
logger.info(" Total optimization steps = %s",
332+
self.training_args.max_train_steps)
329333
logger.info(
330-
f" Gradient Accumulation steps = {self.training_args.gradient_accumulation_steps}"
331-
)
332-
logger.info(
333-
f" Total optimization steps = {self.training_args.max_train_steps}"
334-
)
335-
logger.info(
336-
f" Total training parameters per FSDP shard = {sum(p.numel() for p in self.transformer.parameters() if p.requires_grad) / 1e9} B"
337-
)
338-
logger.info(
339-
f" Master weight dtype: {self.transformer.parameters().__next__().dtype}"
340-
)
334+
" Total training parameters per FSDP shard = %s B",
335+
sum(p.numel()
336+
for p in self.transformer.parameters() if p.requires_grad) /
337+
1e9)
338+
logger.info(" Master weight dtype: %s",
339+
self.transformer.parameters().__next__().dtype)
341340

342341
# Potentially load in the weights and states from a previous save
343342
if self.training_args.resume_from_checkpoint:
@@ -352,13 +351,14 @@ def forward(
352351
)
353352

354353
loader_iter = iter(train_dataloader)
355-
step_times = deque(maxlen=100)
354+
step_times: deque[float] = deque(maxlen=100)
356355

357356
# Skip steps if resuming
358357
for i in range(init_steps):
359358
next(loader_iter)
360359

361-
def get_num_phases(multi_phased_distill_schedule, step):
360+
def get_num_phases(multi_phased_distill_schedule: str,
361+
step: int) -> int:
362362
# step-phase,step-phase
363363
multi_phases = multi_phased_distill_schedule.split(",")
364364
phase = multi_phases[-1].split("-")[-1]
@@ -400,10 +400,6 @@ def get_num_phases(multi_phased_distill_schedule, step):
400400
self.training_args.pred_decay_weight,
401401
self.training_args.pred_decay_type,
402402
self.training_args.hunyuan_teacher_disable_cfg,
403-
self.training_args.weighting_scheme,
404-
self.training_args.logit_mean,
405-
self.training_args.logit_std,
406-
self.training_args.mode_scale,
407403
)
408404

409405
step_time = time.perf_counter() - start_time
@@ -462,7 +458,7 @@ def get_num_phases(multi_phased_distill_schedule, step):
462458
self.sp_group.barrier()
463459

464460
if self.training_args.log_validation and step % self.training_args.validation_steps == 0:
465-
self.log_validation(self.transformer, self.training_args, step)
461+
self._log_validation(self.transformer, self.training_args, step)
466462

467463
# Final checkpoint
468464
if self.training_args.use_lora:
@@ -476,7 +472,7 @@ def get_num_phases(multi_phased_distill_schedule, step):
476472
cleanup_dist_env_and_memory()
477473

478474

479-
def main(args):
475+
def main(args) -> None:
480476
logger.info("Starting distillation pipeline...")
481477

482478
pipeline = WanDistillationPipeline.from_pretrained(

fastvideo/v1/training/wan_training_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from fastvideo.v1.distributed import (cleanup_dist_env_and_memory, get_sp_group,
1414
get_torch_device, get_world_group)
15-
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs
15+
from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs
1616
from fastvideo.v1.forward_context import set_forward_context
1717
from fastvideo.v1.logger import init_logger
1818
from fastvideo.v1.models.schedulers.scheduling_flow_unipc_multistep import (
@@ -53,12 +53,12 @@ def initialize_validation_pipeline(self, training_args: TrainingArgs):
5353
logger.info("Initializing validation pipeline...")
5454
args_copy = deepcopy(training_args)
5555

56-
args_copy.inference_mode = True
56+
args_copy.mode = Mode.INFERENCE
5757
args_copy.vae_config.load_encoder = False
5858
validation_pipeline = WanValidationPipeline.from_pretrained(
5959
training_args.model_path,
6060
args=None,
61-
inference_mode=True,
61+
mode=Mode.INFERENCE,
6262
loaded_modules={"transformer": self.get_module("transformer")},
6363
tp_size=training_args.tp_size,
6464
sp_size=training_args.sp_size,

0 commit comments

Comments
 (0)