66import dataclasses
77from contextlib import contextmanager
88from dataclasses import field
9+ from enum import Enum
910from typing import Any , Callable , List , Optional , Tuple
1011
12+ import torch
13+
1114from fastvideo .v1 .configs .models import DiTConfig , EncoderConfig , VAEConfig
1215from fastvideo .v1 .logger import init_logger
1316from fastvideo .v1 .utils import FlexibleArgumentParser , StoreBoolean
1417
1518logger = init_logger (__name__ )
1619
1720
21+ class Mode (Enum ):
22+ """Enumeration for FastVideo execution modes."""
23+ INFERENCE = "inference"
24+ TRAINING = "training"
25+ DISTILL = "distill"
26+
27+
1828def preprocess_text (prompt : str ) -> str :
1929 return prompt
2030
@@ -34,7 +44,7 @@ class FastVideoArgs:
3444 # Distributed executor backend
3545 distributed_executor_backend : str = "mp"
3646
37- mode : str = "inference" # Options: "inference", "training", "distill"
47+ mode : Mode = Mode . INFERENCE
3848
3949 # HuggingFace specific parameters
4050 trust_remote_code : bool = False
@@ -115,15 +125,15 @@ class FastVideoArgs:
115125
116126 @property
117127 def training_mode (self ) -> bool :
118- return self .mode == "training"
128+ return self .mode == Mode . TRAINING
119129
120130 @property
121131 def distill_mode (self ) -> bool :
122- return self .mode == "distill"
132+ return self .mode == Mode . DISTILL
123133
124134 @property
125135 def inference_mode (self ) -> bool :
126- return self .mode == "inference"
136+ return self .mode == Mode . INFERENCE
127137
128138 def __post_init__ (self ):
129139 pass
@@ -160,8 +170,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
160170 parser .add_argument (
161171 "--mode" ,
162172 type = str ,
163- default = FastVideoArgs .mode ,
164- choices = ["inference" , "training" , "distill" ],
173+ default = FastVideoArgs .mode . value ,
174+ choices = [mode . value for mode in Mode ],
165175 help = "The mode to use" ,
166176 )
167177
@@ -376,9 +386,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
376386
377387 @classmethod
378388 def from_cli_args (cls , args : argparse .Namespace ) -> "FastVideoArgs" :
379- args .tp_size = args .tensor_parallel_size
380- args .sp_size = args .sequence_parallel_size
381- args .flow_shift = getattr (args , "shift" , args .flow_shift )
389+ assert getattr (args , 'model_path' , None ) is not None , "model_path must be set in args"
390+ # Handle attribute mapping with safe getattr
391+ if hasattr (args , 'tensor_parallel_size' ):
392+ args .tp_size = args .tensor_parallel_size
393+ if hasattr (args , 'sequence_parallel_size' ):
394+ args .sp_size = args .sequence_parallel_size
395+ if hasattr (args , 'shift' ):
396+ args .flow_shift = args .shift
397+ elif hasattr (args , 'flow_shift' ):
398+ args .flow_shift = args .flow_shift
382399
383400 # Get all fields from the dataclass
384401 attrs = [attr .name for attr in dataclasses .fields (cls )]
@@ -397,6 +414,18 @@ def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs":
397414 kwargs [attr ] = args .data_parallel_shards
398415 elif attr == 'flow_shift' and hasattr (args , 'shift' ):
399416 kwargs [attr ] = args .shift
417+ elif attr == 'mode' :
418+ # Convert string mode to Mode enum
419+ mode_value = getattr (args , attr , None )
420+ if mode_value :
421+ if isinstance (mode_value , Mode ):
422+ kwargs [attr ] = mode_value
423+ else :
424+ kwargs [attr ] = Mode (mode_value )
425+ else :
426+ kwargs [attr ] = Mode .INFERENCE
427+ elif attr == 'device_str' :
428+ kwargs [attr ] = getattr (args , 'device' , None ) or "cuda" if torch .cuda .is_available () else "cpu"
400429 # Use getattr with default value from the dataclass for potentially missing attributes
401430 else :
402431 default_value = getattr (cls , attr , None )
@@ -595,9 +624,6 @@ class TrainingArgs(FastVideoArgs):
595624 # master_weight_type
596625 master_weight_type : str = ""
597626
598- # For fast checking in LoRA pipeline
599- training_mode : bool = True
600-
601627 @classmethod
602628 def from_cli_args (cls , args : argparse .Namespace ) -> "TrainingArgs" :
603629 # Get all fields from the dataclass
@@ -617,6 +643,18 @@ def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs":
617643 kwargs [attr ] = args .data_parallel_size
618644 elif attr == 'dp_shards' and hasattr (args , 'data_parallel_shards' ):
619645 kwargs [attr ] = args .data_parallel_shards
646+ elif attr == 'mode' :
647+ # Convert string mode to Mode enum
648+ mode_value = getattr (args , attr , None )
649+ if mode_value :
650+ if isinstance (mode_value , Mode ):
651+ kwargs [attr ] = mode_value
652+ else :
653+ kwargs [attr ] = Mode (mode_value )
654+ else :
655+ kwargs [attr ] = Mode .TRAINING # Default to training for TrainingArgs
656+ elif attr == 'device_str' :
657+ kwargs [attr ] = getattr (args , 'device' , None ) or "cuda" if torch .cuda .is_available () else "cpu"
620658 # Use getattr with default value from the dataclass for potentially missing attributes
621659 else :
622660 default_value = getattr (cls , attr , None )
0 commit comments