Skip to content

Commit 8c6b915

Browse files
Fix megatron distributed checkpoint metadata pass through (#431)
Signed-off-by: Chenhan Yu <[email protected]> Co-authored-by: Keval Morabia <[email protected]>
1 parent c692074 commit 8c6b915

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

modelopt/torch/opt/plugins/mcore_dist_checkpointing.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def _load_extra_state_from_sharded_checkpoint(
177177
model: torch.nn.Module,
178178
checkpoint_name: str | Path,
179179
prefix: str,
180+
metadata: dict[str, Any] | None = None,
180181
) -> None:
181182
"""Load extra state from sharded checkpoint.
182183
@@ -187,6 +188,12 @@ def _load_extra_state_from_sharded_checkpoint(
187188
model: the model to load extra state into
188189
checkpoint_name: the checkpoint folder path
189190
prefix: the prefix to add to the modelopt_state keys
191+
metadata: the metadata for distributed checkpointing
192+
193+
Note:
194+
The metadata includes several breaking changes. For example, `singleton_local_shards`
195+
is set to `True` (was not set before) in megatron-core-0.15.0. This flag affects the
196+
sharded state_dict format and must be consistent between saving and loading.
190197
"""
191198
sharded_state_dict = model.sharded_state_dict(prefix=prefix)
192199
extra_sharded_state_dict = {k: v for k, v in sharded_state_dict.items() if "_extra_state" in k}
@@ -208,13 +215,20 @@ def restore_sharded_modelopt_state(
208215
model: list[torch.nn.Module],
209216
checkpoint_name: str | Path,
210217
prefix: str = "",
218+
metadata: dict[str, Any] | None = None,
211219
) -> None:
212220
"""Restore modelopt_state from the sharded state_dict format.
213221
214222
Args:
215223
model: the model to restore the modelopt optimization
216224
checkpoint_name: the checkpoint folder path
217225
prefix: the prefix to add to the modelopt_state keys ("model." for NeMo)
226+
metadata: the metadata for distributed checkpointing
227+
228+
Note:
229+
The metadata includes several breaking changes. For example, `singleton_local_shards`
230+
is set to `True` (was not set before) in megatron-core-0.15.0. This flag affects the
231+
sharded state_dict format and must be consistent between saving and loading.
218232
"""
219233
if len(model) > 1:
220234
raise ValueError("sharded_modelopt_state does not support virtual pipeline parallel!")
@@ -247,4 +261,4 @@ def restore_sharded_modelopt_state(
247261
#
248262
model[0] = mto.restore_from_modelopt_state(model[0], common_modelopt_state)
249263

250-
_load_extra_state_from_sharded_checkpoint(model[0], checkpoint_name, prefix)
264+
_load_extra_state_from_sharded_checkpoint(model[0], checkpoint_name, prefix, metadata=metadata)

0 commit comments

Comments
 (0)