66import  dataclasses 
77from  contextlib  import  contextmanager 
88from  dataclasses  import  field 
9+ from  enum  import  Enum 
910from  typing  import  Any , Callable , List , Optional , Tuple 
1011
12+ import  torch 
13+ 
1114from  fastvideo .v1 .configs .models  import  DiTConfig , EncoderConfig , VAEConfig 
1215from  fastvideo .v1 .logger  import  init_logger 
1316from  fastvideo .v1 .utils  import  FlexibleArgumentParser , StoreBoolean 
1417
1518logger  =  init_logger (__name__ )
1619
1720
21+ class  Mode (Enum ):
22+     """Enumeration for FastVideo execution modes.""" 
23+     INFERENCE  =  "inference" 
24+     TRAINING  =  "training" 
25+     DISTILL  =  "distill" 
26+ 
27+ 
1828def  preprocess_text (prompt : str ) ->  str :
1929    return  prompt 
2030
@@ -34,7 +44,7 @@ class FastVideoArgs:
3444    # Distributed executor backend 
3545    distributed_executor_backend : str  =  "mp" 
3646
37-     mode : str  =  "inference"    # Options: "inference", "training", "distill" 
47+     mode : Mode  =  Mode . INFERENCE 
3848
3949    # HuggingFace specific parameters 
4050    trust_remote_code : bool  =  False 
@@ -115,15 +125,15 @@ class FastVideoArgs:
115125
116126    @property  
117127    def  training_mode (self ) ->  bool :
118-         return  self .mode  ==  "training" 
128+         return  self .mode  ==  Mode . TRAINING 
119129
120130    @property  
121131    def  distill_mode (self ) ->  bool :
122-         return  self .mode  ==  "distill" 
132+         return  self .mode  ==  Mode . DISTILL 
123133
124134    @property  
125135    def  inference_mode (self ) ->  bool :
126-         return  self .mode  ==  "inference" 
136+         return  self .mode  ==  Mode . INFERENCE 
127137
128138    def  __post_init__ (self ):
129139        pass 
@@ -160,8 +170,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
160170        parser .add_argument (
161171            "--mode" ,
162172            type = str ,
163-             default = FastVideoArgs .mode ,
164-             choices = ["inference" ,  "training" ,  "distill" ],
173+             default = FastVideoArgs .mode . value ,
174+             choices = [mode . value   for   mode   in   Mode ],
165175            help = "The mode to use" ,
166176        )
167177
@@ -376,9 +386,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
376386
377387    @classmethod  
378388    def  from_cli_args (cls , args : argparse .Namespace ) ->  "FastVideoArgs" :
379-         args .tp_size  =  args .tensor_parallel_size 
380-         args .sp_size  =  args .sequence_parallel_size 
381-         args .flow_shift  =  getattr (args , "shift" , args .flow_shift )
389+         assert  getattr (args , 'model_path' , None ) is  not None , "model_path must be set in args" 
390+         # Handle attribute mapping with safe getattr 
391+         if  hasattr (args , 'tensor_parallel_size' ):
392+             args .tp_size  =  args .tensor_parallel_size 
393+         if  hasattr (args , 'sequence_parallel_size' ):
394+             args .sp_size  =  args .sequence_parallel_size 
395+         if  hasattr (args , 'shift' ):
396+             args .flow_shift  =  args .shift 
397+         elif  hasattr (args , 'flow_shift' ):
398+             args .flow_shift  =  args .flow_shift 
382399
383400        # Get all fields from the dataclass 
384401        attrs  =  [attr .name  for  attr  in  dataclasses .fields (cls )]
@@ -397,6 +414,18 @@ def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs":
397414                kwargs [attr ] =  args .data_parallel_shards 
398415            elif  attr  ==  'flow_shift'  and  hasattr (args , 'shift' ):
399416                kwargs [attr ] =  args .shift 
417+             elif  attr  ==  'mode' :
418+                 # Convert string mode to Mode enum 
419+                 mode_value  =  getattr (args , attr , None )
420+                 if  mode_value :
421+                     if  isinstance (mode_value , Mode ):
422+                         kwargs [attr ] =  mode_value 
423+                     else :
424+                         kwargs [attr ] =  Mode (mode_value )
425+                 else :
426+                     kwargs [attr ] =  Mode .INFERENCE 
427+             elif  attr  ==  'device_str' :
428+                 kwargs [attr ] =  getattr (args , 'device' , None ) or  "cuda"  if  torch .cuda .is_available () else  "cpu" 
400429            # Use getattr with default value from the dataclass for potentially missing attributes 
401430            else :
402431                default_value  =  getattr (cls , attr , None )
@@ -595,9 +624,6 @@ class TrainingArgs(FastVideoArgs):
595624    # master_weight_type 
596625    master_weight_type : str  =  "" 
597626
598-     # For fast checking in LoRA pipeline 
599-     training_mode : bool  =  True 
600- 
601627    @classmethod  
602628    def  from_cli_args (cls , args : argparse .Namespace ) ->  "TrainingArgs" :
603629        # Get all fields from the dataclass 
@@ -617,6 +643,18 @@ def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs":
617643                kwargs [attr ] =  args .data_parallel_size 
618644            elif  attr  ==  'dp_shards'  and  hasattr (args , 'data_parallel_shards' ):
619645                kwargs [attr ] =  args .data_parallel_shards 
646+             elif  attr  ==  'mode' :
647+                 # Convert string mode to Mode enum 
648+                 mode_value  =  getattr (args , attr , None )
649+                 if  mode_value :
650+                     if  isinstance (mode_value , Mode ):
651+                         kwargs [attr ] =  mode_value 
652+                     else :
653+                         kwargs [attr ] =  Mode (mode_value )
654+                 else :
655+                     kwargs [attr ] =  Mode .TRAINING   # Default to training for TrainingArgs 
656+             elif  attr  ==  'device_str' :
657+                 kwargs [attr ] =  getattr (args , 'device' , None ) or  "cuda"  if  torch .cuda .is_available () else  "cpu" 
620658            # Use getattr with default value from the dataclass for potentially missing attributes 
621659            else :
622660                default_value  =  getattr (cls , attr , None )
0 commit comments