Skip to content

Commit 130229f

Browse files
[checkpointio]support asyncio for 3d (#6152)
* fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent aaafb38 commit 130229f

17 files changed

+774
-186
lines changed

colossalai/booster/plugin/gemini_plugin.py

Lines changed: 85 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from colossalai.accelerator import get_accelerator
1818
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
1919
from colossalai.checkpoint_io.utils import (
20+
async_save_state_dict_shards,
21+
create_pinned_state_dict,
2022
get_model_base_filenames,
2123
get_optimizer_base_filenames,
2224
load_shard_state_dict,
@@ -28,6 +30,7 @@
2830
from colossalai.interface import ModelWrapper, OptimizerWrapper
2931
from colossalai.logging import get_dist_logger
3032
from colossalai.shardformer import ShardConfig, ShardFormer
33+
from colossalai.utils.safetensors import load_flat
3134
from colossalai.zero import GeminiDDP, GeminiOptimizer
3235
from colossalai.zero.gemini.memory_tracer import MemStats
3336

@@ -82,7 +85,15 @@ def save_unsharded_model(
8285
state_dict = model.state_dict(only_rank_0=True)
8386
if self.coordinator.is_master():
8487
if use_async:
85-
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
88+
from colossalai.utils.safetensors import save
89+
90+
if id(model) not in self.pinned_state_dicts:
91+
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
92+
for k, v in state_dict.items():
93+
self.pinned_state_dicts[id(model)][k].copy_(v)
94+
state_dict[k] = self.pinned_state_dicts[id(model)][k]
95+
writer = save(checkpoint, state_dict)
96+
self.async_writers.append(writer)
8697
else:
8798
save_state_dict(state_dict, checkpoint, use_safetensors)
8899

@@ -106,7 +117,19 @@ def save_unsharded_optimizer(
106117
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
107118
state_dict = optimizer.state_dict()
108119
if self.coordinator.is_master():
109-
save_state_dict(state_dict, checkpoint, use_safetensors=False)
120+
if use_async:
121+
from colossalai.utils.safetensors import _flatten_optim_state_dict, save
122+
123+
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
124+
if id(optimizer) not in self.pinned_state_dicts:
125+
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
126+
for k, v in flatten_state_dict.items():
127+
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
128+
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
129+
writer = save(checkpoint, flatten_state_dict, metadata)
130+
self.async_writers.append(writer)
131+
else:
132+
save_state_dict(state_dict, checkpoint, use_safetensors=False)
110133

111134
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
112135
"""
@@ -137,17 +160,29 @@ def save_sharded_model(
137160

138161
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
139162

140-
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True)
163+
if use_async and self.coordinator.is_master():
164+
if id(model) not in self.pinned_state_dicts:
165+
self.pinned_state_dicts[id(model)] = {}
166+
pinned_state_dicts = self.pinned_state_dicts[id(model)]
167+
else:
168+
pinned_state_dicts = None
169+
state_dict_shard = model.state_dict_shard(
170+
max_shard_size=max_shard_size, only_rank_0=True, pinned_state_dicts=pinned_state_dicts
171+
)
141172
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
142173
index_file = CheckpointIndexFile(checkpoint_path)
143174

144175
# Save shards of optimizer states.
145176
is_master = self.coordinator.is_master()
146177
if use_async:
147-
super().save_sharded_model(
148-
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async
178+
total_size, writers = async_save_state_dict_shards(
179+
sharded_state_dict=state_dict_shard,
180+
checkpoint=checkpoint_path,
181+
index_file=index_file,
182+
base_filename=weights_name,
183+
is_master=is_master,
149184
)
150-
185+
self.async_writers.extend(writers)
151186
else:
152187
total_size = save_state_dict_shards(
153188
sharded_state_dict=state_dict_shard,
@@ -158,17 +193,17 @@ def save_sharded_model(
158193
use_safetensors=use_safetensors,
159194
)
160195

161-
# only save the index file on the master rank
162-
if self.coordinator.is_master():
163-
index_file.append_meta_data("total_size", total_size)
164-
index_file.write_index_file(save_index_file)
165-
save_config_file(model.unwrap(), checkpoint_path)
166-
self.logger.info(
167-
f"The model is split into checkpoint shards. "
168-
f"You can find where each parameters has been saved in the "
169-
f"index located at {save_index_file}.",
170-
ranks=[0],
171-
)
196+
# only save the index file on the master rank
197+
if self.coordinator.is_master():
198+
index_file.append_meta_data("total_size", total_size)
199+
index_file.write_index_file(save_index_file)
200+
save_config_file(model.unwrap(), checkpoint_path)
201+
self.logger.info(
202+
f"The model is split into checkpoint shards. "
203+
f"You can find where each parameters has been saved in the "
204+
f"index located at {save_index_file}.",
205+
ranks=[0],
206+
)
172207

173208
def load_sharded_model(
174209
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
@@ -201,7 +236,7 @@ def save_sharded_optimizer(
201236
Path(checkpoint).mkdir(parents=True, exist_ok=True)
202237

203238
# Preparing file paths and index file.
204-
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
239+
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
205240
index_file = CheckpointIndexFile(checkpoint)
206241
index_file.append_meta_data("param_groups", param_group_file)
207242

@@ -212,17 +247,36 @@ def save_sharded_optimizer(
212247
torch.save(param_groups, group_file_path)
213248

214249
# States are broken into shards within max_shard_size.
215-
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
250+
if use_async and self.coordinator.is_master():
251+
if id(optimizer) not in self.pinned_state_dicts:
252+
self.pinned_state_dicts[id(optimizer)] = {}
253+
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
254+
else:
255+
pinned_state_dicts = None
256+
state_dict_shard = optimizer.state_shard(
257+
prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True, pinned_state_dicts=pinned_state_dicts
258+
)
216259

217260
# Save shards of optimizer states.
218-
total_size = save_state_dict_shards(
219-
sharded_state_dict=state_dict_shard,
220-
checkpoint=checkpoint,
221-
index_file=index_file,
222-
base_filename=states_name,
223-
is_master=self.coordinator.is_master(),
224-
use_safetensors=False,
225-
)
261+
if use_async:
262+
total_size, writers = async_save_state_dict_shards(
263+
sharded_state_dict=state_dict_shard,
264+
checkpoint=checkpoint,
265+
index_file=index_file,
266+
base_filename=states_name,
267+
is_master=self.coordinator.is_master(),
268+
state_preprocess=True,
269+
)
270+
self.async_writers.extend(writers)
271+
else:
272+
total_size = save_state_dict_shards(
273+
sharded_state_dict=state_dict_shard,
274+
checkpoint=checkpoint,
275+
index_file=index_file,
276+
base_filename=states_name,
277+
is_master=self.coordinator.is_master(),
278+
use_safetensors=False,
279+
)
226280

227281
# Wrap up index file. Only save it on master rank.
228282
if self.coordinator.is_master():
@@ -264,7 +318,10 @@ def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_fi
264318
# Load optimizer states from shard files under checkpoint path.
265319
# For each file, only load the states managed by current process.
266320
for shard_file in checkpoint_files:
267-
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
321+
if shard_file.endswith(".safetensors"):
322+
state_dict_shard = load_flat(shard_file)
323+
else:
324+
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
268325
optimizer.load_param_states(state_dict_shard)
269326
del state_dict_shard
270327
gc.collect()

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1488,7 +1488,7 @@ def seed_worker(worker_id):
14881488
)
14891489

14901490
def get_checkpoint_io(self) -> CheckpointIO:
1491-
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
1491+
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage)
14921492

14931493
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
14941494
assert (

colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,13 @@ def __init__(
404404

405405
def get_checkpoint_io(self) -> MoECheckpointIO:
406406
return MoECheckpointIO(
407-
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
407+
self.dp_group,
408+
self.pp_group,
409+
self.tp_group,
410+
self.sp_group,
411+
self.ep_group,
412+
self.moe_dp_group,
413+
self.zero_stage,
408414
)
409415

410416
def configure(

colossalai/booster/plugin/torch_ddp_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def save_unsharded_optimizer(
6060
"""
6161
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
6262
if self.coordinator.is_master():
63-
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
63+
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)
6464

6565
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
6666
"""

0 commit comments

Comments
 (0)