Skip to content

Commit f89d869

Browse files
[Feature][Training]Add diffusers format checkpoint saving for inference (#542)
1 parent 8741d20 commit f89d869

File tree

10 files changed

+98
-7
lines changed

10 files changed

+98
-7
lines changed

fastvideo/v1/configs/models/dits/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class DiTArchConfig(ArchConfig):
1212
_fsdp_shard_conditions: list = field(default_factory=list)
1313
_compile_conditions: list = field(default_factory=list)
1414
_param_names_mapping: dict = field(default_factory=dict)
15+
_reverse_param_names_mapping: dict = field(default_factory=dict)
1516
_lora_param_names_mapping: dict = field(default_factory=dict)
1617
_supported_attention_backends: Tuple[AttentionBackendEnum, ...] = (
1718
AttentionBackendEnum.SLIDING_TILE_ATTN, AttentionBackendEnum.SAGE_ATTN,

fastvideo/v1/configs/models/dits/hunyuanvideo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ class HunyuanVideoArchConfig(DiTArchConfig):
147147
r"final_layer.linear.\1",
148148
})
149149

150+
# Reverse mapping for saving checkpoints: training -> diffusers
151+
_reverse_param_names_mapping: dict = field(default_factory=lambda: {})
152+
150153
patch_size: int = 2
151154
patch_size_t: int = 1
152155
in_channels: int = 16

fastvideo/v1/configs/models/dits/wanvideo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,13 @@ class WanVideoArchConfig(DiTArchConfig):
4949
r"blocks.\1.ffn.fc_in.\2",
5050
r"^blocks\.(\d+)\.ffn\.net\.2\.(.*)$":
5151
r"blocks.\1.ffn.fc_out.\2",
52-
r"blocks\.(\d+)\.norm2\.(.*)$":
52+
r"^blocks\.(\d+)\.norm2\.(.*)$":
5353
r"blocks.\1.self_attn_residual_norm.norm.\2",
5454
})
55+
56+
# Reverse mapping for saving checkpoints: training -> diffusers
57+
_reverse_param_names_mapping: dict = field(default_factory=lambda: {})
58+
5559
# Some LoRA adapters use the original official layer names instead of hf layer names,
5660
# so apply this before the param_names_mapping
5761
_lora_param_names_mapping: dict = field(

fastvideo/v1/models/dits/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class BaseDiT(nn.Module, ABC):
1414
_fsdp_shard_conditions: list = []
1515
_compile_conditions: list = []
1616
_param_names_mapping: dict
17+
_reverse_param_names_mapping: dict
1718
hidden_size: int
1819
num_attention_heads: int
1920
num_channels_latents: int
@@ -78,6 +79,7 @@ class CachableDiT(BaseDiT):
7879
# These are required class attributes that should be overridden by concrete implementations
7980
_fsdp_shard_conditions = []
8081
_param_names_mapping = {}
82+
_reverse_param_names_mapping = {}
8183
_lora_param_names_mapping: dict = {}
8284
# Ensure these instance attributes are properly defined in subclasses
8385
hidden_size: int

fastvideo/v1/models/dits/hunyuanvideo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,8 @@ class HunyuanVideoTransformer3DModel(CachableDiT):
442442
_supported_attention_backends = HunyuanVideoConfig(
443443
)._supported_attention_backends
444444
_param_names_mapping = HunyuanVideoConfig()._param_names_mapping
445+
_reverse_param_names_mapping = HunyuanVideoConfig(
446+
)._reverse_param_names_mapping
445447
_lora_param_names_mapping = HunyuanVideoConfig()._lora_param_names_mapping
446448

447449
def __init__(self, config: HunyuanVideoConfig, hf_config: dict[str, Any]):

fastvideo/v1/models/dits/stepvideo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,8 @@ class StepVideoModel(BaseDiT):
460460
# lambda n, m: "pos_embed" in n # If needed for the patch embedding.
461461
]
462462
_param_names_mapping = StepVideoConfig()._param_names_mapping
463+
_reverse_param_names_mapping = StepVideoConfig(
464+
)._reverse_param_names_mapping
463465
_lora_param_names_mapping = StepVideoConfig()._lora_param_names_mapping
464466
_supported_attention_backends = StepVideoConfig(
465467
)._supported_attention_backends

fastvideo/v1/models/dits/wanvideo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ class WanTransformer3DModel(CachableDiT):
518518
_supported_attention_backends = WanVideoConfig(
519519
)._supported_attention_backends
520520
_param_names_mapping = WanVideoConfig()._param_names_mapping
521+
_reverse_param_names_mapping = WanVideoConfig()._reverse_param_names_mapping
521522
_lora_param_names_mapping = WanVideoConfig()._lora_param_names_mapping
522523

523524
def __init__(self, config: WanVideoConfig, hf_config: dict[str,

fastvideo/v1/models/loader/fsdp_load.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,14 @@ def load_model_from_full_model_state_dict(
222222
used_keys = set()
223223
sharded_sd = {}
224224
to_merge_params: DefaultDict[str, Dict[Any, Any]] = defaultdict(dict)
225+
reverse_param_names_mapping = {}
226+
assert param_names_mapping is not None
225227
for source_param_name, full_tensor in full_sd_iterator:
226-
assert param_names_mapping is not None
227228
target_param_name, merge_index, num_params_to_merge = param_names_mapping(
228229
source_param_name)
230+
reverse_param_names_mapping[target_param_name] = (source_param_name,
231+
merge_index,
232+
num_params_to_merge)
229233
used_keys.add(target_param_name)
230234
if merge_index is not None:
231235
to_merge_params[target_param_name][merge_index] = full_tensor
@@ -260,6 +264,7 @@ def load_model_from_full_model_state_dict(
260264
sharded_tensor = sharded_tensor.cpu()
261265
sharded_sd[target_param_name] = nn.Parameter(sharded_tensor)
262266

267+
model._reverse_param_names_mapping = reverse_param_names_mapping
263268
unused_keys = set(meta_sd.keys()) - used_keys
264269
if unused_keys:
265270
logger.warning("Found new parameters in meta state dict: %s",

fastvideo/v1/tests/training/Vanilla/test_training_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def test_distributed_training():
116116
'avg_step_time': 1.0,
117117
'grad_norm': 0.2,
118118
'step_time': 0.5,
119-
'train_loss': 0.001
119+
'train_loss': 0.0025
120120
}
121121

122122
failures = []

fastvideo/v1/training/training_utils.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch
99
import torch.distributed as dist
1010
import torch.distributed.checkpoint as dcp
11-
import torch.distributed.checkpoint.stateful
1211
from einops import rearrange
1312
from safetensors.torch import save_file
1413

@@ -154,13 +153,20 @@ def save_checkpoint(transformer,
154153

155154
if rank == 0:
156155
# Save model weights (consolidated)
157-
weight_path = os.path.join(save_dir,
156+
transformer_save_dir = os.path.join(save_dir, "transformer")
157+
os.makedirs(transformer_save_dir, exist_ok=True)
158+
weight_path = os.path.join(transformer_save_dir,
158159
"diffusion_pytorch_model.safetensors")
159160
logger.info("rank: %s, saving consolidated checkpoint to %s",
160161
rank,
161162
weight_path,
162163
local_main_process_only=False)
163-
save_file(cpu_state, weight_path)
164+
165+
# Convert training format to diffusers format and save
166+
diffusers_state_dict = convert_training_to_diffusers_format(
167+
cpu_state, transformer)
168+
save_file(diffusers_state_dict, weight_path)
169+
164170
logger.info("rank: %s, consolidated checkpoint saved to %s",
165171
rank,
166172
weight_path,
@@ -170,7 +176,7 @@ def save_checkpoint(transformer,
170176
config_dict = transformer.hf_config
171177
if "dtype" in config_dict:
172178
del config_dict["dtype"] # TODO
173-
config_path = os.path.join(save_dir, "config.json")
179+
config_path = os.path.join(transformer_save_dir, "config.json")
174180
# save dict as json
175181
with open(config_path, "w") as f:
176182
json.dump(config_dict, f, indent=4)
@@ -479,3 +485,68 @@ def _has_foreach_support(tensors: List[torch.Tensor],
479485
device: torch.device) -> bool:
480486
return _device_has_foreach_support(device) and all(
481487
t is None or type(t) in [torch.Tensor] for t in tensors)
488+
489+
490+
def convert_training_to_diffusers_format(state_dict: Dict[str, Any],
491+
transformer) -> Dict[str, Any]:
492+
"""
493+
Convert training format state dict to diffusers format using reverse_param_names_mapping.
494+
495+
Args:
496+
state_dict: State dict in training format
497+
transformer: Transformer model object with _reverse_param_names_mapping
498+
499+
Returns:
500+
State dict in diffusers format
501+
"""
502+
new_state_dict = {}
503+
504+
# Get the reverse mapping from the transformer
505+
reverse_param_names_mapping = transformer._reverse_param_names_mapping
506+
assert reverse_param_names_mapping != {}, "reverse_param_names_mapping is empty"
507+
508+
# Group parameters that need to be split (merged parameters)
509+
merge_groups: Dict[str, List[Tuple[str, int, int]]] = {}
510+
511+
# First pass: collect all merge groups
512+
for training_key, (
513+
diffusers_key, merge_index,
514+
num_params_to_merge) in reverse_param_names_mapping.items():
515+
if merge_index is not None:
516+
# This is a merged parameter that needs to be split
517+
if training_key not in merge_groups:
518+
merge_groups[training_key] = []
519+
merge_groups[training_key].append(
520+
(diffusers_key, merge_index, num_params_to_merge))
521+
522+
# Second pass: handle merged parameters by splitting them
523+
used_keys = set()
524+
for training_key, splits in merge_groups.items():
525+
if training_key in state_dict:
526+
v = state_dict[training_key]
527+
# Sort by merge_index to ensure correct order
528+
splits.sort(key=lambda x: x[1])
529+
total = splits[0][2]
530+
split_size = v.shape[0] // total
531+
split_tensors = torch.split(v, split_size, dim=0)
532+
533+
for diffusers_key, split_index, _ in splits:
534+
new_state_dict[diffusers_key] = split_tensors[split_index]
535+
used_keys.add(training_key)
536+
537+
# Third pass: handle regular parameters (direct mappings)
538+
for training_key, v in state_dict.items():
539+
if training_key in used_keys:
540+
continue
541+
542+
if training_key in reverse_param_names_mapping:
543+
diffusers_key, merge_index, _ = reverse_param_names_mapping[
544+
training_key]
545+
if merge_index is None:
546+
# Direct mapping
547+
new_state_dict[diffusers_key] = v
548+
else:
549+
# No mapping found, keep as is
550+
new_state_dict[training_key] = v
551+
552+
return new_state_dict

0 commit comments

Comments
 (0)