Skip to content

Commit 3a83130

Browse files
committed
[Fix] adapt to main
1 parent 6fa53cb commit 3a83130

File tree

5 files changed

+73
-33
lines changed

5 files changed

+73
-33
lines changed

fastvideo/v1/fastvideo_args.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,25 @@
66
import dataclasses
77
from contextlib import contextmanager
88
from dataclasses import field
9+
from enum import Enum
910
from typing import Any, Callable, List, Optional, Tuple
1011

12+
import torch
13+
1114
from fastvideo.v1.configs.models import DiTConfig, EncoderConfig, VAEConfig
1215
from fastvideo.v1.logger import init_logger
1316
from fastvideo.v1.utils import FlexibleArgumentParser, StoreBoolean
1417

1518
logger = init_logger(__name__)
1619

1720

21+
class Mode(Enum):
22+
"""Enumeration for FastVideo execution modes."""
23+
INFERENCE = "inference"
24+
TRAINING = "training"
25+
DISTILL = "distill"
26+
27+
1828
def preprocess_text(prompt: str) -> str:
1929
return prompt
2030

@@ -34,7 +44,7 @@ class FastVideoArgs:
3444
# Distributed executor backend
3545
distributed_executor_backend: str = "mp"
3646

37-
mode: str = "inference" # Options: "inference", "training", "distill"
47+
mode: Mode = Mode.INFERENCE
3848

3949
# HuggingFace specific parameters
4050
trust_remote_code: bool = False
@@ -111,15 +121,15 @@ class FastVideoArgs:
111121

112122
@property
113123
def training_mode(self) -> bool:
114-
return self.mode == "training"
124+
return self.mode == Mode.TRAINING
115125

116126
@property
117127
def distill_mode(self) -> bool:
118-
return self.mode == "distill"
128+
return self.mode == Mode.DISTILL
119129

120130
@property
121131
def inference_mode(self) -> bool:
122-
return self.mode == "inference"
132+
return self.mode == Mode.INFERENCE
123133

124134
def __post_init__(self):
125135
self.check_fastvideo_args()
@@ -156,8 +166,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
156166
parser.add_argument(
157167
"--mode",
158168
type=str,
159-
default=FastVideoArgs.mode,
160-
choices=["inference", "training", "distill"],
169+
default=FastVideoArgs.mode.value,
170+
choices=[mode.value for mode in Mode],
161171
help="The mode to use",
162172
)
163173

@@ -371,9 +381,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
371381

372382
@classmethod
373383
def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs":
374-
args.tp_size = args.tensor_parallel_size
375-
args.sp_size = args.sequence_parallel_size
376-
args.flow_shift = getattr(args, "shift", args.flow_shift)
384+
assert getattr(args, 'model_path', None) is not None, "model_path must be set in args"
385+
# Handle attribute mapping with safe getattr
386+
if hasattr(args, 'tensor_parallel_size'):
387+
args.tp_size = args.tensor_parallel_size
388+
if hasattr(args, 'sequence_parallel_size'):
389+
args.sp_size = args.sequence_parallel_size
390+
if hasattr(args, 'shift'):
391+
args.flow_shift = args.shift
392+
elif hasattr(args, 'flow_shift'):
393+
args.flow_shift = args.flow_shift
377394

378395
# Get all fields from the dataclass
379396
attrs = [attr.name for attr in dataclasses.fields(cls)]
@@ -388,6 +405,18 @@ def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs":
388405
kwargs[attr] = args.sequence_parallel_size
389406
elif attr == 'flow_shift' and hasattr(args, 'shift'):
390407
kwargs[attr] = args.shift
408+
elif attr == 'mode':
409+
# Convert string mode to Mode enum
410+
mode_value = getattr(args, attr, None)
411+
if mode_value:
412+
if isinstance(mode_value, Mode):
413+
kwargs[attr] = mode_value
414+
else:
415+
kwargs[attr] = Mode(mode_value)
416+
else:
417+
kwargs[attr] = Mode.INFERENCE
418+
elif attr == 'device_str':
419+
kwargs[attr] = getattr(args, 'device', None) or "cuda" if torch.cuda.is_available() else "cpu"
391420
# Use getattr with default value from the dataclass for potentially missing attributes
392421
else:
393422
default_value = getattr(cls, attr, None)
@@ -587,9 +616,6 @@ class TrainingArgs(FastVideoArgs):
587616
# master_weight_type
588617
master_weight_type: str = ""
589618

590-
# For fast checking in LoRA pipeline
591-
training_mode: bool = True
592-
593619
@classmethod
594620
def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs":
595621
# Get all fields from the dataclass
@@ -605,6 +631,19 @@ def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs":
605631
kwargs[attr] = args.sequence_parallel_size
606632
elif attr == 'flow_shift' and hasattr(args, 'shift'):
607633
kwargs[attr] = args.shift
634+
elif attr == 'mode':
635+
# Convert string mode to Mode enum
636+
mode_value = getattr(args, attr, None)
637+
if mode_value:
638+
if isinstance(mode_value, Mode):
639+
kwargs[attr] = mode_value
640+
else:
641+
kwargs[attr] = Mode(mode_value)
642+
else:
643+
kwargs[attr] = Mode.TRAINING # Default to training for TrainingArgs
644+
elif attr == 'device_str':
645+
kwargs[attr] = getattr(args, 'device', None) or "cuda" if torch.cuda.is_available() else "cpu"
646+
# Use getattr with default value from the dataclass for potentially missing attributes
608647
else:
609648
default_value = getattr(cls, attr, None)
610649
if getattr(args, attr, default_value) is not None:

fastvideo/v1/pipelines/composed_pipeline_base.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
from fastvideo.v1.distributed import (
1919
maybe_init_distributed_environment_and_model_parallel)
2020
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs
21+
from fastvideo.v1.distributed import (init_distributed_environment,
22+
initialize_model_parallel,
23+
model_parallel_is_initialized)
24+
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs, Mode
2125
from fastvideo.v1.logger import init_logger
2226
from fastvideo.v1.models.loader.component_loader import PipelineComponentLoader
2327
from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch
@@ -94,7 +98,6 @@ def __init__(self,
9498
self.initialize_validation_pipeline(self.training_args)
9599
self.initialize_training_pipeline(self.training_args)
96100

97-
# TODO(jinzhe): discuss this
98101
if fastvideo_args.distill_mode:
99102
assert self.training_args is not None
100103
if self.training_args.log_validation:
@@ -159,32 +162,32 @@ def from_pretrained(cls,
159162
config_args = shallow_asdict(config)
160163
config_args.update(kwargs)
161164

162-
if args.mode == "inference":
163-
fastvideo_args = FastVideoArgs(model_path=model_path,
164-
device_str=device or "cuda" if
165-
torch.cuda.is_available() else "cpu",
166-
**config_args)
167-
fastvideo_args.model_path = model_path
165+
args.model_path = model_path
166+
# Handle both string mode and Mode enum values
167+
mode_str = args.mode if isinstance(args.mode, str) else args.mode.value
168+
169+
if mode_str == "inference":
170+
fastvideo_args = FastVideoArgs.from_cli_args(args)
168171
for key, value in config_args.items():
169172
setattr(fastvideo_args, key, value)
170-
else:
173+
174+
elif mode_str == "training" or mode_str == "distill":
171175
assert args is not None, "args must be provided for training mode"
172176
fastvideo_args = TrainingArgs.from_cli_args(args)
173-
# TODO(will): fix this so that its not so ugly
174-
fastvideo_args.model_path = model_path
175177
for key, value in config_args.items():
176178
setattr(fastvideo_args, key, value)
177179

178180
fastvideo_args.use_cpu_offload = False
179181
# make sure we are in training mode
180-
fastvideo_args.mode = args.mode
181182
# we hijack the precision to be the master weight type so that the
182183
# model is loaded with the correct precision. Subsequently we will
183184
# use FSDP2's MixedPrecisionPolicy to set the precision for the
184185
# fwd, bwd, and other operations' precision.
185186
# fastvideo_args.precision = fastvideo_args.master_weight_type
186187
assert fastvideo_args.master_weight_type == 'fp32', 'only fp32 is supported for training'
187-
# assert fastvideo_args.precision == 'fp32', 'only fp32 is supported for training'
188+
else:
189+
raise ValueError(f"Invalid mode: {mode_str}")
190+
188191

189192
logger.info("fastvideo_args in from_pretrained: %s", fastvideo_args)
190193

fastvideo/v1/training/distillation_pipeline.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from fastvideo.v1.configs.sample import SamplingParam
1818
from fastvideo.v1.dataset.parquet_datasets import ParquetVideoTextDataset
1919
from fastvideo.v1.distributed import get_sp_group
20-
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs
20+
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs, Mode
2121
from fastvideo.v1.logger import init_logger
2222
from fastvideo.v1.pipelines import ComposedPipelineBase
2323
from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch
@@ -150,8 +150,6 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs):
150150
train_dataset = ParquetVideoTextDataset(
151151
fastvideo_args.data_path,
152152
batch_size=fastvideo_args.train_batch_size,
153-
rank=self.rank,
154-
world_size=self.world_size,
155153
cfg_rate=fastvideo_args.cfg,
156154
num_latent_t=fastvideo_args.num_latent_t)
157155

@@ -203,7 +201,7 @@ def distill_one_step(self, transformer, model_type, teacher_transformer,
203201

204202
def log_validation(self, transformer, fastvideo_args, global_step):
205203
"""Log validation results during training."""
206-
fastvideo_args.mode = "inference"
204+
fastvideo_args.mode = Mode.INFERENCE
207205
fastvideo_args.use_cpu_offload = False
208206
if not fastvideo_args.log_validation:
209207
return
@@ -218,8 +216,6 @@ def log_validation(self, transformer, fastvideo_args, global_step):
218216
validation_dataset = ParquetVideoTextDataset(
219217
fastvideo_args.validation_prompt_dir,
220218
batch_size=1,
221-
rank=0,
222-
world_size=1,
223219
cfg_rate=0,
224220
num_latent_t=fastvideo_args.num_latent_t)
225221

@@ -324,7 +320,7 @@ def log_validation(self, transformer, fastvideo_args, global_step):
324320
wandb.log(logs, step=global_step)
325321

326322
# Re-enable gradients for training
327-
fastvideo_args.mode = "distill"
323+
fastvideo_args.mode = Mode.DISTILL
328324
transformer.requires_grad_(True)
329325
transformer.train()
330326

fastvideo/v1/training/wan_distillation_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from fastvideo.distill.solver import extract_into_tensor
1111
from fastvideo.v1.distributed import cleanup_dist_env_and_memory, get_sp_group
12-
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs
12+
from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs
1313
from fastvideo.v1.forward_context import set_forward_context
1414
from fastvideo.v1.logger import init_logger
1515
from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch
@@ -60,7 +60,7 @@ def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs):
6060
logger.info("Initializing validation pipeline...")
6161
args_copy = deepcopy(fastvideo_args)
6262

63-
args_copy.mode = "inference"
63+
args_copy.mode = Mode.INFERENCE
6464
args_copy.vae_config.load_encoder = False
6565
validation_pipeline = WanValidationPipeline.from_pretrained(
6666
fastvideo_args.model_path, args=args_copy)

scripts/distill/distill_v1.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ torchrun --nnodes 1 --nproc_per_node $num_gpus\
2121
--train_batch_size=1 \
2222
--num_latent_t 4 \
2323
--sp_size $num_gpus \
24+
--dp_size $num_gpus \
25+
--dp_shards $num_gpus \
2426
--train_sp_batch_size 1 \
2527
--dataloader_num_workers $num_gpus \
2628
--gradient_accumulation_steps=1 \

0 commit comments

Comments
 (0)