Skip to content

Commit d4a4360

Browse files
[checkpointio] support async model save (#6131)
* [checkpointio] support async model save * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5a03d26 commit d4a4360

File tree

7 files changed

+209
-28
lines changed

7 files changed

+209
-28
lines changed

colossalai/booster/booster.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def save_model(
310310
prefix: Optional[str] = None,
311311
size_per_shard: int = 1024,
312312
use_safetensors: bool = False,
313+
use_async: bool = False,
313314
) -> None:
314315
"""Save model to checkpoint.
315316
@@ -333,6 +334,7 @@ def save_model(
333334
prefix=prefix,
334335
size_per_shard=size_per_shard,
335336
use_safetensors=use_safetensors,
337+
use_async=use_async,
336338
)
337339

338340
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,12 @@ def load_sharded_model(
259259
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
260260
model.update_master_params()
261261

262-
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
262+
def save_unsharded_model(
263+
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
264+
):
263265
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
264266
model._force_wait_all_gather()
265-
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
267+
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
266268

267269
def save_sharded_model(
268270
self,
@@ -272,11 +274,12 @@ def save_sharded_model(
272274
prefix: Optional[str] = None,
273275
max_shard_size: int = 1024,
274276
use_safetensors: bool = False,
277+
use_async: bool = False,
275278
):
276279
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
277280
model._force_wait_all_gather()
278281
return super().save_sharded_model(
279-
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
282+
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async
280283
)
281284

282285
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):

colossalai/booster/plugin/torch_ddp_plugin.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,17 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo
3333
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
3434
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
3535

36-
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
36+
def save_unsharded_model(
37+
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
38+
):
3739
"""
3840
Save model to checkpoint but only on master process.
3941
"""
4042
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
4143
if self.coordinator.is_master():
42-
super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
44+
super().save_unsharded_model(
45+
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
46+
)
4347

4448
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
4549
"""
@@ -71,14 +75,21 @@ def save_sharded_model(
7175
prefix: Optional[str] = None,
7276
max_shard_size: int = 1024,
7377
use_safetensors: bool = False,
78+
use_async: bool = False,
7479
):
7580
"""
7681
Save model to checkpoint but only on master process.
7782
"""
7883
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
7984
if self.coordinator.is_master():
8085
super().save_sharded_model(
81-
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
86+
model.unwrap(),
87+
checkpoint_path,
88+
gather_dtensor,
89+
prefix,
90+
max_shard_size,
91+
use_safetensors,
92+
use_async=use_async,
8293
)
8394

8495
def load_sharded_model(

colossalai/checkpoint_io/checkpoint_io_base.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from abc import ABC, abstractmethod
22
from pathlib import Path
3-
from typing import Optional, Union
3+
from typing import Dict, Optional, Union
44

55
import torch
66
import torch.nn as nn
77
from torch.optim import Optimizer
88
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
99

1010
from colossalai.interface import ModelWrapper
11+
from colossalai.logging import get_dist_logger
1112

1213
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
1314

@@ -58,9 +59,34 @@ class CheckpointIO(ABC):
5859
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
5960
"""
6061

62+
N_WRITE_ENTRIES: int = 32
63+
6164
# ======================================
6265
# Public methods
6366
# ======================================
67+
def __init__(self):
68+
super().__init__()
69+
self.pinned_state_dicts: Dict[int, dict] = {}
70+
self.async_writers = []
71+
72+
def _sync_io(self):
73+
for writer in self.async_writers:
74+
writer.synchronize()
75+
writer.fp.close()
76+
self.async_writers.clear()
77+
78+
def _sync_d2h(self):
79+
for writer in self.async_writers:
80+
writer.sync_before_step()
81+
82+
def synchronize(self):
83+
"""This method must be called before updating the model weights."""
84+
self._sync_d2h()
85+
86+
def __del__(self):
87+
self._sync_d2h()
88+
self._sync_io()
89+
6490
def load_model(
6591
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
6692
) -> Union[nn.Module, ModelWrapper]:
@@ -111,6 +137,7 @@ def save_model(
111137
prefix: str = None,
112138
size_per_shard: int = 1024,
113139
use_safetensors: bool = False,
140+
use_async: bool = False,
114141
):
115142
"""
116143
Save model to checkpoint.
@@ -138,11 +165,21 @@ def save_model(
138165
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
139166
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
140167
"""
168+
self._sync_io()
169+
if use_async and not use_safetensors:
170+
logger = get_dist_logger()
171+
logger.warning(
172+
"Async save is only supported when use_safetensors is set to True. "
173+
"Setting use_safetensors to True for async save."
174+
)
175+
use_safetensors = True
141176

142177
if shard:
143-
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
178+
self.save_sharded_model(
179+
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async=use_async
180+
)
144181
else:
145-
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
182+
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
146183

147184
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
148185
"""
@@ -234,6 +271,7 @@ def save_sharded_model(
234271
prefix: Optional[str],
235272
size_per_shard: int,
236273
use_safetensors: bool,
274+
use_async: bool = False,
237275
):
238276
"""
239277
Save model to sharded checkpoint.
@@ -248,7 +286,9 @@ def save_sharded_model(
248286
"""
249287

250288
@abstractmethod
251-
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
289+
def save_unsharded_model(
290+
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
291+
):
252292
"""
253293
Save model to unsharded checkpoint.
254294

colossalai/checkpoint_io/general_checkpoint_io.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@
88
import torch.nn as nn
99
from torch.optim import Optimizer
1010

11+
from colossalai.utils.safetensors import move_and_save
12+
1113
from .checkpoint_io_base import CheckpointIO
1214
from .index_file import CheckpointIndexFile
1315
from .utils import (
16+
async_save_state_dict_shards,
17+
create_pinned_state_dict,
1418
get_model_base_filenames,
1519
get_optimizer_base_filenames,
1620
is_safetensors_available,
@@ -40,15 +44,27 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
4044
checkpoint = load_state_dict(checkpoint)
4145
model.load_state_dict(checkpoint, strict=strict)
4246

43-
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
47+
def save_unsharded_model(
48+
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
49+
):
4450
state_dict = model.state_dict()
4551

4652
# TODO(FrankLeeeee): add support for gather_dtensor
4753
if gather_dtensor:
4854
pass
4955

50-
# save the checkpoint
51-
save_state_dict(state_dict, checkpoint, use_safetensors)
56+
if use_async:
57+
from tensornvme.async_file_io import AsyncFileWriter
58+
59+
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
60+
if id(model) not in self.pinned_state_dicts:
61+
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
62+
self.async_writers.append(writer)
63+
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
64+
else:
65+
66+
# save the checkpoint
67+
save_state_dict(state_dict, checkpoint, use_safetensors)
5268

5369
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
5470
"""
@@ -151,6 +167,7 @@ def save_sharded_model(
151167
prefix: Optional[str] = None,
152168
max_shard_size: int = 1024,
153169
use_safetensors: bool = False,
170+
use_async: bool = False,
154171
):
155172
"""
156173
implement this method as it can be supported by Huggingface model,
@@ -168,16 +185,30 @@ def save_sharded_model(
168185
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
169186
index_file = CheckpointIndexFile(checkpoint_path)
170187

171-
# Save shards of optimizer states.
172-
# In general cases, is_master is set to True to get the right behavior.
173-
total_size = save_state_dict_shards(
174-
sharded_state_dict=state_dict_shard,
175-
checkpoint=checkpoint_path,
176-
index_file=index_file,
177-
base_filename=weights_name,
178-
is_master=True,
179-
use_safetensors=use_safetensors,
180-
)
188+
if use_async:
189+
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
190+
total_size, new_pinned_state_dict, writers = async_save_state_dict_shards(
191+
sharded_state_dict=state_dict_shard,
192+
checkpoint=checkpoint_path,
193+
index_file=index_file,
194+
base_filename=weights_name,
195+
is_master=True,
196+
pinned_state_dict=pinned_state_dict,
197+
n_write_entries=self.N_WRITE_ENTRIES,
198+
)
199+
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
200+
self.async_writers.extend(writers)
201+
else:
202+
# Save shards of optimizer states.
203+
# In general cases, is_master is set to True to get the right behavior.
204+
total_size = save_state_dict_shards(
205+
sharded_state_dict=state_dict_shard,
206+
checkpoint=checkpoint_path,
207+
index_file=index_file,
208+
base_filename=weights_name,
209+
is_master=True,
210+
use_safetensors=use_safetensors,
211+
)
181212

182213
index_file.append_meta_data("total_size", total_size)
183214
index_file.write_index_file(save_index_file)

colossalai/checkpoint_io/utils.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import defaultdict
66
from itertools import chain
77
from pathlib import Path
8-
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
8+
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple
99

1010
import torch
1111
import torch.nn as nn
@@ -19,6 +19,7 @@
1919
to_global,
2020
to_global_for_customized_distributed_tensor,
2121
)
22+
from colossalai.utils.safetensors import move_and_save
2223

2324
SAFE_WEIGHTS_NAME = "model.safetensors"
2425
WEIGHTS_NAME = "pytorch_model.bin"
@@ -263,6 +264,71 @@ def save_state_dict_shards(
263264
return total_size
264265

265266

267+
def async_save_state_dict_shards(
268+
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
269+
checkpoint: str,
270+
index_file: "CheckpointIndexFile",
271+
base_filename: str,
272+
is_master: bool,
273+
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
274+
n_write_entries: int,
275+
use_pp_format: bool = False,
276+
) -> Tuple[int, Dict[str, torch.Tensor], list]:
277+
"""
278+
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
279+
Args:
280+
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
281+
checkpoint (str): The path of checkpoint directory as string.
282+
index_file (CheckpointIndexFile): The index file object to be updated.
283+
base_filename (str): Decides the prefix of filenames of shards.
284+
is_master (bool): Whether current rank is main process.
285+
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
286+
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
287+
288+
Returns:
289+
int: the total size of shards
290+
"""
291+
from tensornvme.async_file_io import AsyncFileWriter
292+
293+
total_size = 0
294+
shard_filenames = []
295+
if pinned_state_dict is None:
296+
returned_state_dict = {}
297+
else:
298+
returned_state_dict = pinned_state_dict
299+
writers = []
300+
for idx, shard_pair in enumerate(sharded_state_dict):
301+
shard, current_size = shard_pair
302+
# Just loop over the sharder and gather to other ranks if not master
303+
if not is_master:
304+
del shard
305+
continue
306+
shard_file = get_shard_filename(base_filename, idx)
307+
total_size = total_size + current_size
308+
for key in shard.keys():
309+
index_file.append_weight_map(key, shard_file)
310+
checkpoint_file_path = os.path.join(checkpoint, shard_file)
311+
312+
writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread")
313+
writers.append(writer)
314+
315+
if pinned_state_dict is not None:
316+
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
317+
else:
318+
sub_pinned_state_dict = create_pinned_state_dict(shard)
319+
returned_state_dict.update(sub_pinned_state_dict)
320+
321+
# Only save on master rank.
322+
move_and_save(writer, shard, sub_pinned_state_dict)
323+
shard_filenames.append(shard_file)
324+
del shard
325+
326+
# Clean folder, deleted unneeded files.
327+
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
328+
329+
return total_size, returned_state_dict, writers
330+
331+
266332
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
267333
"""
268334
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
@@ -799,3 +865,10 @@ def get_shard_filename(weights_name: str, idx: int):
799865
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
800866
shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
801867
return shard_file
868+
869+
870+
def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]):
871+
pin_mem = dict()
872+
for name, tensor in state_dict.items():
873+
pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu")
874+
return pin_mem

0 commit comments

Comments
 (0)