66import dataclasses
77from contextlib import contextmanager
88from dataclasses import field
9+ from enum import Enum
910from typing import Any , Callable , List , Optional , Tuple
1011
12+ import torch
13+
1114from fastvideo .v1 .configs .models import DiTConfig , EncoderConfig , VAEConfig
1215from fastvideo .v1 .logger import init_logger
1316from fastvideo .v1 .utils import FlexibleArgumentParser , StoreBoolean
1417
1518logger = init_logger (__name__ )
1619
1720
21+ class Mode (Enum ):
22+ """Enumeration for FastVideo execution modes."""
23+ INFERENCE = "inference"
24+ TRAINING = "training"
25+ DISTILL = "distill"
26+
27+
1828def preprocess_text (prompt : str ) -> str :
1929 return prompt
2030
@@ -34,7 +44,7 @@ class FastVideoArgs:
3444 # Distributed executor backend
3545 distributed_executor_backend : str = "mp"
3646
37- mode : str = "inference" # Options: "inference", "training", "distill"
47+ mode : Mode = Mode . INFERENCE
3848
3949 # HuggingFace specific parameters
4050 trust_remote_code : bool = False
@@ -111,15 +121,15 @@ class FastVideoArgs:
111121
112122 @property
113123 def training_mode (self ) -> bool :
114- return self .mode == "training"
124+ return self .mode == Mode . TRAINING
115125
116126 @property
117127 def distill_mode (self ) -> bool :
118- return self .mode == "distill"
128+ return self .mode == Mode . DISTILL
119129
120130 @property
121131 def inference_mode (self ) -> bool :
122- return self .mode == "inference"
132+ return self .mode == Mode . INFERENCE
123133
124134 def __post_init__ (self ):
125135 self .check_fastvideo_args ()
@@ -156,8 +166,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
156166 parser .add_argument (
157167 "--mode" ,
158168 type = str ,
159- default = FastVideoArgs .mode ,
160- choices = ["inference" , "training" , "distill" ],
169+ default = FastVideoArgs .mode . value ,
170+ choices = [mode . value for mode in Mode ],
161171 help = "The mode to use" ,
162172 )
163173
@@ -371,9 +381,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
371381
372382 @classmethod
373383 def from_cli_args (cls , args : argparse .Namespace ) -> "FastVideoArgs" :
374- args .tp_size = args .tensor_parallel_size
375- args .sp_size = args .sequence_parallel_size
376- args .flow_shift = getattr (args , "shift" , args .flow_shift )
384+ assert getattr (args , 'model_path' , None ) is not None , "model_path must be set in args"
385+ # Handle attribute mapping with safe getattr
386+ if hasattr (args , 'tensor_parallel_size' ):
387+ args .tp_size = args .tensor_parallel_size
388+ if hasattr (args , 'sequence_parallel_size' ):
389+ args .sp_size = args .sequence_parallel_size
390+ if hasattr (args , 'shift' ):
391+ args .flow_shift = args .shift
392+ elif hasattr (args , 'flow_shift' ):
393+ args .flow_shift = args .flow_shift
377394
378395 # Get all fields from the dataclass
379396 attrs = [attr .name for attr in dataclasses .fields (cls )]
@@ -388,6 +405,18 @@ def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs":
388405 kwargs [attr ] = args .sequence_parallel_size
389406 elif attr == 'flow_shift' and hasattr (args , 'shift' ):
390407 kwargs [attr ] = args .shift
408+ elif attr == 'mode' :
409+ # Convert string mode to Mode enum
410+ mode_value = getattr (args , attr , None )
411+ if mode_value :
412+ if isinstance (mode_value , Mode ):
413+ kwargs [attr ] = mode_value
414+ else :
415+ kwargs [attr ] = Mode (mode_value )
416+ else :
417+ kwargs [attr ] = Mode .INFERENCE
418+ elif attr == 'device_str' :
419+ kwargs [attr ] = getattr (args , 'device' , None ) or "cuda" if torch .cuda .is_available () else "cpu"
391420 # Use getattr with default value from the dataclass for potentially missing attributes
392421 else :
393422 default_value = getattr (cls , attr , None )
@@ -587,9 +616,6 @@ class TrainingArgs(FastVideoArgs):
587616 # master_weight_type
588617 master_weight_type : str = ""
589618
590- # For fast checking in LoRA pipeline
591- training_mode : bool = True
592-
593619 @classmethod
594620 def from_cli_args (cls , args : argparse .Namespace ) -> "TrainingArgs" :
595621 # Get all fields from the dataclass
@@ -605,6 +631,19 @@ def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs":
605631 kwargs [attr ] = args .sequence_parallel_size
606632 elif attr == 'flow_shift' and hasattr (args , 'shift' ):
607633 kwargs [attr ] = args .shift
634+ elif attr == 'mode' :
635+ # Convert string mode to Mode enum
636+ mode_value = getattr (args , attr , None )
637+ if mode_value :
638+ if isinstance (mode_value , Mode ):
639+ kwargs [attr ] = mode_value
640+ else :
641+ kwargs [attr ] = Mode (mode_value )
642+ else :
643+ kwargs [attr ] = Mode .TRAINING # Default to training for TrainingArgs
644+ elif attr == 'device_str' :
645+ kwargs [attr ] = getattr (args , 'device' , None ) or "cuda" if torch .cuda .is_available () else "cpu"
646+ # Use getattr with default value from the dataclass for potentially missing attributes
608647 else :
609648 default_value = getattr (cls , attr , None )
610649 if getattr (args , attr , default_value ) is not None :
0 commit comments