Skip to content

Commit f797e79

Browse files
committed
[Fmt] fmt code
1 parent 2d66b3b commit f797e79

File tree

5 files changed

+64
-43
lines changed

5 files changed

+64
-43
lines changed

fastvideo/v1/dataset/parquet_datasets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ def get_validation_negative_prompt(
192192
lat = lat[:, self.rank_in_sp_group, :, :, :]
193193
return lat, emb, mask, info
194194

195-
196195
def __len__(self):
197196
if self.local_indices is None:
198197
try:

fastvideo/v1/fastvideo_args.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
386386

387387
@classmethod
388388
def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs":
389-
assert getattr(args, 'model_path', None) is not None, "model_path must be set in args"
389+
assert getattr(args, 'model_path',
390+
None) is not None, "model_path must be set in args"
390391
# Handle attribute mapping with safe getattr
391392
if hasattr(args, 'tensor_parallel_size'):
392393
args.tp_size = args.tensor_parallel_size
@@ -425,7 +426,9 @@ def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs":
425426
else:
426427
kwargs[attr] = Mode.INFERENCE
427428
elif attr == 'device_str':
428-
kwargs[attr] = getattr(args, 'device', None) or "cuda" if torch.cuda.is_available() else "cpu"
429+
kwargs[attr] = getattr(
430+
args, 'device',
431+
None) or "cuda" if torch.cuda.is_available() else "cpu"
429432
# Use getattr with default value from the dataclass for potentially missing attributes
430433
else:
431434
default_value = getattr(cls, attr, None)
@@ -652,9 +655,12 @@ def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs":
652655
else:
653656
kwargs[attr] = Mode(mode_value)
654657
else:
655-
kwargs[attr] = Mode.TRAINING # Default to training for TrainingArgs
658+
kwargs[
659+
attr] = Mode.TRAINING # Default to training for TrainingArgs
656660
elif attr == 'device_str':
657-
kwargs[attr] = getattr(args, 'device', None) or "cuda" if torch.cuda.is_available() else "cpu"
661+
kwargs[attr] = getattr(
662+
args, 'device',
663+
None) or "cuda" if torch.cuda.is_available() else "cpu"
658664
# Use getattr with default value from the dataclass for potentially missing attributes
659665
else:
660666
default_value = getattr(cls, attr, None)

fastvideo/v1/pipelines/composed_pipeline_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from fastvideo.v1.distributed import (init_distributed_environment,
1919
initialize_model_parallel,
2020
model_parallel_is_initialized)
21-
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs, Mode
21+
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs
2222
from fastvideo.v1.logger import init_logger
2323
from fastvideo.v1.models.loader.component_loader import PipelineComponentLoader
2424
from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch
@@ -161,7 +161,7 @@ def from_pretrained(cls,
161161
args.model_path = model_path
162162
# Handle both string mode and Mode enum values
163163
mode_str = args.mode if isinstance(args.mode, str) else args.mode.value
164-
164+
165165
if mode_str == "inference":
166166
fastvideo_args = FastVideoArgs.from_cli_args(args)
167167
for key, value in config_args.items():
@@ -185,7 +185,6 @@ def from_pretrained(cls,
185185
else:
186186
raise ValueError(f"Invalid mode: {mode_str}")
187187

188-
189188
fastvideo_args.check_fastvideo_args()
190189

191190
logger.info("fastvideo_args in from_pretrained: %s", fastvideo_args)

fastvideo/v1/training/distillation_pipeline.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@
66
import numpy as np
77
import torch
88
import torchvision
9-
import wandb
109
from diffusers.optimization import get_scheduler
1110
from einops import rearrange
1211
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1312
from torch.distributed.fsdp import ShardingStrategy
1413
from torchdata.stateful_dataloader import StatefulDataLoader
1514

15+
import wandb
1616
from fastvideo.distill.solver import EulerSolver
1717
from fastvideo.v1.configs.sample import SamplingParam
1818
from fastvideo.v1.dataset.parquet_datasets import ParquetVideoTextDataset
1919
from fastvideo.v1.distributed import get_sp_group
20-
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs, Mode
20+
from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs
2121
from fastvideo.v1.logger import init_logger
2222
from fastvideo.v1.pipelines import ComposedPipelineBase
2323
from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch
@@ -54,6 +54,7 @@ def reshard_fsdp(model):
5454
if m._has_params and m.sharding_strategy is not ShardingStrategy.NO_SHARD:
5555
torch.distributed.fsdp._runtime_utils._reshard(m, m._handle, True)
5656

57+
5758
class DistillationPipeline(ComposedPipelineBase, ABC):
5859
"""
5960
A pipeline for distillation training. All distillation pipelines should inherit from this class.
@@ -77,10 +78,12 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs):
7778

7879
# Initialize teacher model without deepcopy to avoid FSDP issues
7980
logger.info("Creating teacher model...")
80-
from fastvideo.v1.models.loader.component_loader import TransformerLoader
81+
from fastvideo.v1.models.loader.component_loader import (
82+
TransformerLoader)
8183
teacher_loader = TransformerLoader()
8284
transformer_path = os.path.join(self.model_path, "transformer")
83-
self.teacher_transformer = teacher_loader.load(transformer_path, "", fastvideo_args)
85+
self.teacher_transformer = teacher_loader.load(transformer_path, "",
86+
fastvideo_args)
8487
self.teacher_transformer.requires_grad_(False)
8588
self.teacher_transformer.eval()
8689
logger.info("Teacher model initialized")
@@ -89,7 +92,8 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs):
8992
if fastvideo_args.use_ema:
9093
logger.info("Creating EMA model...")
9194
ema_loader = TransformerLoader()
92-
self.ema_transformer = ema_loader.load(transformer_path, "", fastvideo_args)
95+
self.ema_transformer = ema_loader.load(transformer_path, "",
96+
fastvideo_args)
9397
self.ema_transformer.requires_grad_(False)
9498
self.ema_transformer.eval()
9599
logger.info("EMA model initialized")
@@ -326,5 +330,3 @@ def log_validation(self, transformer, fastvideo_args, global_step):
326330

327331
gc.collect()
328332
torch.cuda.empty_cache()
329-
330-

fastvideo/v1/training/wan_distillation_pipeline.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,25 @@
44
from copy import deepcopy
55

66
import torch
7-
import wandb
87
from tqdm.auto import tqdm
98

9+
import wandb
1010
from fastvideo.distill.solver import extract_into_tensor
1111
from fastvideo.v1.distributed import cleanup_dist_env_and_memory, get_sp_group
1212
from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs
1313
from fastvideo.v1.forward_context import set_forward_context
1414
from fastvideo.v1.logger import init_logger
1515
from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch
16-
from fastvideo.v1.training.training_utils import (
17-
clip_grad_norm_while_handling_failing_dtensor_cases,
18-
save_checkpoint, normalize_dit_input)
1916
from fastvideo.v1.pipelines.wan.wan_pipeline import WanValidationPipeline
20-
from fastvideo.v1.training.distillation_pipeline import DistillationPipeline, reshard_fsdp
17+
from fastvideo.v1.training.distillation_pipeline import (DistillationPipeline,
18+
reshard_fsdp)
19+
from fastvideo.v1.training.training_utils import (
20+
clip_grad_norm_while_handling_failing_dtensor_cases, normalize_dit_input,
21+
save_checkpoint)
2122

2223
logger = init_logger(__name__)
2324

25+
2426
def get_norm(model_pred, norms, gradient_accumulation_steps):
2527
"""Calculate and aggregate model prediction norms."""
2628
fro_norm = (
@@ -44,6 +46,7 @@ def get_norm(model_pred, norms, gradient_accumulation_steps):
4446
norms["absolute mean"] += absolute_mean.item()
4547
norms["absolute max"] += absolute_max.item()
4648

49+
4750
class WanDistillationPipeline(DistillationPipeline):
4851
"""
4952
A distillation pipeline for Wan.
@@ -124,15 +127,14 @@ def distill_one_step(
124127
noise = torch.randn_like(latents)
125128

126129
indices = torch.randint(0,
127-
num_euler_timesteps, (batch_size, ),
128-
device=latents.device).long()
130+
num_euler_timesteps, (batch_size, ),
131+
device=latents.device).long()
129132

130133
if sp_size > 1:
131134
self.sp_group.broadcast(indices, src=0)
132135

133136
# Add noise according to flow matching
134-
sigmas = extract_into_tensor(solver.sigmas, indices,
135-
latents.shape)
137+
sigmas = extract_into_tensor(solver.sigmas, indices, latents.shape)
136138
sigmas_prev = extract_into_tensor(solver.sigmas_prev, indices,
137139
latents.shape)
138140

@@ -186,16 +188,23 @@ def distill_one_step(
186188
# Get teacher model prediction on unconditional embedding
187189
with torch.autocast("cuda", dtype=torch.bfloat16):
188190
input_kwargs = {
189-
"hidden_states": noisy_model_input,
190-
"encoder_hidden_states": uncond_prompt_embed.unsqueeze(0).expand(
191-
batch_size, -1, -1),
192-
"timestep": timesteps,
193-
"encoder_attention_mask": uncond_prompt_mask.unsqueeze(0).expand(batch_size, -1),
194-
"return_dict": False,
191+
"hidden_states":
192+
noisy_model_input,
193+
"encoder_hidden_states":
194+
uncond_prompt_embed.unsqueeze(0).expand(
195+
batch_size, -1, -1),
196+
"timestep":
197+
timesteps,
198+
"encoder_attention_mask":
199+
uncond_prompt_mask.unsqueeze(0).expand(
200+
batch_size, -1),
201+
"return_dict":
202+
False,
195203
}
196204
with set_forward_context(current_timestep=timesteps,
197205
attn_metadata=None):
198-
uncond_teacher_output = teacher_transformer(**input_kwargs)[0]
206+
uncond_teacher_output = teacher_transformer(
207+
**input_kwargs)[0]
199208
teacher_output = uncond_teacher_output + distill_cfg * (
200209
cond_teacher_output - uncond_teacher_output)
201210
x_prev = solver.euler_step(noisy_model_input, teacher_output,
@@ -305,19 +314,24 @@ def forward(
305314
uncond_prompt_mask = self.uncond_prompt_mask
306315

307316
# Train!
308-
total_batch_size = (self.world_size * self.training_args.gradient_accumulation_steps /
309-
self.training_args.sp_size * self.training_args.train_sp_batch_size)
317+
total_batch_size = (self.world_size *
318+
self.training_args.gradient_accumulation_steps /
319+
self.training_args.sp_size *
320+
self.training_args.train_sp_batch_size)
310321
logger.info("***** Running distillation training *****")
311322
logger.info(f" Resume training from step {init_steps}")
312323
logger.info(
313-
f" Instantaneous batch size per device = {self.training_args.train_batch_size}")
324+
f" Instantaneous batch size per device = {self.training_args.train_batch_size}"
325+
)
314326
logger.info(
315327
f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}"
316328
)
317329
logger.info(
318330
f" Gradient Accumulation steps = {self.training_args.gradient_accumulation_steps}"
319331
)
320-
logger.info(f" Total optimization steps = {self.training_args.max_train_steps}")
332+
logger.info(
333+
f" Total optimization steps = {self.training_args.max_train_steps}"
334+
)
321335
logger.info(
322336
f" Total training parameters per FSDP shard = {sum(p.numel() for p in self.transformer.parameters() if p.requires_grad) / 1e9} B"
323337
)
@@ -354,12 +368,13 @@ def get_num_phases(multi_phased_distill_schedule, step):
354368
return int(phase)
355369
return int(phase)
356370

357-
for step in range(init_steps + 1, self.training_args.max_train_steps + 1):
371+
for step in range(init_steps + 1,
372+
self.training_args.max_train_steps + 1):
358373
start_time = time.perf_counter()
359374

360375
assert self.training_args.multi_phased_distill_schedule is not None
361-
num_phases = get_num_phases(self.training_args.multi_phased_distill_schedule,
362-
step)
376+
num_phases = get_num_phases(
377+
self.training_args.multi_phased_distill_schedule, step)
363378
try:
364379
loss, grad_norm, pred_norm = self.distill_one_step(
365380
self.transformer,
@@ -407,7 +422,6 @@ def get_num_phases(multi_phased_distill_schedule, step):
407422
step -= 1
408423
continue
409424

410-
411425
if self.rank <= 0:
412426
wandb.log(
413427
{
@@ -441,10 +455,10 @@ def get_num_phases(multi_phased_distill_schedule, step):
441455
else:
442456
if self.training_args.use_ema:
443457
save_checkpoint(self.ema_transformer, self.rank,
444-
self.training_args.output_dir, step)
458+
self.training_args.output_dir, step)
445459
else:
446460
save_checkpoint(self.transformer, self.rank,
447-
self.training_args.output_dir, step)
461+
self.training_args.output_dir, step)
448462
self.sp_group.barrier()
449463

450464
if self.training_args.log_validation and step % self.training_args.validation_steps == 0:
@@ -454,8 +468,9 @@ def get_num_phases(multi_phased_distill_schedule, step):
454468
if self.training_args.use_lora:
455469
raise NotImplementedError("LoRA is not supported now")
456470
else:
457-
save_checkpoint(self.transformer, self.rank, self.training_args.output_dir,
458-
self.training_args.max_train_steps)
471+
save_checkpoint(self.transformer, self.rank,
472+
self.training_args.output_dir,
473+
self.training_args.max_train_steps)
459474

460475
if get_sp_group():
461476
cleanup_dist_env_and_memory()

0 commit comments

Comments
 (0)