@@ -177,6 +177,7 @@ def _load_extra_state_from_sharded_checkpoint(
177177 model : torch .nn .Module ,
178178 checkpoint_name : str | Path ,
179179 prefix : str ,
180+ metadata : dict [str , Any ] | None = None ,
180181) -> None :
181182 """Load extra state from sharded checkpoint.
182183
@@ -187,6 +188,12 @@ def _load_extra_state_from_sharded_checkpoint(
187188 model: the model to load extra state into
188189 checkpoint_name: the checkpoint folder path
189190 prefix: the prefix to add to the modelopt_state keys
191+ metadata: the metadata for distributed checkpointing
192+
193+ Note:
194+ The metadata includes several breaking changes. For example, `singleton_local_shards`
195+ is set to `True` (was not set before) in megatron-core-0.15.0. This flag affects the
196+ sharded state_dict format and must be consistent between saving and loading.
190197 """
191198 sharded_state_dict = model .sharded_state_dict (prefix = prefix )
192199 extra_sharded_state_dict = {k : v for k , v in sharded_state_dict .items () if "_extra_state" in k }
@@ -208,13 +215,20 @@ def restore_sharded_modelopt_state(
208215 model : list [torch .nn .Module ],
209216 checkpoint_name : str | Path ,
210217 prefix : str = "" ,
218+ metadata : dict [str , Any ] | None = None ,
211219) -> None :
212220 """Restore modelopt_state from the sharded state_dict format.
213221
214222 Args:
215223 model: the model to restore the modelopt optimization
216224 checkpoint_name: the checkpoint folder path
217225 prefix: the prefix to add to the modelopt_state keys ("model." for NeMo)
226+ metadata: the metadata for distributed checkpointing
227+
228+ Note:
229+ The metadata includes several breaking changes. For example, `singleton_local_shards`
230+ is set to `True` (was not set before) in megatron-core-0.15.0. This flag affects the
231+ sharded state_dict format and must be consistent between saving and loading.
218232 """
219233 if len (model ) > 1 :
220234 raise ValueError ("sharded_modelopt_state does not support virtual pipeline parallel!" )
@@ -247,4 +261,4 @@ def restore_sharded_modelopt_state(
247261 #
248262 model [0 ] = mto .restore_from_modelopt_state (model [0 ], common_modelopt_state )
249263
250- _load_extra_state_from_sharded_checkpoint (model [0 ], checkpoint_name , prefix )
264+ _load_extra_state_from_sharded_checkpoint (model [0 ], checkpoint_name , prefix , metadata = metadata )
0 commit comments