11import json
22import math
33import os
4- from typing import List , Optional , Tuple , Union
4+ import time
5+ from typing import Any , Dict , List , Optional , Tuple , Union
56
67import torch
78import torch .distributed as dist
8- from torch .distributed .fsdp import FullStateDictConfig
9- from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
10- from torch . distributed . fsdp import StateDictType
9+ import torch .distributed .checkpoint as dcp
10+ import torch .distributed .checkpoint . stateful
11+ from safetensors . torch import save_file
1112
1213from fastvideo .v1 .logger import init_logger
14+ from fastvideo .v1 .training .checkpointing_utils import (ModelWrapper ,
15+ OptimizerWrapper ,
16+ RandomStateWrapper ,
17+ SchedulerWrapper )
1318
1419logger = init_logger (__name__ )
1520
1621_HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES = False
1722
1823
24+ def gather_state_dict_on_cpu_rank0 (
25+ model ,
26+ device : Optional [torch .device ] = None ,
27+ ) -> Dict [str , Any ]:
28+ rank = dist .get_rank ()
29+ cpu_state_dict = {}
30+ sharded_sd = model .state_dict ()
31+ for param_name , param in sharded_sd .items ():
32+ if hasattr (param , "_local_tensor" ):
33+ # DTensor case
34+ if param .is_cpu :
35+ # Gather directly on CPU
36+ param = param .full_tensor ()
37+ else :
38+ if device is not None :
39+ param = param .to (device )
40+ param = param .full_tensor ()
41+ else :
42+ # Regular tensor case
43+ if param .is_cpu :
44+ pass
45+ else :
46+ if device is not None :
47+ param = param .to (device )
48+
49+ if rank == 0 :
50+ cpu_state_dict [param_name ] = param .cpu ()
51+
52+ return cpu_state_dict
53+
54+
1955def compute_density_for_timestep_sampling (
2056 weighting_scheme : str ,
2157 batch_size : int ,
@@ -66,24 +102,67 @@ def get_sigmas(noise_scheduler,
66102 return sigma
67103
68104
69- def save_checkpoint (transformer , rank , output_dir , step ) -> None :
70- # Configure FSDP to save full state dict
71- FSDP .set_state_dict_type (
72- transformer ,
73- state_dict_type = StateDictType .FULL_STATE_DICT ,
74- state_dict_config = FullStateDictConfig (offload_to_cpu = True ,
75- rank0_only = True ),
76- )
77-
78- # Now get the state dict
79- cpu_state = transformer .state_dict ()
80-
81- # Save it (only on rank 0 since we used rank0_only=True)
82- if rank <= 0 :
83- save_dir = os .path .join (output_dir , f"checkpoint-{ step } " )
84- os .makedirs (save_dir , exist_ok = True )
85- weight_path = os .path .join (save_dir , "diffusion_pytorch_model.pt" )
86- torch .save (cpu_state , weight_path )
105+ def save_checkpoint (transformer ,
106+ rank ,
107+ output_dir ,
108+ step ,
109+ optimizer = None ,
110+ dataloader = None ,
111+ scheduler = None ,
112+ noise_generator = None ) -> None :
113+ """
114+ Save checkpoint following finetrainer's distributed checkpoint approach.
115+ Saves both distributed checkpoint and consolidated model weights.
116+ """
117+ save_dir = os .path .join (output_dir , f"checkpoint-{ step } " )
118+ os .makedirs (save_dir , exist_ok = True )
119+
120+ states = {
121+ "model" : ModelWrapper (transformer ),
122+ "random_state" : RandomStateWrapper (noise_generator ),
123+ }
124+
125+ if optimizer is not None :
126+ states ["optimizer" ] = OptimizerWrapper (transformer , optimizer )
127+
128+ if dataloader is not None :
129+ states ["dataloader" ] = dataloader
130+
131+ if scheduler is not None :
132+ states ["scheduler" ] = SchedulerWrapper (scheduler )
133+
134+ dcp_dir = os .path .join (save_dir , "distributed_checkpoint" )
135+ logger .info ("rank: %s, saving distributed checkpoint to %s" ,
136+ rank ,
137+ dcp_dir ,
138+ local_main_process_only = False )
139+
140+ begin_time = time .perf_counter ()
141+ dcp .save (states , checkpoint_id = dcp_dir )
142+ end_time = time .perf_counter ()
143+
144+ logger .info ("rank: %s, distributed checkpoint saved in %.2f seconds" ,
145+ rank ,
146+ end_time - begin_time ,
147+ local_main_process_only = False )
148+
149+ cpu_state = gather_state_dict_on_cpu_rank0 (transformer , device = None )
150+
151+ if rank == 0 :
152+ # Save model weights (consolidated)
153+ weight_path = os .path .join (save_dir ,
154+ "diffusion_pytorch_model.safetensors" )
155+ logger .info ("rank: %s, saving consolidated checkpoint to %s" ,
156+ rank ,
157+ weight_path ,
158+ local_main_process_only = False )
159+ save_file (cpu_state , weight_path )
160+ logger .info ("rank: %s, consolidated checkpoint saved to %s" ,
161+ rank ,
162+ weight_path ,
163+ local_main_process_only = False )
164+
165+ # Save model config
87166 config_dict = transformer .hf_config
88167 if "dtype" in config_dict :
89168 del config_dict ["dtype" ] # TODO
@@ -94,6 +173,66 @@ def save_checkpoint(transformer, rank, output_dir, step) -> None:
94173 logger .info ("--> checkpoint saved at step %s to %s" , step , weight_path )
95174
96175
176+ def load_checkpoint (transformer ,
177+ rank ,
178+ checkpoint_path ,
179+ optimizer = None ,
180+ dataloader = None ,
181+ scheduler = None ,
182+ noise_generator = None ) -> int :
183+ """
184+ Load checkpoint following finetrainer's distributed checkpoint approach.
185+ Returns the step number from which training should resume.
186+ """
187+ if not os .path .exists (checkpoint_path ):
188+ logger .warning ("Checkpoint path %s does not exist" , checkpoint_path )
189+ return 0
190+
191+ # Extract step number from checkpoint path
192+ step = int (os .path .basename (checkpoint_path ).split ('-' )[- 1 ])
193+
194+ if rank == 0 :
195+ logger .info ("Loading checkpoint from step %s" , step )
196+
197+ dcp_dir = os .path .join (checkpoint_path , "distributed_checkpoint" )
198+
199+ if not os .path .exists (dcp_dir ):
200+ logger .warning ("Distributed checkpoint directory %s does not exist" ,
201+ dcp_dir )
202+ return 0
203+
204+ states = {
205+ "model" : ModelWrapper (transformer ),
206+ "random_state" : RandomStateWrapper (noise_generator ),
207+ }
208+
209+ if optimizer is not None :
210+ states ["optimizer" ] = OptimizerWrapper (transformer , optimizer )
211+
212+ if dataloader is not None :
213+ states ["dataloader" ] = dataloader
214+
215+ if scheduler is not None :
216+ states ["scheduler" ] = SchedulerWrapper (scheduler )
217+
218+ logger .info ("rank: %s, loading distributed checkpoint from %s" ,
219+ rank ,
220+ dcp_dir ,
221+ local_main_process_only = False )
222+
223+ begin_time = time .perf_counter ()
224+ dcp .load (states , checkpoint_id = dcp_dir )
225+ end_time = time .perf_counter ()
226+
227+ logger .info ("rank: %s, distributed checkpoint loaded in %.2f seconds" ,
228+ rank ,
229+ end_time - begin_time ,
230+ local_main_process_only = False )
231+ logger .info ("--> checkpoint loaded from step %s" , step )
232+
233+ return step
234+
235+
97236def normalize_dit_input (model_type , latents , args = None ) -> torch .Tensor :
98237 if model_type == "hunyuan_hf" or model_type == "hunyuan" :
99238 return latents * 0.476986
0 commit comments