Skip to content

Commit bdfdf1d

Browse files
authored
[Training] Add distributed checkpointing (#458)
1 parent d156461 commit bdfdf1d

File tree

6 files changed

+322
-37
lines changed

6 files changed

+322
-37
lines changed

fastvideo/v1/fastvideo_args.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ class TrainingArgs(FastVideoArgs):
472472
validation_steps: float = 0.0
473473
log_validation: bool = False
474474
tracker_project_name: str = ""
475-
# seed: int
475+
seed: Optional[int] = None
476476

477477
# output
478478
output_dir: str = ""
@@ -630,6 +630,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
630630
parser.add_argument("--tracker-project-name",
631631
type=str,
632632
help="Project name for tracking")
633+
parser.add_argument("--seed",
634+
type=int,
635+
help="Seed for deterministic training")
633636

634637
# Output configuration
635638
parser.add_argument("--output-dir",
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import random
2+
from typing import Any, Dict, Optional
3+
4+
import numpy as np
5+
import torch
6+
import torch.distributed.checkpoint.stateful
7+
from torch.distributed.checkpoint.state_dict import (StateDictOptions,
8+
get_model_state_dict,
9+
get_optimizer_state_dict,
10+
set_model_state_dict,
11+
set_optimizer_state_dict)
12+
13+
14+
class ModelWrapper(torch.distributed.checkpoint.stateful.Stateful):
15+
16+
def __init__(self, model: torch.nn.Module) -> None:
17+
self.model = model
18+
19+
def state_dict(self) -> Dict[str, Any]:
20+
return get_model_state_dict(self.model) # type: ignore[no-any-return]
21+
22+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
23+
set_model_state_dict(
24+
self.model,
25+
model_state_dict=state_dict,
26+
options=StateDictOptions(strict=False),
27+
)
28+
29+
30+
class OptimizerWrapper(torch.distributed.checkpoint.stateful.Stateful):
31+
32+
def __init__(self, model: torch.nn.Module,
33+
optimizer: torch.optim.Optimizer) -> None:
34+
self.model = model
35+
self.optimizer = optimizer
36+
37+
def state_dict(self) -> Dict[str, Any]:
38+
return get_optimizer_state_dict( # type: ignore[no-any-return]
39+
self.model,
40+
self.optimizer,
41+
options=StateDictOptions(flatten_optimizer_state_dict=True),
42+
)
43+
44+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
45+
set_optimizer_state_dict(
46+
self.model,
47+
self.optimizer,
48+
optim_state_dict=state_dict,
49+
options=StateDictOptions(flatten_optimizer_state_dict=True),
50+
)
51+
52+
53+
class SchedulerWrapper(torch.distributed.checkpoint.stateful.Stateful):
54+
55+
def __init__(self, scheduler) -> None:
56+
self.scheduler = scheduler
57+
58+
def state_dict(self) -> Dict[str, Any]:
59+
return {"scheduler": self.scheduler.state_dict()}
60+
61+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
62+
self.scheduler.load_state_dict(state_dict["scheduler"])
63+
64+
65+
class RandomStateWrapper(torch.distributed.checkpoint.stateful.Stateful):
66+
67+
def __init__(self,
68+
noise_generator: Optional[torch.Generator] = None) -> None:
69+
self.noise_generator = noise_generator
70+
71+
def state_dict(self) -> Dict[str, Any]:
72+
state = {
73+
"torch_rng_state": torch.get_rng_state(),
74+
"numpy_rng_state": np.random.get_state(),
75+
"python_rng_state": random.getstate(),
76+
}
77+
78+
if torch.cuda.is_available():
79+
state["cuda_rng_state"] = torch.cuda.get_rng_state()
80+
if torch.cuda.device_count() > 1:
81+
state["cuda_rng_state_all"] = torch.cuda.get_rng_state_all()
82+
83+
if self.noise_generator is not None:
84+
state["noise_generator_state"] = self.noise_generator.get_state()
85+
86+
return state
87+
88+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
89+
if "torch_rng_state" in state_dict:
90+
torch.set_rng_state(state_dict["torch_rng_state"])
91+
92+
if "numpy_rng_state" in state_dict:
93+
np.random.set_state(state_dict["numpy_rng_state"])
94+
95+
if "python_rng_state" in state_dict:
96+
random.setstate(state_dict["python_rng_state"])
97+
98+
# Restore CUDA random state
99+
if torch.cuda.is_available():
100+
if "cuda_rng_state" in state_dict:
101+
torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
102+
if "cuda_rng_state_all" in state_dict:
103+
torch.cuda.set_rng_state_all(state_dict["cuda_rng_state_all"])
104+
105+
# Restore noise generator state
106+
if "noise_generator_state" in state_dict and self.noise_generator is not None:
107+
self.noise_generator.set_state(state_dict["noise_generator_state"])

fastvideo/v1/training/training_pipeline.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,13 @@ def log_validation(self, transformer, training_args, global_step) -> None:
135135
# Create sampling parameters if not provided
136136
sampling_param = SamplingParam.from_pretrained(training_args.model_path)
137137

138+
# Set deterministic seed for validation
139+
validation_seed = training_args.seed if training_args.seed is not None else 42
140+
torch.manual_seed(validation_seed)
141+
torch.cuda.manual_seed_all(validation_seed)
142+
143+
logger.info("Using validation seed: %s", validation_seed)
144+
138145
# Prepare validation prompts
139146
logger.info('fastvideo_args.validation_prompt_dir: %s',
140147
training_args.validation_prompt_dir)
@@ -192,7 +199,7 @@ def log_validation(self, transformer, training_args, global_step) -> None:
192199
batch = ForwardBatch(
193200
data_type="video",
194201
latents=None,
195-
# seed=sampling_param.seed,
202+
seed=validation_seed, # Use deterministic seed
196203
prompt_embeds=[prompt_embeds],
197204
prompt_attention_mask=[prompt_attention_mask],
198205
# make sure we use the same height, width, and num_frames as the training pipeline
@@ -206,7 +213,6 @@ def log_validation(self, transformer, training_args, global_step) -> None:
206213
n_tokens=n_tokens,
207214
do_classifier_free_guidance=False,
208215
eta=0.0,
209-
extra={},
210216
)
211217

212218
# Run validation inference

fastvideo/v1/training/training_utils.py

Lines changed: 161 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,57 @@
11
import json
22
import math
33
import os
4-
from typing import List, Optional, Tuple, Union
4+
import time
5+
from typing import Any, Dict, List, Optional, Tuple, Union
56

67
import torch
78
import torch.distributed as dist
8-
from torch.distributed.fsdp import FullStateDictConfig
9-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
10-
from torch.distributed.fsdp import StateDictType
9+
import torch.distributed.checkpoint as dcp
10+
import torch.distributed.checkpoint.stateful
11+
from safetensors.torch import save_file
1112

1213
from fastvideo.v1.logger import init_logger
14+
from fastvideo.v1.training.checkpointing_utils import (ModelWrapper,
15+
OptimizerWrapper,
16+
RandomStateWrapper,
17+
SchedulerWrapper)
1318

1419
logger = init_logger(__name__)
1520

1621
_HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES = False
1722

1823

24+
def gather_state_dict_on_cpu_rank0(
25+
model,
26+
device: Optional[torch.device] = None,
27+
) -> Dict[str, Any]:
28+
rank = dist.get_rank()
29+
cpu_state_dict = {}
30+
sharded_sd = model.state_dict()
31+
for param_name, param in sharded_sd.items():
32+
if hasattr(param, "_local_tensor"):
33+
# DTensor case
34+
if param.is_cpu:
35+
# Gather directly on CPU
36+
param = param.full_tensor()
37+
else:
38+
if device is not None:
39+
param = param.to(device)
40+
param = param.full_tensor()
41+
else:
42+
# Regular tensor case
43+
if param.is_cpu:
44+
pass
45+
else:
46+
if device is not None:
47+
param = param.to(device)
48+
49+
if rank == 0:
50+
cpu_state_dict[param_name] = param.cpu()
51+
52+
return cpu_state_dict
53+
54+
1955
def compute_density_for_timestep_sampling(
2056
weighting_scheme: str,
2157
batch_size: int,
@@ -66,24 +102,67 @@ def get_sigmas(noise_scheduler,
66102
return sigma
67103

68104

69-
def save_checkpoint(transformer, rank, output_dir, step) -> None:
70-
# Configure FSDP to save full state dict
71-
FSDP.set_state_dict_type(
72-
transformer,
73-
state_dict_type=StateDictType.FULL_STATE_DICT,
74-
state_dict_config=FullStateDictConfig(offload_to_cpu=True,
75-
rank0_only=True),
76-
)
77-
78-
# Now get the state dict
79-
cpu_state = transformer.state_dict()
80-
81-
# Save it (only on rank 0 since we used rank0_only=True)
82-
if rank <= 0:
83-
save_dir = os.path.join(output_dir, f"checkpoint-{step}")
84-
os.makedirs(save_dir, exist_ok=True)
85-
weight_path = os.path.join(save_dir, "diffusion_pytorch_model.pt")
86-
torch.save(cpu_state, weight_path)
105+
def save_checkpoint(transformer,
106+
rank,
107+
output_dir,
108+
step,
109+
optimizer=None,
110+
dataloader=None,
111+
scheduler=None,
112+
noise_generator=None) -> None:
113+
"""
114+
Save checkpoint following finetrainer's distributed checkpoint approach.
115+
Saves both distributed checkpoint and consolidated model weights.
116+
"""
117+
save_dir = os.path.join(output_dir, f"checkpoint-{step}")
118+
os.makedirs(save_dir, exist_ok=True)
119+
120+
states = {
121+
"model": ModelWrapper(transformer),
122+
"random_state": RandomStateWrapper(noise_generator),
123+
}
124+
125+
if optimizer is not None:
126+
states["optimizer"] = OptimizerWrapper(transformer, optimizer)
127+
128+
if dataloader is not None:
129+
states["dataloader"] = dataloader
130+
131+
if scheduler is not None:
132+
states["scheduler"] = SchedulerWrapper(scheduler)
133+
134+
dcp_dir = os.path.join(save_dir, "distributed_checkpoint")
135+
logger.info("rank: %s, saving distributed checkpoint to %s",
136+
rank,
137+
dcp_dir,
138+
local_main_process_only=False)
139+
140+
begin_time = time.perf_counter()
141+
dcp.save(states, checkpoint_id=dcp_dir)
142+
end_time = time.perf_counter()
143+
144+
logger.info("rank: %s, distributed checkpoint saved in %.2f seconds",
145+
rank,
146+
end_time - begin_time,
147+
local_main_process_only=False)
148+
149+
cpu_state = gather_state_dict_on_cpu_rank0(transformer, device=None)
150+
151+
if rank == 0:
152+
# Save model weights (consolidated)
153+
weight_path = os.path.join(save_dir,
154+
"diffusion_pytorch_model.safetensors")
155+
logger.info("rank: %s, saving consolidated checkpoint to %s",
156+
rank,
157+
weight_path,
158+
local_main_process_only=False)
159+
save_file(cpu_state, weight_path)
160+
logger.info("rank: %s, consolidated checkpoint saved to %s",
161+
rank,
162+
weight_path,
163+
local_main_process_only=False)
164+
165+
# Save model config
87166
config_dict = transformer.hf_config
88167
if "dtype" in config_dict:
89168
del config_dict["dtype"] # TODO
@@ -94,6 +173,66 @@ def save_checkpoint(transformer, rank, output_dir, step) -> None:
94173
logger.info("--> checkpoint saved at step %s to %s", step, weight_path)
95174

96175

176+
def load_checkpoint(transformer,
177+
rank,
178+
checkpoint_path,
179+
optimizer=None,
180+
dataloader=None,
181+
scheduler=None,
182+
noise_generator=None) -> int:
183+
"""
184+
Load checkpoint following finetrainer's distributed checkpoint approach.
185+
Returns the step number from which training should resume.
186+
"""
187+
if not os.path.exists(checkpoint_path):
188+
logger.warning("Checkpoint path %s does not exist", checkpoint_path)
189+
return 0
190+
191+
# Extract step number from checkpoint path
192+
step = int(os.path.basename(checkpoint_path).split('-')[-1])
193+
194+
if rank == 0:
195+
logger.info("Loading checkpoint from step %s", step)
196+
197+
dcp_dir = os.path.join(checkpoint_path, "distributed_checkpoint")
198+
199+
if not os.path.exists(dcp_dir):
200+
logger.warning("Distributed checkpoint directory %s does not exist",
201+
dcp_dir)
202+
return 0
203+
204+
states = {
205+
"model": ModelWrapper(transformer),
206+
"random_state": RandomStateWrapper(noise_generator),
207+
}
208+
209+
if optimizer is not None:
210+
states["optimizer"] = OptimizerWrapper(transformer, optimizer)
211+
212+
if dataloader is not None:
213+
states["dataloader"] = dataloader
214+
215+
if scheduler is not None:
216+
states["scheduler"] = SchedulerWrapper(scheduler)
217+
218+
logger.info("rank: %s, loading distributed checkpoint from %s",
219+
rank,
220+
dcp_dir,
221+
local_main_process_only=False)
222+
223+
begin_time = time.perf_counter()
224+
dcp.load(states, checkpoint_id=dcp_dir)
225+
end_time = time.perf_counter()
226+
227+
logger.info("rank: %s, distributed checkpoint loaded in %.2f seconds",
228+
rank,
229+
end_time - begin_time,
230+
local_main_process_only=False)
231+
logger.info("--> checkpoint loaded from step %s", step)
232+
233+
return step
234+
235+
97236
def normalize_dit_input(model_type, latents, args=None) -> torch.Tensor:
98237
if model_type == "hunyuan_hf" or model_type == "hunyuan":
99238
return latents * 0.476986

0 commit comments

Comments
 (0)