Skip to content

Commit 016e24d

Browse files
[Training] [3/n] Add training args and dependencies (#440)
1 parent 85b8717 commit 016e24d

File tree

2 files changed

+349
-2
lines changed

2 files changed

+349
-2
lines changed

fastvideo/v1/fastvideo_args.py

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ class FastVideoArgs:
100100
device_str: Optional[str] = None
101101
device = None
102102

103+
@property
104+
def training_mode(self) -> bool:
105+
return not self.inference_mode
106+
103107
def __post_init__(self):
104108
pass
105109

@@ -132,6 +136,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
132136
help="The distributed executor backend to use",
133137
)
134138

139+
parser.add_argument(
140+
"--inference-mode",
141+
action=StoreBoolean,
142+
default=FastVideoArgs.inference_mode,
143+
help="Whether to use inference mode",
144+
)
145+
135146
# HuggingFace specific parameters
136147
parser.add_argument(
137148
"--trust-remote-code",
@@ -423,3 +434,333 @@ def get_current_fastvideo_args() -> FastVideoArgs:
423434
# TODO(will): may need to handle this for CI.
424435
raise ValueError("Current fastvideo args is not set.")
425436
return _current_fastvideo_args
437+
438+
439+
@dataclasses.dataclass
440+
class TrainingArgs(FastVideoArgs):
441+
"""
442+
Training arguments. Inherits from FastVideoArgs and adds training-specific
443+
arguments. If there are any conflicts, the training arguments will take
444+
precedence.
445+
"""
446+
data_path: str = ""
447+
dataloader_num_workers: int = 0
448+
num_height: int = 0
449+
num_width: int = 0
450+
num_frames: int = 0
451+
452+
train_batch_size: int = 0
453+
num_latent_t: int = 0
454+
group_frame: bool = False
455+
group_resolution: bool = False
456+
457+
# text encoder & vae & diffusion model
458+
pretrained_model_name_or_path: str = ""
459+
dit_model_name_or_path: str = ""
460+
cache_dir: str = ""
461+
462+
# diffusion setting
463+
ema_decay: float = 0.0
464+
ema_start_step: int = 0
465+
cfg: float = 0.0
466+
precondition_outputs: bool = False
467+
468+
# validation & logs
469+
validation_prompt_dir: str = ""
470+
validation_sampling_steps: str = ""
471+
validation_guidance_scale: str = ""
472+
validation_steps: float = 0.0
473+
log_validation: bool = False
474+
tracker_project_name: str = ""
475+
# seed: int
476+
477+
# output
478+
output_dir: str = ""
479+
checkpoints_total_limit: int = 0
480+
checkpointing_steps: int = 0
481+
logging_dir: str = ""
482+
483+
# optimizer & scheduler
484+
num_train_epochs: int = 0
485+
max_train_steps: int = 0
486+
gradient_accumulation_steps: int = 0
487+
learning_rate: float = 0.0
488+
scale_lr: bool = False
489+
lr_scheduler: str = ""
490+
lr_warmup_steps: int = 0
491+
max_grad_norm: float = 0.0
492+
gradient_checkpointing: bool = False
493+
selective_checkpointing: float = 0.0
494+
allow_tf32: bool = False
495+
mixed_precision: str = ""
496+
train_sp_batch_size: int = 0
497+
fsdp_sharding_startegy: str = ""
498+
499+
weighting_scheme: str = ""
500+
logit_mean: float = 0.0
501+
logit_std: float = 1.0
502+
mode_scale: float = 0.0
503+
504+
num_euler_timesteps: int = 0
505+
lr_num_cycles: int = 0
506+
lr_power: float = 0.0
507+
not_apply_cfg_solver: bool = False
508+
distill_cfg: float = 0.0
509+
scheduler_type: str = ""
510+
linear_quadratic_threshold: float = 0.0
511+
linear_range: float = 0.0
512+
weight_decay: float = 0.0
513+
use_ema: bool = False
514+
multi_phased_distill_schedule: str = ""
515+
pred_decay_weight: float = 0.0
516+
pred_decay_type: str = ""
517+
hunyuan_teacher_disable_cfg: bool = False
518+
519+
# master_weight_type
520+
master_weight_type: str = ""
521+
522+
@classmethod
523+
def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs":
524+
# Get all fields from the dataclass
525+
attrs = [attr.name for attr in dataclasses.fields(cls)]
526+
527+
# Create a dictionary of attribute values, with defaults for missing attributes
528+
kwargs = {}
529+
for attr in attrs:
530+
# Handle renamed attributes or those with multiple CLI names
531+
if attr == 'tp_size' and hasattr(args, 'tensor_parallel_size'):
532+
kwargs[attr] = args.tensor_parallel_size
533+
elif attr == 'sp_size' and hasattr(args, 'sequence_parallel_size'):
534+
kwargs[attr] = args.sequence_parallel_size
535+
elif attr == 'flow_shift' and hasattr(args, 'shift'):
536+
kwargs[attr] = args.shift
537+
# Use getattr with default value from the dataclass for potentially missing attributes
538+
else:
539+
default_value = getattr(cls, attr, None)
540+
kwargs[attr] = getattr(args, attr, default_value)
541+
542+
return cls(**kwargs)
543+
544+
@staticmethod
545+
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
546+
parser.add_argument("--data-path",
547+
type=str,
548+
required=True,
549+
help="Path to parquet files")
550+
parser.add_argument("--dataloader-num-workers",
551+
type=int,
552+
required=True,
553+
help="Number of workers for dataloader")
554+
parser.add_argument("--num-height",
555+
type=int,
556+
required=True,
557+
help="Number of heights")
558+
parser.add_argument("--num-width",
559+
type=int,
560+
required=True,
561+
help="Number of widths")
562+
parser.add_argument("--num-frames",
563+
type=int,
564+
required=True,
565+
help="Number of frames")
566+
567+
# Training batch and model configuration
568+
parser.add_argument("--train-batch-size",
569+
type=int,
570+
required=True,
571+
help="Training batch size")
572+
parser.add_argument("--num-latent-t",
573+
type=int,
574+
required=True,
575+
help="Number of latent time steps")
576+
parser.add_argument("--group-frame",
577+
action=StoreBoolean,
578+
help="Whether to group frames during training")
579+
parser.add_argument("--group-resolution",
580+
action=StoreBoolean,
581+
help="Whether to group resolutions during training")
582+
583+
# Model paths
584+
parser.add_argument("--pretrained-model-name-or-path",
585+
type=str,
586+
required=True,
587+
help="Path to pretrained model or model name")
588+
parser.add_argument("--dit-model-name-or-path",
589+
type=str,
590+
required=False,
591+
help="Path to DiT model or model name")
592+
parser.add_argument("--cache-dir",
593+
type=str,
594+
help="Directory to cache models")
595+
596+
# Diffusion settings
597+
parser.add_argument("--ema-decay",
598+
type=float,
599+
default=0.999,
600+
help="EMA decay rate")
601+
parser.add_argument("--ema-start-step",
602+
type=int,
603+
default=0,
604+
help="Step to start EMA")
605+
parser.add_argument("--cfg",
606+
type=float,
607+
help="Classifier-free guidance scale")
608+
parser.add_argument(
609+
"--precondition-outputs",
610+
action=StoreBoolean,
611+
help="Whether to precondition the outputs of the model")
612+
613+
# Validation and logging
614+
parser.add_argument("--validation-prompt-dir",
615+
type=str,
616+
help="Directory containing validation prompts")
617+
parser.add_argument("--validation-sampling-steps",
618+
type=str,
619+
help="Validation sampling steps")
620+
parser.add_argument("--validation-guidance-scale",
621+
type=str,
622+
help="Validation guidance scale")
623+
parser.add_argument("--validation-steps",
624+
type=float,
625+
help="Number of validation steps")
626+
parser.add_argument("--log-validation",
627+
action=StoreBoolean,
628+
help="Whether to log validation results")
629+
parser.add_argument("--tracker-project-name",
630+
type=str,
631+
help="Project name for tracking")
632+
633+
# Output configuration
634+
parser.add_argument("--output-dir",
635+
type=str,
636+
required=True,
637+
help="Output directory for checkpoints and logs")
638+
parser.add_argument("--checkpoints-total-limit",
639+
type=int,
640+
help="Maximum number of checkpoints to keep")
641+
parser.add_argument("--checkpointing-steps",
642+
type=int,
643+
help="Steps between checkpoints")
644+
parser.add_argument("--resume-from-checkpoint",
645+
type=str,
646+
help="Path to checkpoint to resume from")
647+
parser.add_argument("--logging-dir",
648+
type=str,
649+
help="Directory for logging")
650+
651+
# Training configuration
652+
parser.add_argument("--num-train-epochs",
653+
type=int,
654+
help="Number of training epochs")
655+
parser.add_argument("--max-train-steps",
656+
type=int,
657+
help="Maximum number of training steps")
658+
parser.add_argument("--gradient-accumulation-steps",
659+
type=int,
660+
help="Number of steps to accumulate gradients")
661+
parser.add_argument("--learning-rate",
662+
type=float,
663+
required=True,
664+
help="Learning rate")
665+
parser.add_argument("--scale-lr",
666+
action=StoreBoolean,
667+
help="Whether to scale learning rate")
668+
parser.add_argument("--lr-scheduler",
669+
type=str,
670+
default="constant",
671+
help="Learning rate scheduler type")
672+
parser.add_argument("--lr-warmup-steps",
673+
type=int,
674+
default=10,
675+
help="Number of warmup steps for learning rate")
676+
parser.add_argument("--max-grad-norm",
677+
type=float,
678+
help="Maximum gradient norm")
679+
parser.add_argument("--gradient-checkpointing",
680+
action=StoreBoolean,
681+
help="Whether to use gradient checkpointing")
682+
parser.add_argument("--selective-checkpointing",
683+
type=float,
684+
help="Selective checkpointing threshold")
685+
parser.add_argument("--allow-tf32",
686+
action=StoreBoolean,
687+
help="Whether to allow TF32")
688+
parser.add_argument("--mixed-precision",
689+
type=str,
690+
help="Mixed precision training type")
691+
parser.add_argument("--train-sp-batch-size",
692+
type=int,
693+
help="Training spatial parallelism batch size")
694+
695+
parser.add_argument("--fsdp-sharding-strategy",
696+
type=str,
697+
help="FSDP sharding strategy")
698+
699+
parser.add_argument(
700+
"--weighting_scheme",
701+
type=str,
702+
default="uniform",
703+
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"],
704+
)
705+
parser.add_argument(
706+
"--logit_mean",
707+
type=float,
708+
default=0.0,
709+
help="mean to use when using the `'logit_normal'` weighting scheme.",
710+
)
711+
parser.add_argument(
712+
"--logit_std",
713+
type=float,
714+
default=1.0,
715+
help="std to use when using the `'logit_normal'` weighting scheme.",
716+
)
717+
parser.add_argument(
718+
"--mode_scale",
719+
type=float,
720+
default=1.29,
721+
help=
722+
"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
723+
)
724+
725+
# Additional training parameters
726+
parser.add_argument("--num-euler-timesteps",
727+
type=int,
728+
help="Number of Euler timesteps")
729+
parser.add_argument("--lr-num-cycles",
730+
type=int,
731+
help="Number of learning rate cycles")
732+
parser.add_argument("--lr-power",
733+
type=float,
734+
help="Learning rate power")
735+
parser.add_argument("--not-apply-cfg-solver",
736+
action=StoreBoolean,
737+
help="Whether to not apply CFG solver")
738+
parser.add_argument("--distill-cfg",
739+
type=float,
740+
help="Distillation CFG scale")
741+
parser.add_argument("--scheduler-type", type=str, help="Scheduler type")
742+
parser.add_argument("--linear-quadratic-threshold",
743+
type=float,
744+
help="Linear quadratic threshold")
745+
parser.add_argument("--linear-range", type=float, help="Linear range")
746+
parser.add_argument("--weight-decay", type=float, help="Weight decay")
747+
parser.add_argument("--use-ema",
748+
action=StoreBoolean,
749+
help="Whether to use EMA")
750+
parser.add_argument("--multi-phased-distill-schedule",
751+
type=str,
752+
help="Multi-phased distillation schedule")
753+
parser.add_argument("--pred-decay-weight",
754+
type=float,
755+
help="Prediction decay weight")
756+
parser.add_argument("--pred-decay-type",
757+
type=str,
758+
help="Prediction decay type")
759+
parser.add_argument("--hunyuan-teacher-disable-cfg",
760+
action=StoreBoolean,
761+
help="Whether to disable CFG for Hunyuan teacher")
762+
parser.add_argument("--master-weight-type",
763+
type=str,
764+
help="Master weight type")
765+
766+
return parser

pyproject.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dependencies = [
1919

2020
# Machine Learning & Transformers
2121
"transformers>=4.46.1", "tokenizers>=0.20.1", "sentencepiece==0.2.0",
22-
"timm==1.0.11", "peft==0.13.2", "diffusers>=0.33.0", "bitsandbytes",
22+
"timm==1.0.11", "peft==0.13.2", "diffusers>=0.33.1", "bitsandbytes",
2323
"torch==2.6.0", "torchvision",
2424

2525
# Acceleration & Optimization
@@ -47,6 +47,12 @@ dependencies = [
4747

4848
# flash-attn: pip install flash-attn==2.7.4.post1 --no-cache-dir --no-build-isolation
4949

50+
train = [
51+
"torchdata",
52+
"pyarrow",
53+
"datasets",
54+
]
55+
5056
lint = [
5157
"pre-commit==4.0.1",
5258
]
@@ -57,7 +63,7 @@ test = [
5763
"pytest",
5864
]
5965

60-
dev = [ "fastvideo[lint]", "fastvideo[test]", ]
66+
dev = [ "fastvideo[lint]", "fastvideo[test]", "fastvideo[train]", ]
6167

6268
[project.scripts]
6369
fastvideo = "fastvideo.v1.entrypoints.cli.main:main"

0 commit comments

Comments
 (0)