Skip to content

Commit ce0ec40

Browse files
authored
[checkpointio] fix for async io (#6189)
1 parent 5ff5323 commit ce0ec40

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

colossalai/checkpoint_io/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,12 +315,13 @@ def async_save_state_dict_shards(
315315
checkpoint_file_path = os.path.join(checkpoint, shard_file)
316316

317317
if state_preprocess:
318-
state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator=".")
318+
state_dict, metadata = _flatten_optim_state_dict(state_dict=shard, seperator=".")
319319
else:
320320
state_dict = shard
321+
metadata = None
321322

322323
# Only save on master rank.
323-
writer = save(checkpoint_file_path, state_dict=state_dict)
324+
writer = save(checkpoint_file_path, state_dict=state_dict, metadata=metadata)
324325
writers.append(writer)
325326
shard_filenames.append(shard_file)
326327
del shard
@@ -377,9 +378,10 @@ def async_move_save_state_dict_shards(
377378
checkpoint_file_path = os.path.join(checkpoint, shard_file)
378379

379380
if state_preprocess:
380-
state_dict, _ = _flatten_optim_state_dict(state_dict=shard)
381+
state_dict, metadata = _flatten_optim_state_dict(state_dict=shard)
381382
else:
382383
state_dict = shard
384+
metadata = None
383385

384386
if pinned_state_dict is not None:
385387
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in state_dict.keys()}
@@ -388,7 +390,7 @@ def async_move_save_state_dict_shards(
388390
returned_state_dict.update(sub_pinned_state_dict)
389391

390392
# Only save on master rank.
391-
writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict)
393+
writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict, metadata)
392394
writers.append(writer)
393395
shard_filenames.append(shard_file)
394396
del shard

0 commit comments

Comments
 (0)