@@ -315,12 +315,13 @@ def async_save_state_dict_shards(
315
315
checkpoint_file_path = os .path .join (checkpoint , shard_file )
316
316
317
317
if state_preprocess :
318
- state_dict , _ = _flatten_optim_state_dict (state_dict = shard , seperator = "." )
318
+ state_dict , metadata = _flatten_optim_state_dict (state_dict = shard , seperator = "." )
319
319
else :
320
320
state_dict = shard
321
+ metadata = None
321
322
322
323
# Only save on master rank.
323
- writer = save (checkpoint_file_path , state_dict = state_dict )
324
+ writer = save (checkpoint_file_path , state_dict = state_dict , metadata = metadata )
324
325
writers .append (writer )
325
326
shard_filenames .append (shard_file )
326
327
del shard
@@ -377,9 +378,10 @@ def async_move_save_state_dict_shards(
377
378
checkpoint_file_path = os .path .join (checkpoint , shard_file )
378
379
379
380
if state_preprocess :
380
- state_dict , _ = _flatten_optim_state_dict (state_dict = shard )
381
+ state_dict , metadata = _flatten_optim_state_dict (state_dict = shard )
381
382
else :
382
383
state_dict = shard
384
+ metadata = None
383
385
384
386
if pinned_state_dict is not None :
385
387
sub_pinned_state_dict = {k : pinned_state_dict [k ] for k in state_dict .keys ()}
@@ -388,7 +390,7 @@ def async_move_save_state_dict_shards(
388
390
returned_state_dict .update (sub_pinned_state_dict )
389
391
390
392
# Only save on master rank.
391
- writer = move_and_save (checkpoint_file_path , state_dict , sub_pinned_state_dict )
393
+ writer = move_and_save (checkpoint_file_path , state_dict , sub_pinned_state_dict , metadata )
392
394
writers .append (writer )
393
395
shard_filenames .append (shard_file )
394
396
del shard
0 commit comments