Skip to content

Commit 8e08c27

Browse files
wangbluover217
authored andcommitted
[ckpt] Add async ckpt api (#6136)
* fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix
1 parent d4a4360 commit 8e08c27

File tree

12 files changed

+172
-84
lines changed

12 files changed

+172
-84
lines changed

.github/workflows/build_on_pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ jobs:
117117
cd TensorNVMe
118118
conda install cmake
119119
pip install -r requirements.txt
120-
DISABLE_URING=1 pip install -v .
120+
DISABLE_URING=1 pip install -v --no-cache-dir .
121121
122122
- name: Store TensorNVMe Cache
123123
run: |

colossalai/booster/booster.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ def save_model(
325325
names to compose the keys in state_dict. Defaults to None.
326326
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
327327
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
328+
use_async (bool, optional): whether to save the state_dict of model asynchronously. Default: False.
328329
"""
329330
self.checkpoint_io.save_model(
330331
model,

colossalai/booster/plugin/gemini_plugin.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,14 @@ def __init__(self) -> None:
6565
self.coordinator = DistCoordinator()
6666
self.logger = get_dist_logger()
6767

68-
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
68+
def save_unsharded_model(
69+
self,
70+
model: GeminiDDP,
71+
checkpoint: str,
72+
gather_dtensor: bool,
73+
use_safetensors: bool,
74+
use_async: bool = False,
75+
):
6976
"""
7077
Save sharded model to checkpoint but only on master process.
7178
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
@@ -74,7 +81,10 @@ def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor
7481
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
7582
state_dict = model.state_dict(only_rank_0=True)
7683
if self.coordinator.is_master():
77-
save_state_dict(state_dict, checkpoint, use_safetensors)
84+
if use_async:
85+
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
86+
else:
87+
save_state_dict(state_dict, checkpoint, use_safetensors)
7888

7989
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
8090
"""
@@ -112,6 +122,7 @@ def save_sharded_model(
112122
prefix: Optional[str] = None,
113123
max_shard_size: int = 1024,
114124
use_safetensors: bool = False,
125+
use_async: bool = False,
115126
):
116127
"""
117128
Save sharded model.
@@ -130,27 +141,33 @@ def save_sharded_model(
130141

131142
# Save shards of optimizer states.
132143
is_master = self.coordinator.is_master()
133-
total_size = save_state_dict_shards(
134-
sharded_state_dict=state_dict_shard,
135-
checkpoint=checkpoint_path,
136-
index_file=index_file,
137-
base_filename=weights_name,
138-
is_master=is_master,
139-
use_safetensors=use_safetensors,
140-
)
144+
if use_async:
145+
super().save_sharded_model(
146+
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async
147+
)
141148

142-
# only save the index file on the master rank
143-
if self.coordinator.is_master():
144-
index_file.append_meta_data("total_size", total_size)
145-
index_file.write_index_file(save_index_file)
146-
save_config_file(model.unwrap(), checkpoint_path)
147-
self.logger.info(
148-
f"The model is split into checkpoint shards. "
149-
f"You can find where each parameters has been saved in the "
150-
f"index located at {save_index_file}.",
151-
ranks=[0],
149+
else:
150+
total_size = save_state_dict_shards(
151+
sharded_state_dict=state_dict_shard,
152+
checkpoint=checkpoint_path,
153+
index_file=index_file,
154+
base_filename=weights_name,
155+
is_master=is_master,
156+
use_safetensors=use_safetensors,
152157
)
153158

159+
# only save the index file on the master rank
160+
if self.coordinator.is_master():
161+
index_file.append_meta_data("total_size", total_size)
162+
index_file.write_index_file(save_index_file)
163+
save_config_file(model.unwrap(), checkpoint_path)
164+
self.logger.info(
165+
f"The model is split into checkpoint shards. "
166+
f"You can find where each parameters has been saved in the "
167+
f"index located at {save_index_file}.",
168+
ranks=[0],
169+
)
170+
154171
def load_sharded_model(
155172
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
156173
):

colossalai/booster/plugin/torch_fsdp_plugin.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path
5454
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
5555
optimizer.load_state_dict(sharded_osd)
5656

57-
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
57+
def save_unsharded_model(
58+
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
59+
):
5860
"""
5961
Save model to checkpoint but only on master process.
6062
"""
@@ -82,6 +84,7 @@ def save_sharded_model(
8284
prefix: Optional[str] = None,
8385
size_per_shard: int = 1024,
8486
use_safetensors: bool = False,
87+
use_async: bool = False,
8588
):
8689
"""
8790
Save model to checkpoint but only on master process.

colossalai/checkpoint_io/checkpoint_io_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,10 @@ def save_model(
176176

177177
if shard:
178178
self.save_sharded_model(
179-
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async=use_async
179+
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async
180180
)
181181
else:
182-
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
182+
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
183183

184184
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
185185
"""

colossalai/checkpoint_io/general_checkpoint_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def save_unsharded_model(
6161
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
6262
self.async_writers.append(writer)
6363
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
64-
else:
6564

65+
else:
6666
# save the checkpoint
6767
save_state_dict(state_dict, checkpoint, use_safetensors)
6868

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from .index_file import CheckpointIndexFile
2828
from .utils import (
2929
StateDictSharder,
30+
async_save_state_dict_shards,
31+
create_pinned_state_dict,
3032
gather_distributed_param,
3133
get_model_base_filenames,
3234
get_optimizer_base_filenames,
@@ -177,6 +179,7 @@ def save_sharded_model(
177179
prefix: Optional[str] = None,
178180
size_per_shard: int = 1024,
179181
use_safetensors: bool = False,
182+
use_async: bool = False,
180183
) -> None:
181184
"""
182185
Save sharded model checkpoint under the given checkpointing path.
@@ -194,6 +197,7 @@ def save_sharded_model(
194197
prefix (str, optional): Perfix of file to save. Defaults to None.
195198
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
196199
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
200+
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
197201
"""
198202

199203
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
@@ -219,24 +223,27 @@ def save_sharded_model(
219223

220224
if self.pp_size == 1:
221225
# When pipeline is not used, save the model shards as in general checkpointIO
222-
total_size = save_state_dict_shards(
223-
sharded_state_dict=state_dict_shard,
224-
checkpoint=checkpoint,
225-
index_file=index_file,
226-
base_filename=weights_name,
227-
is_master=control_saving,
228-
use_safetensors=use_safetensors,
229-
)
230-
if control_saving:
231-
index_file.append_meta_data("total_size", total_size)
232-
index_file.write_index_file(save_index_file)
233-
save_config_file(model, checkpoint)
234-
if self.verbose and self.coordinator.is_master():
235-
logging.info(
236-
f"The model is split into checkpoint shards. "
237-
f"You can find where each parameters has been saved in the "
238-
f"index located at {save_index_file}."
239-
)
226+
if use_async:
227+
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
228+
else:
229+
total_size = save_state_dict_shards(
230+
sharded_state_dict=state_dict_shard,
231+
checkpoint=checkpoint,
232+
index_file=index_file,
233+
base_filename=weights_name,
234+
is_master=control_saving,
235+
use_safetensors=use_safetensors,
236+
)
237+
if control_saving:
238+
index_file.append_meta_data("total_size", total_size)
239+
index_file.write_index_file(save_index_file)
240+
save_config_file(model, checkpoint)
241+
if self.verbose and self.coordinator.is_master():
242+
logging.info(
243+
f"The model is split into checkpoint shards. "
244+
f"You can find where each parameters has been saved in the "
245+
f"index located at {save_index_file}."
246+
)
240247

241248
else:
242249
# When pipeline is used, each stage produces its own shard files and index files.
@@ -251,7 +258,16 @@ def save_sharded_model(
251258
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
252259
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
253260
save_index_file = os.path.join("tmp_index_files", save_index_file)
254-
261+
if use_async:
262+
total_size, returned_state_dict, writers = async_save_state_dict_shards(
263+
sharded_state_dict=state_dict_shard,
264+
checkpoint=checkpoint,
265+
index_file=index_file,
266+
base_filename=weights_name,
267+
is_master=control_saving,
268+
use_pp_format=True,
269+
n_write_entries=191,
270+
)
255271
total_size = save_state_dict_shards(
256272
sharded_state_dict=state_dict_shard,
257273
checkpoint=checkpoint,
@@ -626,7 +642,9 @@ def _get_param_id_from_optimizer_param(
626642
if self.verbose and self.coordinator.is_master():
627643
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
628644

629-
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
645+
def save_unsharded_model(
646+
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
647+
):
630648
"""
631649
Save model state dict to a single file with given checkpointing path.
632650
@@ -635,6 +653,7 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
635653
checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
636654
gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
637655
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
656+
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
638657
"""
639658
if self.coordinator.is_master():
640659
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
@@ -651,7 +670,10 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
651670
if self.pp_size == 1:
652671
# When pipeline is not used, let master rank directly save the collected state_dict.
653672
if self.tp_rank == 0:
654-
save_state_dict(state_dict, checkpoint, use_safetensors)
673+
if use_async:
674+
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
675+
else:
676+
save_state_dict(state_dict, checkpoint, use_safetensors)
655677
else:
656678
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
657679
state_dict_list = [None for _ in range(self.pp_size)]
@@ -662,7 +684,18 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
662684
complete_state_dict = dict()
663685
for _state_dict in state_dict_list:
664686
complete_state_dict.update(_state_dict)
665-
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
687+
if use_async:
688+
from tensornvme.async_file_io import AsyncFileWriter
689+
690+
from colossalai.utils.safetensors import move_and_save
691+
692+
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
693+
if id(model) not in self.pinned_state_dicts:
694+
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
695+
self.async_writers.append(writer)
696+
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
697+
else:
698+
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
666699

667700
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
668701
"""

colossalai/checkpoint_io/moe_checkpoint.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def save_sharded_model(
117117
prefix: Optional[str] = None,
118118
size_per_shard: int = 1024,
119119
use_safetensors: bool = False,
120+
use_async: bool = False,
120121
) -> None:
121122
"""
122123
Save sharded model checkpoint under the given checkpointing path.
@@ -161,24 +162,27 @@ def save_sharded_model(
161162

162163
if self.pp_size == 1 and self.ep_size == 1:
163164
# When pipeline is not used, save the model shards as in general checkpointIO
164-
total_size = save_state_dict_shards(
165-
sharded_state_dict=state_dict_shard,
166-
checkpoint=checkpoint,
167-
index_file=index_file,
168-
base_filename=weights_name,
169-
is_master=control_saving,
170-
use_safetensors=use_safetensors,
171-
)
172-
if control_saving:
173-
index_file.append_meta_data("total_size", total_size)
174-
index_file.write_index_file(save_index_file)
175-
save_config_file(model, checkpoint)
176-
if self.verbose and self.coordinator.is_master():
177-
logging.info(
178-
f"The model is split into checkpoint shards. "
179-
f"You can find where each parameters has been saved in the "
180-
f"index located at {save_index_file}."
181-
)
165+
if use_async:
166+
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
167+
else:
168+
total_size = save_state_dict_shards(
169+
sharded_state_dict=state_dict_shard,
170+
checkpoint=checkpoint,
171+
index_file=index_file,
172+
base_filename=weights_name,
173+
is_master=control_saving,
174+
use_safetensors=use_safetensors,
175+
)
176+
if control_saving:
177+
index_file.append_meta_data("total_size", total_size)
178+
index_file.write_index_file(save_index_file)
179+
save_config_file(model, checkpoint)
180+
if self.verbose and self.coordinator.is_master():
181+
logging.info(
182+
f"The model is split into checkpoint shards. "
183+
f"You can find where each parameters has been saved in the "
184+
f"index located at {save_index_file}."
185+
)
182186

183187
dist.barrier()
184188
else:
@@ -708,10 +712,20 @@ def save_unsharded_model(
708712
checkpoint: str,
709713
gather_dtensor: bool,
710714
use_safetensors: bool,
715+
use_async: bool = False,
711716
):
712717
state_dict = self.pre_save_model(model)
713718
if dist.get_rank() == 0:
714-
torch.save(state_dict, checkpoint)
719+
if use_async:
720+
super().save_unsharded_model(
721+
model=model,
722+
checkpoint=checkpoint,
723+
gather_dtensor=gather_dtensor,
724+
use_safetensors=use_safetensors,
725+
use_async=use_async,
726+
)
727+
else:
728+
torch.save(state_dict, checkpoint)
715729
dist.barrier()
716730

717731
# Copied from colossalai.moe

colossalai/checkpoint_io/utils.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,11 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
371371
# ======================================
372372

373373

374-
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
374+
def save_state_dict(
375+
state_dict: dict,
376+
checkpoint_file_path: str,
377+
use_safetensors: bool,
378+
) -> None:
375379
"""
376380
Save state dict to checkpoint.
377381
@@ -581,14 +585,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
581585
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
582586
if use_safetensors:
583587
from safetensors.torch import load_file as safe_load_file
584-
from safetensors.torch import safe_open
585588

586-
with safe_open(checkpoint_file, framework="pt") as f:
587-
metadata = f.metadata()
588-
if metadata["format"] != "pt":
589-
raise NotImplementedError(
590-
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
591-
)
592589
return safe_load_file(checkpoint_file)
593590
else:
594591
return torch.load(checkpoint_file, map_location=torch.device("cpu"))

0 commit comments

Comments
 (0)