Skip to content

Commit eb69e64

Browse files
flybird11111ver217
authored andcommitted
[async io]supoort async io (#6137)
* support async optimizer save/load * fix * fix * support pin mem * Update low_level_zero_plugin.py * fix * fix * fix * fix * fix
1 parent b90835b commit eb69e64

15 files changed

+374
-46
lines changed

colossalai/booster/booster.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ def save_optimizer(
359359
gather_dtensor: bool = True,
360360
prefix: Optional[str] = None,
361361
size_per_shard: int = 1024,
362+
use_async: bool = False,
362363
) -> None:
363364
"""
364365
Save optimizer to checkpoint.
@@ -374,7 +375,9 @@ def save_optimizer(
374375
names to compose the keys in state_dict. Defaults to None.
375376
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
376377
"""
377-
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
378+
self.checkpoint_io.save_optimizer(
379+
optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard, use_async=use_async
380+
)
378381

379382
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
380383
"""Save lr scheduler to checkpoint.

colossalai/booster/plugin/gemini_plugin.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool =
9494
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
9595
super().load_unsharded_model(model, checkpoint, strict=strict)
9696

97-
def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
97+
def save_unsharded_optimizer(
98+
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
99+
):
98100
"""
99101
Save unsharded optimizer state dict to checkpoint.
100102
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
@@ -178,7 +180,13 @@ def load_sharded_model(
178180
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
179181

180182
def save_sharded_optimizer(
181-
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
183+
self,
184+
optimizer: GeminiOptimizer,
185+
checkpoint: Path,
186+
gather_dtensor: bool,
187+
prefix: str,
188+
size_per_shard: int,
189+
use_async: bool = False,
182190
):
183191
"""
184192
Save sharded optimizer state dict to checkpoint folder.

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
get_shard_filename,
2525
load_param_groups_into_optimizer,
2626
load_shard_state_dict,
27+
load_state_dict,
2728
load_states_into_optimizer,
2829
save_param_groups,
2930
save_state_dict,
@@ -113,7 +114,9 @@ def _hook_context(self):
113114

114115

115116
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
116-
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
117+
def save_unsharded_optimizer(
118+
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False, use_async: bool = False
119+
):
117120
"""Save optimizer to checkpoint but only on master process.
118121
119122
Args:
@@ -125,9 +128,34 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str,
125128
# the `state_dict` in LowLevelZeroOptimizer has communication
126129
# if only the master rank collect state_dict and save,
127130
# the communication on each rank would not match
128-
state_dict = optimizer.state_dict()
131+
if use_async:
132+
if id(optimizer) not in self.pinned_state_dicts:
133+
self.pinned_state_dicts[id(optimizer)] = {}
134+
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
135+
else:
136+
pinned_state_dicts = None
137+
state_dict = optimizer.state_dict(pinned_state_dicts)
129138
if self.coordinator.is_master():
130-
save_state_dict(state_dict, checkpoint, use_safetensors=False)
139+
if use_async:
140+
from tensornvme.async_file_io import AsyncFileWriter
141+
142+
from colossalai.utils.safetensors import save_nested
143+
144+
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
145+
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]})
146+
self.async_writers.append(f_writer)
147+
else:
148+
save_state_dict(state_dict, checkpoint, use_safetensors=False)
149+
150+
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
151+
use_async = checkpoint.endswith(".safetensors")
152+
if use_async:
153+
from colossalai.utils.safetensors import load_flat
154+
155+
checkpoint = load_flat(checkpoint)
156+
else:
157+
checkpoint = load_state_dict(checkpoint)
158+
optimizer.load_state_dict(checkpoint)
131159

132160
def save_sharded_optimizer(
133161
self,
@@ -136,6 +164,7 @@ def save_sharded_optimizer(
136164
gather_dtensor: bool = False,
137165
prefix: str = None,
138166
size_per_shard: int = 1024,
167+
use_async: bool = False,
139168
):
140169
"""
141170
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
@@ -161,10 +190,16 @@ def save_sharded_optimizer(
161190
# state_dict only provide only 'param_groups'
162191
state_dict = optimizer.optim.state_dict()
163192
# state shard would be handled by the low-level zero optimizer
164-
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
193+
if use_async:
194+
if id(optimizer) not in self.pinned_state_dicts:
195+
self.pinned_state_dicts[id(optimizer)] = {}
196+
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
197+
else:
198+
pinned_state_dicts = None
199+
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts)
165200

166201
# Preparing file paths and index file.
167-
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
202+
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
168203
index_file = CheckpointIndexFile(checkpoint)
169204
index_file.append_meta_data("param_groups", param_group_file)
170205

@@ -184,7 +219,18 @@ def save_sharded_optimizer(
184219

185220
checkpoint_file_path = os.path.join(checkpoint, shard_file)
186221
if self.coordinator.is_master():
187-
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
222+
if use_async:
223+
from tensornvme.async_file_io import AsyncFileWriter
224+
225+
from colossalai.utils.safetensors import save_nested
226+
227+
f_writer = AsyncFileWriter(
228+
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
229+
)
230+
save_nested(f_writer, shard)
231+
self.async_writers.append(f_writer)
232+
else:
233+
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
188234

189235
# Wrap up index file.
190236
index_file.append_meta_data("total_size", total_size)
@@ -223,7 +269,12 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
223269
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
224270

225271
for shard_file in checkpoint_files:
226-
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
272+
if shard_file.endswith(".safetensors"):
273+
from colossalai.utils.safetensors import load_flat
274+
275+
state_dict = load_flat(shard_file)
276+
else:
277+
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
227278
# shard state dict
228279
for param_idx, state in state_dict.items():
229280
for k, v in state.items():

colossalai/booster/plugin/torch_ddp_plugin.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str)
5252
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
5353
super().load_unsharded_optimizer(optimizer, checkpoint)
5454

55-
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
55+
def save_unsharded_optimizer(
56+
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
57+
):
5658
"""
5759
Save optimizer to checkpoint but only on master process.
5860
"""
@@ -113,13 +115,16 @@ def save_sharded_optimizer(
113115
gather_dtensor: bool = True,
114116
prefix: Optional[str] = None,
115117
size_per_shard: int = 1024,
118+
use_async: bool = False,
116119
):
117120
"""
118121
Save optimizer to sharded checkpoint but only on master process.
119122
"""
120123
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
121124
if self.coordinator.is_master():
122-
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
125+
super().save_sharded_optimizer(
126+
optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
127+
)
123128

124129
def load_sharded_optimizer(
125130
self,

colossalai/booster/plugin/torch_fsdp_plugin.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def save_unsharded_model(
6767
full_model_state = model.state_dict()
6868
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
6969

70-
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
70+
def save_unsharded_optimizer(
71+
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
72+
):
7173
"""
7274
Save optimizer to checkpoint but only on master process.
7375
"""
@@ -157,7 +159,13 @@ def load_sharded_model(
157159
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
158160

159161
def save_sharded_optimizer(
160-
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
162+
self,
163+
optimizer: Optimizer,
164+
checkpoint: str,
165+
gather_dtensor: bool,
166+
prefix: str,
167+
size_per_shard: int,
168+
use_async: bool = False,
161169
):
162170
"""
163171
Save optimizer to checkpoint but only on master process.

colossalai/checkpoint_io/checkpoint_io_base.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def save_optimizer(
213213
gather_dtensor=True,
214214
prefix: str = None,
215215
size_per_shard: int = 1024,
216+
use_async: bool = False,
216217
):
217218
"""
218219
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
@@ -229,11 +230,12 @@ def save_optimizer(
229230
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
230231
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
231232
"""
232-
233233
if shard:
234-
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
234+
self.save_sharded_optimizer(
235+
optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
236+
)
235237
else:
236-
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
238+
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)
237239

238240
# ========================================================
239241
# Abstract methods for model loading/saving implementation
@@ -326,7 +328,13 @@ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
326328

327329
@abstractmethod
328330
def save_sharded_optimizer(
329-
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
331+
self,
332+
optimizer: Optimizer,
333+
checkpoint: Path,
334+
gather_dtensor: bool,
335+
prefix: str,
336+
size_per_shard: int,
337+
use_async: bool = False,
330338
):
331339
"""
332340
Save optimizer to sharded checkpoint.
@@ -340,7 +348,9 @@ def save_sharded_optimizer(
340348
"""
341349

342350
@abstractmethod
343-
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
351+
def save_unsharded_optimizer(
352+
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, use_async: bool = False
353+
):
344354
"""
345355
Save optimizer to unsharded checkpoint.
346356

colossalai/checkpoint_io/general_checkpoint_io.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def save_sharded_optimizer(
9898
gather_dtensor: bool,
9999
prefix: str,
100100
size_per_shard: int,
101+
use_async: bool = False,
101102
):
102103
"""
103104
Save sharded optimizer checkpoint under the given checkpointing path.
@@ -155,6 +156,7 @@ def save_unsharded_optimizer(
155156
optimizer: Optimizer,
156157
checkpoint: Path,
157158
gather_dtensor: bool,
159+
use_async: bool = False,
158160
):
159161
# TODO(FrankLeeeee): handle distributed tensors
160162
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def save_sharded_optimizer(
416416
gather_dtensor: bool = True,
417417
prefix: Optional[str] = None,
418418
size_per_shard: int = 1024,
419+
use_async: bool = False,
419420
):
420421
"""
421422
Save sharded optimizer checkpoint under the given checkpointing path.
@@ -725,7 +726,9 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo
725726
# Update master params if mixed-precision training is enabled.
726727
model_before_wrapping.update_master_params()
727728

728-
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
729+
def save_unsharded_optimizer(
730+
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
731+
):
729732
"""
730733
Save optimizer state dict to a file with given path.
731734

colossalai/checkpoint_io/moe_checkpoint.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ def save_sharded_optimizer(
369369
gather_dtensor: bool = True,
370370
prefix: Optional[str] = None,
371371
size_per_shard: int = 1024,
372+
use_async: bool = False,
372373
):
373374
"""
374375
Save sharded optimizer checkpoint under the given checkpointing path.
@@ -729,7 +730,13 @@ def save_unsharded_model(
729730
dist.barrier()
730731

731732
# Copied from colossalai.moe
732-
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
733+
def save_unsharded_optimizer(
734+
self,
735+
optimizer: OptimizerWrapper,
736+
checkpoint: str,
737+
gather_dtensor: bool,
738+
use_async: bool = False,
739+
):
733740
"""
734741
Save optimizer state dict to a file with given path.
735742

colossalai/checkpoint_io/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424
SAFE_WEIGHTS_NAME = "model.safetensors"
2525
WEIGHTS_NAME = "pytorch_model.bin"
2626
STATES_NAME = "pytorch_optim.bin"
27+
SAFE_STATE_NAME = "optimizer.safetensors"
2728
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
2829
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
2930
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
31+
SAFE_STATES_INDEX_NAME = "optimizer.safetensors.index.json"
3032
GROUP_FILE_NAME = "pytorch_optim_group.bin"
3133

3234
# ======================================
@@ -838,14 +840,14 @@ def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
838840
return weights_name, save_index_file
839841

840842

841-
def get_optimizer_base_filenames(prefix: str = None):
843+
def get_optimizer_base_filenames(prefix: str = None, use_safetensors: bool = False):
842844
"""
843845
generate base optimizer state filenames
844846
"""
845-
states_name = STATES_NAME
847+
states_name = SAFE_STATE_NAME if use_safetensors else STATES_NAME
846848
states_name = add_prefix(states_name, prefix)
847849

848-
save_index_file = STATES_INDEX_NAME
850+
save_index_file = SAFE_STATES_INDEX_NAME if use_safetensors else STATES_INDEX_NAME
849851
save_index_file = add_prefix(save_index_file, prefix)
850852

851853
param_group_file = GROUP_FILE_NAME

0 commit comments

Comments
 (0)