Skip to content

Commit 73ae313

Browse files
fix: parallel state initialization error in Megatron to HF model conversion (NVIDIA-NeMo#1120)
Signed-off-by: Stan Kirdey <stan@inflection.ai>
1 parent 731f16f commit 73ae313

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

nemo_rl/models/megatron/community_import.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,24 @@ def export_model_from_megatron(
104104
f"HF checkpoint already exists at {output_path}. Delete it to run or set overwrite=True."
105105
)
106106

107+
try:
108+
from megatron.bridge.training.model_load_save import (
109+
temporary_distributed_context,
110+
)
111+
except ImportError:
112+
raise ImportError("megatron.bridge.training is not available.")
113+
107114
bridge = AutoBridge.from_hf_pretrained(hf_model_name, trust_remote_code=True)
108-
megatron_model = bridge.load_megatron_model(input_path)
109-
bridge.save_hf_pretrained(megatron_model, output_path)
115+
116+
# Export performs on CPU with proper distributed context
117+
with temporary_distributed_context(backend="gloo"):
118+
# Load the Megatron model
119+
megatron_model = bridge.load_megatron_model(
120+
input_path, skip_temp_dist_context=True
121+
)
122+
123+
# Save in HuggingFace format
124+
bridge.save_hf_pretrained(megatron_model, output_path)
110125

111126
# resetting mcore state
112127
import megatron.core.rerun_state_machine

0 commit comments

Comments
 (0)