Skip to content

Commit 555b7a0

Browse files
committed
[Fix] adapt to main
1 parent 34523f8 commit 555b7a0

File tree

5 files changed

+69
-39
lines changed

5 files changed

+69
-39
lines changed

fastvideo/v1/fastvideo_args.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,25 @@
66
import dataclasses
77
from contextlib import contextmanager
88
from dataclasses import field
9+
from enum import Enum
910
from typing import Any, Callable, List, Optional, Tuple
1011

12+
import torch
13+
1114
from fastvideo.v1.configs.models import DiTConfig, EncoderConfig, VAEConfig
1215
from fastvideo.v1.logger import init_logger
1316
from fastvideo.v1.utils import FlexibleArgumentParser, StoreBoolean
1417

1518
logger = init_logger(__name__)
1619

1720

21+
class Mode(Enum):
22+
"""Enumeration for FastVideo execution modes."""
23+
INFERENCE = "inference"
24+
TRAINING = "training"
25+
DISTILL = "distill"
26+
27+
1828
def preprocess_text(prompt: str) -> str:
1929
return prompt
2030

@@ -34,7 +44,7 @@ class FastVideoArgs:
3444
# Distributed executor backend
3545
distributed_executor_backend: str = "mp"
3646

37-
mode: str = "inference" # Options: "inference", "training", "distill"
47+
mode: Mode = Mode.INFERENCE
3848

3949
# HuggingFace specific parameters
4050
trust_remote_code: bool = False
@@ -115,15 +125,15 @@ class FastVideoArgs:
115125

116126
@property
117127
def training_mode(self) -> bool:
118-
return self.mode == "training"
128+
return self.mode == Mode.TRAINING
119129

120130
@property
121131
def distill_mode(self) -> bool:
122-
return self.mode == "distill"
132+
return self.mode == Mode.DISTILL
123133

124134
@property
125135
def inference_mode(self) -> bool:
126-
return self.mode == "inference"
136+
return self.mode == Mode.INFERENCE
127137

128138
def __post_init__(self):
129139
pass
@@ -160,8 +170,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
160170
parser.add_argument(
161171
"--mode",
162172
type=str,
163-
default=FastVideoArgs.mode,
164-
choices=["inference", "training", "distill"],
173+
default=FastVideoArgs.mode.value,
174+
choices=[mode.value for mode in Mode],
165175
help="The mode to use",
166176
)
167177

@@ -376,9 +386,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
376386

377387
@classmethod
378388
def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs":
379-
args.tp_size = args.tensor_parallel_size
380-
args.sp_size = args.sequence_parallel_size
381-
args.flow_shift = getattr(args, "shift", args.flow_shift)
389+
assert getattr(args, 'model_path', None) is not None, "model_path must be set in args"
390+
# Handle attribute mapping with safe getattr
391+
if hasattr(args, 'tensor_parallel_size'):
392+
args.tp_size = args.tensor_parallel_size
393+
if hasattr(args, 'sequence_parallel_size'):
394+
args.sp_size = args.sequence_parallel_size
395+
if hasattr(args, 'shift'):
396+
args.flow_shift = args.shift
397+
elif hasattr(args, 'flow_shift'):
398+
args.flow_shift = args.flow_shift
382399

383400
# Get all fields from the dataclass
384401
attrs = [attr.name for attr in dataclasses.fields(cls)]
@@ -397,6 +414,18 @@ def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs":
397414
kwargs[attr] = args.data_parallel_shards
398415
elif attr == 'flow_shift' and hasattr(args, 'shift'):
399416
kwargs[attr] = args.shift
417+
elif attr == 'mode':
418+
# Convert string mode to Mode enum
419+
mode_value = getattr(args, attr, None)
420+
if mode_value:
421+
if isinstance(mode_value, Mode):
422+
kwargs[attr] = mode_value
423+
else:
424+
kwargs[attr] = Mode(mode_value)
425+
else:
426+
kwargs[attr] = Mode.INFERENCE
427+
elif attr == 'device_str':
428+
kwargs[attr] = getattr(args, 'device', None) or "cuda" if torch.cuda.is_available() else "cpu"
400429
# Use getattr with default value from the dataclass for potentially missing attributes
401430
else:
402431
default_value = getattr(cls, attr, None)
@@ -595,9 +624,6 @@ class TrainingArgs(FastVideoArgs):
595624
# master_weight_type
596625
master_weight_type: str = ""
597626

598-
# For fast checking in LoRA pipeline
599-
training_mode: bool = True
600-
601627
@classmethod
602628
def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs":
603629
# Get all fields from the dataclass
@@ -617,6 +643,18 @@ def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs":
617643
kwargs[attr] = args.data_parallel_size
618644
elif attr == 'dp_shards' and hasattr(args, 'data_parallel_shards'):
619645
kwargs[attr] = args.data_parallel_shards
646+
elif attr == 'mode':
647+
# Convert string mode to Mode enum
648+
mode_value = getattr(args, attr, None)
649+
if mode_value:
650+
if isinstance(mode_value, Mode):
651+
kwargs[attr] = mode_value
652+
else:
653+
kwargs[attr] = Mode(mode_value)
654+
else:
655+
kwargs[attr] = Mode.TRAINING # Default to training for TrainingArgs
656+
elif attr == 'device_str':
657+
kwargs[attr] = getattr(args, 'device', None) or "cuda" if torch.cuda.is_available() else "cpu"
620658
# Use getattr with default value from the dataclass for potentially missing attributes
621659
else:
622660
default_value = getattr(cls, attr, None)

fastvideo/v1/pipelines/composed_pipeline_base.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from fastvideo.v1.distributed import (init_distributed_environment,
1919
initialize_model_parallel,
2020
model_parallel_is_initialized)
21-
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs
21+
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs, Mode
2222
from fastvideo.v1.logger import init_logger
2323
from fastvideo.v1.models.loader.component_loader import PipelineComponentLoader
2424
from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch
@@ -94,7 +94,6 @@ def __init__(self,
9494
self.initialize_validation_pipeline(self.training_args)
9595
self.initialize_training_pipeline(self.training_args)
9696

97-
# TODO(jinzhe): discuss this
9897
if fastvideo_args.distill_mode:
9998
assert self.training_args is not None
10099
if self.training_args.log_validation:
@@ -159,38 +158,33 @@ def from_pretrained(cls,
159158
config_args = shallow_asdict(config)
160159
config_args.update(kwargs)
161160

162-
if args.mode == "inference":
163-
fastvideo_args = FastVideoArgs(model_path=model_path,
164-
device_str=device or "cuda" if
165-
torch.cuda.is_available() else "cpu",
166-
**config_args)
167-
168-
fastvideo_args.model_path = model_path
169-
fastvideo_args.device_str = device or "cuda" if torch.cuda.is_available(
170-
) else "cpu"
161+
args.model_path = model_path
162+
# Handle both string mode and Mode enum values
163+
mode_str = args.mode if isinstance(args.mode, str) else args.mode.value
164+
165+
if mode_str == "inference":
166+
fastvideo_args = FastVideoArgs.from_cli_args(args)
171167
for key, value in config_args.items():
172168
setattr(fastvideo_args, key, value)
173-
else:
169+
170+
elif mode_str == "training" or mode_str == "distill":
174171
assert args is not None, "args must be provided for training mode"
175172
fastvideo_args = TrainingArgs.from_cli_args(args)
176-
# TODO(will): fix this so that its not so ugly
177-
fastvideo_args.model_path = model_path
178-
fastvideo_args.device_str = device or "cuda" if torch.cuda.is_available(
179-
) else "cpu"
180173
for key, value in config_args.items():
181174
setattr(fastvideo_args, key, value)
182175

183176
fastvideo_args.num_gpus = int(os.environ.get("WORLD_SIZE", 1))
184177
fastvideo_args.use_cpu_offload = False
185178
# make sure we are in training mode
186-
fastvideo_args.mode = args.mode
187179
# we hijack the precision to be the master weight type so that the
188180
# model is loaded with the correct precision. Subsequently we will
189181
# use FSDP2's MixedPrecisionPolicy to set the precision for the
190182
# fwd, bwd, and other operations' precision.
191183
# fastvideo_args.precision = fastvideo_args.master_weight_type
192184
assert fastvideo_args.master_weight_type == 'fp32', 'only fp32 is supported for training'
193-
# assert fastvideo_args.precision == 'fp32', 'only fp32 is supported for training'
185+
else:
186+
raise ValueError(f"Invalid mode: {mode_str}")
187+
194188

195189
fastvideo_args.check_fastvideo_args()
196190

fastvideo/v1/training/distillation_pipeline.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from fastvideo.v1.configs.sample import SamplingParam
1818
from fastvideo.v1.dataset.parquet_datasets import ParquetVideoTextDataset
1919
from fastvideo.v1.distributed import get_sp_group
20-
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs
20+
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs, Mode
2121
from fastvideo.v1.logger import init_logger
2222
from fastvideo.v1.pipelines import ComposedPipelineBase
2323
from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch
@@ -150,8 +150,6 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs):
150150
train_dataset = ParquetVideoTextDataset(
151151
fastvideo_args.data_path,
152152
batch_size=fastvideo_args.train_batch_size,
153-
rank=self.rank,
154-
world_size=self.world_size,
155153
cfg_rate=fastvideo_args.cfg,
156154
num_latent_t=fastvideo_args.num_latent_t)
157155

@@ -203,7 +201,7 @@ def distill_one_step(self, transformer, model_type, teacher_transformer,
203201

204202
def log_validation(self, transformer, fastvideo_args, global_step):
205203
"""Log validation results during training."""
206-
fastvideo_args.mode = "inference"
204+
fastvideo_args.mode = Mode.INFERENCE
207205
fastvideo_args.use_cpu_offload = False
208206
if not fastvideo_args.log_validation:
209207
return
@@ -218,8 +216,6 @@ def log_validation(self, transformer, fastvideo_args, global_step):
218216
validation_dataset = ParquetVideoTextDataset(
219217
fastvideo_args.validation_prompt_dir,
220218
batch_size=1,
221-
rank=0,
222-
world_size=1,
223219
cfg_rate=0,
224220
num_latent_t=fastvideo_args.num_latent_t)
225221

@@ -324,7 +320,7 @@ def log_validation(self, transformer, fastvideo_args, global_step):
324320
wandb.log(logs, step=global_step)
325321

326322
# Re-enable gradients for training
327-
fastvideo_args.mode = "distill"
323+
fastvideo_args.mode = Mode.DISTILL
328324
transformer.requires_grad_(True)
329325
transformer.train()
330326

fastvideo/v1/training/wan_distillation_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from fastvideo.distill.solver import extract_into_tensor
1111
from fastvideo.v1.distributed import cleanup_dist_env_and_memory, get_sp_group
12-
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs
12+
from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs
1313
from fastvideo.v1.forward_context import set_forward_context
1414
from fastvideo.v1.logger import init_logger
1515
from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch
@@ -60,7 +60,7 @@ def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs):
6060
logger.info("Initializing validation pipeline...")
6161
args_copy = deepcopy(fastvideo_args)
6262

63-
args_copy.mode = "inference"
63+
args_copy.mode = Mode.INFERENCE
6464
args_copy.vae_config.load_encoder = False
6565
validation_pipeline = WanValidationPipeline.from_pretrained(
6666
fastvideo_args.model_path, args=args_copy)

scripts/distill/distill_v1.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ torchrun --nnodes 1 --nproc_per_node $num_gpus\
2121
--train_batch_size=1 \
2222
--num_latent_t 4 \
2323
--sp_size $num_gpus \
24+
--dp_size $num_gpus \
25+
--dp_shards $num_gpus \
2426
--train_sp_batch_size 1 \
2527
--dataloader_num_workers $num_gpus \
2628
--gradient_accumulation_steps=1 \

0 commit comments

Comments
 (0)