Skip to content

Commit 6ef8fcb

Browse files
[Training] [4/n] add training save checkpoint (#441)
1 parent 016e24d commit 6ef8fcb

File tree

6 files changed

+62
-12
lines changed

6 files changed

+62
-12
lines changed

fastvideo/v1/models/dits/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ def __init_subclass__(cls) -> None:
3333
f"Subclasses of BaseDiT must define '{attr}' class variable"
3434
)
3535

36-
def __init__(self, config: DiTConfig, **kwargs) -> None:
36+
def __init__(self, config: DiTConfig, hf_config: dict[str, Any],
37+
**kwargs) -> None:
3738
super().__init__()
3839
self.config = config
40+
self.hf_config = hf_config
3941
if not self.supported_attention_backends:
4042
raise ValueError(
4143
f"Subclass {self.__class__.__name__} must define _supported_attention_backends"

fastvideo/v1/models/dits/hunyuanvideo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import List, Optional, Tuple, Union
3+
from typing import Any, List, Optional, Tuple, Union
44

55
import numpy as np
66
import torch
@@ -442,8 +442,8 @@ class HunyuanVideoTransformer3DModel(CachableDiT):
442442
)._supported_attention_backends
443443
_param_names_mapping = HunyuanVideoConfig()._param_names_mapping
444444

445-
def __init__(self, config: HunyuanVideoConfig):
446-
super().__init__(config=config)
445+
def __init__(self, config: HunyuanVideoConfig, hf_config: dict[str, Any]):
446+
super().__init__(config=config, hf_config=hf_config)
447447

448448
self.patch_size = [
449449
config.patch_size_t, config.patch_size, config.patch_size

fastvideo/v1/models/dits/stepvideo.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# The above copyright notice and this permission notice shall be included in all
1111
# copies or substantial portions of the Software.
1212
# ==============================================================================
13-
from typing import Dict, Optional, Tuple
13+
from typing import Any, Dict, Optional, Tuple
1414

1515
import torch
1616
from einops import rearrange, repeat
@@ -462,8 +462,9 @@ class StepVideoModel(BaseDiT):
462462
_supported_attention_backends = StepVideoConfig(
463463
)._supported_attention_backends
464464

465-
def __init__(self, config: StepVideoConfig) -> None:
466-
super().__init__(config=config)
465+
def __init__(self, config: StepVideoConfig, hf_config: dict[str,
466+
Any]) -> None:
467+
super().__init__(config=config, hf_config=hf_config)
467468
self.num_attention_heads = config.num_attention_heads
468469
self.attention_head_dim = config.attention_head_dim
469470
self.in_channels = config.in_channels

fastvideo/v1/models/dits/wanvideo.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import math
4-
from typing import List, Optional, Tuple, Union
4+
from typing import Any, List, Optional, Tuple, Union
55

66
import numpy as np
77
import torch
@@ -298,7 +298,7 @@ def forward(
298298
hidden_states = hidden_states.squeeze(1)
299299
bs, seq_length, _ = hidden_states.shape
300300
orig_dtype = hidden_states.dtype
301-
assert orig_dtype != torch.float32
301+
# assert orig_dtype != torch.float32
302302
e = self.scale_shift_table + temb.float()
303303
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = e.chunk(
304304
6, dim=1)
@@ -360,8 +360,9 @@ class WanTransformer3DModel(CachableDiT):
360360
)._supported_attention_backends
361361
_param_names_mapping = WanVideoConfig()._param_names_mapping
362362

363-
def __init__(self, config: WanVideoConfig) -> None:
364-
super().__init__(config=config)
363+
def __init__(self, config: WanVideoConfig, hf_config: dict[str,
364+
Any]) -> None:
365+
super().__init__(config=config, hf_config=hf_config)
365366

366367
inner_dim = config.num_attention_heads * config.attention_head_dim
367368
self.hidden_size = config.hidden_size

fastvideo/v1/models/loader/component_loader.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import time
88
from abc import ABC, abstractmethod
9+
from copy import deepcopy
910
from typing import Any, Generator, Iterable, List, Optional, Tuple, cast
1011

1112
import torch
@@ -366,6 +367,7 @@ def load(self, model_path: str, architecture: str,
366367
fastvideo_args: FastVideoArgs):
367368
"""Load the transformer based on the model path, architecture, and inference args."""
368369
config = get_diffusers_config(model=model_path)
370+
hf_config = deepcopy(config)
369371
cls_name = config.pop("_class_name")
370372
if cls_name is None:
371373
raise ValueError(
@@ -394,7 +396,10 @@ def load(self, model_path: str, architecture: str,
394396
# Load the model using FSDP loader
395397
logger.info("Loading model from %s", cls_name)
396398
model = load_fsdp_model(model_cls=model_cls,
397-
init_params={"config": dit_config},
399+
init_params={
400+
"config": dit_config,
401+
"hf_config": hf_config
402+
},
398403
weight_dir_list=safetensors_list,
399404
device=fastvideo_args.device,
400405
cpu_offload=fastvideo_args.use_cpu_offload,
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import json
2+
import os
3+
4+
import torch
5+
from torch.distributed.fsdp import FullStateDictConfig
6+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7+
from torch.distributed.fsdp import StateDictType
8+
9+
from fastvideo.v1.logger import init_logger
10+
11+
logger = init_logger(__name__)
12+
13+
14+
def save_checkpoint(transformer, rank, output_dir, step):
15+
# Configure FSDP to save full state dict
16+
FSDP.set_state_dict_type(
17+
transformer,
18+
state_dict_type=StateDictType.FULL_STATE_DICT,
19+
state_dict_config=FullStateDictConfig(offload_to_cpu=True,
20+
rank0_only=True),
21+
)
22+
23+
# Now get the state dict
24+
cpu_state = transformer.state_dict()
25+
26+
# Save it (only on rank 0 since we used rank0_only=True)
27+
if rank <= 0:
28+
save_dir = os.path.join(output_dir, f"checkpoint-{step}")
29+
os.makedirs(save_dir, exist_ok=True)
30+
weight_path = os.path.join(save_dir, "diffusion_pytorch_model.pt")
31+
torch.save(cpu_state, weight_path)
32+
config_dict = transformer.hf_config
33+
if "dtype" in config_dict:
34+
del config_dict["dtype"] # TODO
35+
config_path = os.path.join(save_dir, "config.json")
36+
# save dict as json
37+
with open(config_path, "w") as f:
38+
json.dump(config_dict, f, indent=4)
39+
logger.info("--> checkpoint saved at step {step} to {weight_path}",
40+
step=step,
41+
weight_path=weight_path)

0 commit comments

Comments
 (0)