@@ -100,6 +100,10 @@ class FastVideoArgs:
100100 device_str : Optional [str ] = None
101101 device = None
102102
103+ @property
104+ def training_mode (self ) -> bool :
105+ return not self .inference_mode
106+
103107 def __post_init__ (self ):
104108 pass
105109
@@ -132,6 +136,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
132136 help = "The distributed executor backend to use" ,
133137 )
134138
139+ parser .add_argument (
140+ "--inference-mode" ,
141+ action = StoreBoolean ,
142+ default = FastVideoArgs .inference_mode ,
143+ help = "Whether to use inference mode" ,
144+ )
145+
135146 # HuggingFace specific parameters
136147 parser .add_argument (
137148 "--trust-remote-code" ,
@@ -423,3 +434,333 @@ def get_current_fastvideo_args() -> FastVideoArgs:
423434 # TODO(will): may need to handle this for CI.
424435 raise ValueError ("Current fastvideo args is not set." )
425436 return _current_fastvideo_args
437+
438+
439+ @dataclasses .dataclass
440+ class TrainingArgs (FastVideoArgs ):
441+ """
442+ Training arguments. Inherits from FastVideoArgs and adds training-specific
443+ arguments. If there are any conflicts, the training arguments will take
444+ precedence.
445+ """
446+ data_path : str = ""
447+ dataloader_num_workers : int = 0
448+ num_height : int = 0
449+ num_width : int = 0
450+ num_frames : int = 0
451+
452+ train_batch_size : int = 0
453+ num_latent_t : int = 0
454+ group_frame : bool = False
455+ group_resolution : bool = False
456+
457+ # text encoder & vae & diffusion model
458+ pretrained_model_name_or_path : str = ""
459+ dit_model_name_or_path : str = ""
460+ cache_dir : str = ""
461+
462+ # diffusion setting
463+ ema_decay : float = 0.0
464+ ema_start_step : int = 0
465+ cfg : float = 0.0
466+ precondition_outputs : bool = False
467+
468+ # validation & logs
469+ validation_prompt_dir : str = ""
470+ validation_sampling_steps : str = ""
471+ validation_guidance_scale : str = ""
472+ validation_steps : float = 0.0
473+ log_validation : bool = False
474+ tracker_project_name : str = ""
475+ # seed: int
476+
477+ # output
478+ output_dir : str = ""
479+ checkpoints_total_limit : int = 0
480+ checkpointing_steps : int = 0
481+ logging_dir : str = ""
482+
483+ # optimizer & scheduler
484+ num_train_epochs : int = 0
485+ max_train_steps : int = 0
486+ gradient_accumulation_steps : int = 0
487+ learning_rate : float = 0.0
488+ scale_lr : bool = False
489+ lr_scheduler : str = ""
490+ lr_warmup_steps : int = 0
491+ max_grad_norm : float = 0.0
492+ gradient_checkpointing : bool = False
493+ selective_checkpointing : float = 0.0
494+ allow_tf32 : bool = False
495+ mixed_precision : str = ""
496+ train_sp_batch_size : int = 0
497+ fsdp_sharding_startegy : str = ""
498+
499+ weighting_scheme : str = ""
500+ logit_mean : float = 0.0
501+ logit_std : float = 1.0
502+ mode_scale : float = 0.0
503+
504+ num_euler_timesteps : int = 0
505+ lr_num_cycles : int = 0
506+ lr_power : float = 0.0
507+ not_apply_cfg_solver : bool = False
508+ distill_cfg : float = 0.0
509+ scheduler_type : str = ""
510+ linear_quadratic_threshold : float = 0.0
511+ linear_range : float = 0.0
512+ weight_decay : float = 0.0
513+ use_ema : bool = False
514+ multi_phased_distill_schedule : str = ""
515+ pred_decay_weight : float = 0.0
516+ pred_decay_type : str = ""
517+ hunyuan_teacher_disable_cfg : bool = False
518+
519+ # master_weight_type
520+ master_weight_type : str = ""
521+
522+ @classmethod
523+ def from_cli_args (cls , args : argparse .Namespace ) -> "TrainingArgs" :
524+ # Get all fields from the dataclass
525+ attrs = [attr .name for attr in dataclasses .fields (cls )]
526+
527+ # Create a dictionary of attribute values, with defaults for missing attributes
528+ kwargs = {}
529+ for attr in attrs :
530+ # Handle renamed attributes or those with multiple CLI names
531+ if attr == 'tp_size' and hasattr (args , 'tensor_parallel_size' ):
532+ kwargs [attr ] = args .tensor_parallel_size
533+ elif attr == 'sp_size' and hasattr (args , 'sequence_parallel_size' ):
534+ kwargs [attr ] = args .sequence_parallel_size
535+ elif attr == 'flow_shift' and hasattr (args , 'shift' ):
536+ kwargs [attr ] = args .shift
537+ # Use getattr with default value from the dataclass for potentially missing attributes
538+ else :
539+ default_value = getattr (cls , attr , None )
540+ kwargs [attr ] = getattr (args , attr , default_value )
541+
542+ return cls (** kwargs )
543+
544+ @staticmethod
545+ def add_cli_args (parser : FlexibleArgumentParser ) -> FlexibleArgumentParser :
546+ parser .add_argument ("--data-path" ,
547+ type = str ,
548+ required = True ,
549+ help = "Path to parquet files" )
550+ parser .add_argument ("--dataloader-num-workers" ,
551+ type = int ,
552+ required = True ,
553+ help = "Number of workers for dataloader" )
554+ parser .add_argument ("--num-height" ,
555+ type = int ,
556+ required = True ,
557+ help = "Number of heights" )
558+ parser .add_argument ("--num-width" ,
559+ type = int ,
560+ required = True ,
561+ help = "Number of widths" )
562+ parser .add_argument ("--num-frames" ,
563+ type = int ,
564+ required = True ,
565+ help = "Number of frames" )
566+
567+ # Training batch and model configuration
568+ parser .add_argument ("--train-batch-size" ,
569+ type = int ,
570+ required = True ,
571+ help = "Training batch size" )
572+ parser .add_argument ("--num-latent-t" ,
573+ type = int ,
574+ required = True ,
575+ help = "Number of latent time steps" )
576+ parser .add_argument ("--group-frame" ,
577+ action = StoreBoolean ,
578+ help = "Whether to group frames during training" )
579+ parser .add_argument ("--group-resolution" ,
580+ action = StoreBoolean ,
581+ help = "Whether to group resolutions during training" )
582+
583+ # Model paths
584+ parser .add_argument ("--pretrained-model-name-or-path" ,
585+ type = str ,
586+ required = True ,
587+ help = "Path to pretrained model or model name" )
588+ parser .add_argument ("--dit-model-name-or-path" ,
589+ type = str ,
590+ required = False ,
591+ help = "Path to DiT model or model name" )
592+ parser .add_argument ("--cache-dir" ,
593+ type = str ,
594+ help = "Directory to cache models" )
595+
596+ # Diffusion settings
597+ parser .add_argument ("--ema-decay" ,
598+ type = float ,
599+ default = 0.999 ,
600+ help = "EMA decay rate" )
601+ parser .add_argument ("--ema-start-step" ,
602+ type = int ,
603+ default = 0 ,
604+ help = "Step to start EMA" )
605+ parser .add_argument ("--cfg" ,
606+ type = float ,
607+ help = "Classifier-free guidance scale" )
608+ parser .add_argument (
609+ "--precondition-outputs" ,
610+ action = StoreBoolean ,
611+ help = "Whether to precondition the outputs of the model" )
612+
613+ # Validation and logging
614+ parser .add_argument ("--validation-prompt-dir" ,
615+ type = str ,
616+ help = "Directory containing validation prompts" )
617+ parser .add_argument ("--validation-sampling-steps" ,
618+ type = str ,
619+ help = "Validation sampling steps" )
620+ parser .add_argument ("--validation-guidance-scale" ,
621+ type = str ,
622+ help = "Validation guidance scale" )
623+ parser .add_argument ("--validation-steps" ,
624+ type = float ,
625+ help = "Number of validation steps" )
626+ parser .add_argument ("--log-validation" ,
627+ action = StoreBoolean ,
628+ help = "Whether to log validation results" )
629+ parser .add_argument ("--tracker-project-name" ,
630+ type = str ,
631+ help = "Project name for tracking" )
632+
633+ # Output configuration
634+ parser .add_argument ("--output-dir" ,
635+ type = str ,
636+ required = True ,
637+ help = "Output directory for checkpoints and logs" )
638+ parser .add_argument ("--checkpoints-total-limit" ,
639+ type = int ,
640+ help = "Maximum number of checkpoints to keep" )
641+ parser .add_argument ("--checkpointing-steps" ,
642+ type = int ,
643+ help = "Steps between checkpoints" )
644+ parser .add_argument ("--resume-from-checkpoint" ,
645+ type = str ,
646+ help = "Path to checkpoint to resume from" )
647+ parser .add_argument ("--logging-dir" ,
648+ type = str ,
649+ help = "Directory for logging" )
650+
651+ # Training configuration
652+ parser .add_argument ("--num-train-epochs" ,
653+ type = int ,
654+ help = "Number of training epochs" )
655+ parser .add_argument ("--max-train-steps" ,
656+ type = int ,
657+ help = "Maximum number of training steps" )
658+ parser .add_argument ("--gradient-accumulation-steps" ,
659+ type = int ,
660+ help = "Number of steps to accumulate gradients" )
661+ parser .add_argument ("--learning-rate" ,
662+ type = float ,
663+ required = True ,
664+ help = "Learning rate" )
665+ parser .add_argument ("--scale-lr" ,
666+ action = StoreBoolean ,
667+ help = "Whether to scale learning rate" )
668+ parser .add_argument ("--lr-scheduler" ,
669+ type = str ,
670+ default = "constant" ,
671+ help = "Learning rate scheduler type" )
672+ parser .add_argument ("--lr-warmup-steps" ,
673+ type = int ,
674+ default = 10 ,
675+ help = "Number of warmup steps for learning rate" )
676+ parser .add_argument ("--max-grad-norm" ,
677+ type = float ,
678+ help = "Maximum gradient norm" )
679+ parser .add_argument ("--gradient-checkpointing" ,
680+ action = StoreBoolean ,
681+ help = "Whether to use gradient checkpointing" )
682+ parser .add_argument ("--selective-checkpointing" ,
683+ type = float ,
684+ help = "Selective checkpointing threshold" )
685+ parser .add_argument ("--allow-tf32" ,
686+ action = StoreBoolean ,
687+ help = "Whether to allow TF32" )
688+ parser .add_argument ("--mixed-precision" ,
689+ type = str ,
690+ help = "Mixed precision training type" )
691+ parser .add_argument ("--train-sp-batch-size" ,
692+ type = int ,
693+ help = "Training spatial parallelism batch size" )
694+
695+ parser .add_argument ("--fsdp-sharding-strategy" ,
696+ type = str ,
697+ help = "FSDP sharding strategy" )
698+
699+ parser .add_argument (
700+ "--weighting_scheme" ,
701+ type = str ,
702+ default = "uniform" ,
703+ choices = ["sigma_sqrt" , "logit_normal" , "mode" , "cosmap" , "uniform" ],
704+ )
705+ parser .add_argument (
706+ "--logit_mean" ,
707+ type = float ,
708+ default = 0.0 ,
709+ help = "mean to use when using the `'logit_normal'` weighting scheme." ,
710+ )
711+ parser .add_argument (
712+ "--logit_std" ,
713+ type = float ,
714+ default = 1.0 ,
715+ help = "std to use when using the `'logit_normal'` weighting scheme." ,
716+ )
717+ parser .add_argument (
718+ "--mode_scale" ,
719+ type = float ,
720+ default = 1.29 ,
721+ help =
722+ "Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`." ,
723+ )
724+
725+ # Additional training parameters
726+ parser .add_argument ("--num-euler-timesteps" ,
727+ type = int ,
728+ help = "Number of Euler timesteps" )
729+ parser .add_argument ("--lr-num-cycles" ,
730+ type = int ,
731+ help = "Number of learning rate cycles" )
732+ parser .add_argument ("--lr-power" ,
733+ type = float ,
734+ help = "Learning rate power" )
735+ parser .add_argument ("--not-apply-cfg-solver" ,
736+ action = StoreBoolean ,
737+ help = "Whether to not apply CFG solver" )
738+ parser .add_argument ("--distill-cfg" ,
739+ type = float ,
740+ help = "Distillation CFG scale" )
741+ parser .add_argument ("--scheduler-type" , type = str , help = "Scheduler type" )
742+ parser .add_argument ("--linear-quadratic-threshold" ,
743+ type = float ,
744+ help = "Linear quadratic threshold" )
745+ parser .add_argument ("--linear-range" , type = float , help = "Linear range" )
746+ parser .add_argument ("--weight-decay" , type = float , help = "Weight decay" )
747+ parser .add_argument ("--use-ema" ,
748+ action = StoreBoolean ,
749+ help = "Whether to use EMA" )
750+ parser .add_argument ("--multi-phased-distill-schedule" ,
751+ type = str ,
752+ help = "Multi-phased distillation schedule" )
753+ parser .add_argument ("--pred-decay-weight" ,
754+ type = float ,
755+ help = "Prediction decay weight" )
756+ parser .add_argument ("--pred-decay-type" ,
757+ type = str ,
758+ help = "Prediction decay type" )
759+ parser .add_argument ("--hunyuan-teacher-disable-cfg" ,
760+ action = StoreBoolean ,
761+ help = "Whether to disable CFG for Hunyuan teacher" )
762+ parser .add_argument ("--master-weight-type" ,
763+ type = str ,
764+ help = "Master weight type" )
765+
766+ return parser
0 commit comments