Skip to content

Commit 288ea88

Browse files
[Feat][Training] Rename weight conversion function and update gradient checkpoint in scripts (#589)
1 parent eb0f131 commit 288ea88

File tree

5 files changed

+15
-17
lines changed

5 files changed

+15
-17
lines changed

fastvideo/v1/fastvideo_args.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -413,8 +413,7 @@ class TrainingArgs(FastVideoArgs):
413413
lr_scheduler: str = "constant"
414414
lr_warmup_steps: int = 0
415415
max_grad_norm: float = 0.0
416-
gradient_checkpointing: bool = False
417-
gradient_checkpointing_type: str = "full"
416+
enable_gradient_checkpointing_type: Optional[str] = None
418417
selective_checkpointing: float = 0.0
419418
allow_tf32: bool = False
420419
mixed_precision: str = ""
@@ -613,13 +612,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
613612
parser.add_argument("--max-grad-norm",
614613
type=float,
615614
help="Maximum gradient norm")
616-
parser.add_argument("--gradient-checkpointing",
617-
action=StoreBoolean,
618-
help="Whether to use gradient checkpointing")
619-
parser.add_argument("--gradient-checkpointing-type",
615+
parser.add_argument("--enable-gradient-checkpointing-type",
620616
type=str,
621617
choices=["full", "ops", "block_skip"],
622-
default="full",
618+
default=None,
623619
help="Gradient checkpointing type")
624620
parser.add_argument("--selective-checkpointing",
625621
type=float,

fastvideo/v1/training/training_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ def initialize_training_pipeline(self, training_args: TrainingArgs):
8484

8585
self.transformer.requires_grad_(True)
8686
self.transformer.train()
87-
88-
if training_args.gradient_checkpointing:
87+
if training_args.enable_gradient_checkpointing_type is not None:
8988
self.transformer = apply_activation_checkpointing(
9089
self.transformer,
91-
checkpointing_type=training_args.gradient_checkpointing_type)
90+
checkpointing_type=training_args.
91+
enable_gradient_checkpointing_type)
9292

9393
noise_scheduler = self.modules["scheduler"]
9494
params_to_optimize = self.transformer.parameters()

fastvideo/v1/training/training_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ def save_checkpoint(transformer,
162162
weight_path,
163163
local_main_process_only=False)
164164

165-
# Convert training format to diffusers format and save
166-
diffusers_state_dict = convert_training_to_diffusers_format(
165+
# Convert fastvideo custom format to diffusers format and save
166+
diffusers_state_dict = convert_custom_format_to_diffusers_format(
167167
cpu_state, transformer)
168168
save_file(diffusers_state_dict, weight_path)
169169

@@ -487,10 +487,10 @@ def _has_foreach_support(tensors: List[torch.Tensor],
487487
t is None or type(t) in [torch.Tensor] for t in tensors)
488488

489489

490-
def convert_training_to_diffusers_format(state_dict: Dict[str, Any],
491-
transformer) -> Dict[str, Any]:
490+
def convert_custom_format_to_diffusers_format(state_dict: Dict[str, Any],
491+
transformer) -> Dict[str, Any]:
492492
"""
493-
Convert training format state dict to diffusers format using reverse_param_names_mapping.
493+
Convert fastvideo custom format state dict to diffusers format using reverse_param_names_mapping.
494494
495495
Args:
496496
state_dict: State dict in training format

scripts/finetune/finetune_v1.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,5 @@ torchrun --nnodes 1 --nproc_per_node $NUM_GPUS\
4848
--weight_decay 0.01 \
4949
--not_apply_cfg_solver \
5050
--dit_precision "fp32" \
51-
--max_grad_norm 1.0
51+
--max_grad_norm 1.0 \
52+
--enable_gradient_checkpointing_type "full"

scripts/finetune/finetune_v1_VSA.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,6 @@ torchrun --nnodes 1 --nproc_per_node $NUM_GPUS \
5858
--VSA_decay_sparsity 0.9 \
5959
--VSA_decay_rate 0.03 \
6060
--VSA_decay_interval_steps 30 \
61-
--VSA_val_sparsity 0.9
61+
--VSA_val_sparsity 0.9 \
62+
--enable_gradient_checkpointing_type "full"
6263
# --resume_from_checkpoint "$CHECKPOINT_PATH"

0 commit comments

Comments
 (0)