Skip to content
19 changes: 17 additions & 2 deletions nemo_rl/models/megatron/community_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,24 @@ def export_model_from_megatron(
f"HF checkpoint already exists at {output_path}. Delete it to run or set overwrite=True."
)

try:
from megatron.bridge.training.model_load_save import (
temporary_distributed_context,
)
except ImportError:
raise ImportError("megatron.bridge.training is not available.")

bridge = AutoBridge.from_hf_pretrained(hf_model_name, trust_remote_code=True)
megatron_model = bridge.load_megatron_model(input_path)
bridge.save_hf_pretrained(megatron_model, output_path)

# Export performs on CPU with proper distributed context
with temporary_distributed_context(backend="gloo"):
# Load the Megatron model
megatron_model = bridge.load_megatron_model(
input_path, skip_temp_dist_context=True
)

# Save in HuggingFace format
bridge.save_hf_pretrained(megatron_model, output_path)

# resetting mcore state
import megatron.core.rerun_state_machine
Expand Down
Loading