Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions examples/conversion/hf_megatron_roundtrip_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from rich.table import Table

from megatron.bridge import AutoBridge
from megatron.bridge.models.conversion import weights_verification_table
from megatron.bridge.models.decorators import torchrun_main
from megatron.bridge.models.hf_pretrained.utils import is_safe_repo

Expand Down Expand Up @@ -136,11 +137,9 @@ def main(

# Export (Megatron -> HF weights iteration only)
export_start = perf_counter()
for _ in bridge.export_hf_weights(
megatron_model,
show_progress=show_progress and _is_rank_zero(),
):
pass
table = weights_verification_table(bridge, megatron_model)
if _is_rank_zero():
console.print(table)
_sync_cuda()
_maybe_barrier()
export_duration = perf_counter() - export_start
Expand Down
9 changes: 8 additions & 1 deletion src/megatron/bridge/models/conversion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,19 @@ def weights_verification_table(bridge, megatron_model) -> Table:
# Check each weight against the original HF-model
for name, param in bridge.export_hf_weights(megatron_model, show_progress=True):
original_param = bridge.hf_pretrained.state[name]
param_for_comparison = param.to(dtype=original_param.dtype) if param.dtype != original_param.dtype else param
table.add_row(
name,
str(tuple(param.shape)),
str(param.dtype).replace("torch.", ""),
str(param.device),
"✅" if torch.allclose(param, original_param.to(param.device), atol=1e-6) else "❌",
(
f"{param_for_comparison.shape} != {original_param.shape}"
if param_for_comparison.shape != original_param.shape
else (
"✅" if torch.allclose(param_for_comparison, original_param.to(param.device), atol=1e-6) else "❌"
)
),
)

return table
Expand Down
4 changes: 2 additions & 2 deletions src/megatron/bridge/models/deepseek/deepseek_v3_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,13 @@ def maybe_modify_converted_hf_weight(
if inv_freq is None:
rotary_dim = self.hf_config.qk_rope_head_dim
rotary_base = self.hf_config.rope_theta
inv_freq = 1.0 / (rotary_base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float32) / rotary_dim))
inv_freq = 1.0 / (rotary_base ** (torch.arange(0, rotary_dim, 1, dtype=torch.float32) / rotary_dim))
self._deepseek_inv_freq = inv_freq

if converted_weights_dict:
reference_tensor = next(iter(converted_weights_dict.values()))
if inv_freq.device != reference_tensor.device:
inv_freq = inv_freq.to(device=reference_tensor.device)
inv_freq = inv_freq.to(device=reference_tensor.device, dtype=reference_tensor.dtype)
self._deepseek_inv_freq = inv_freq

converted_weights_dict[inv_freq_key] = inv_freq
Expand Down
200 changes: 200 additions & 0 deletions tests/functional_tests/models/deepseek/test_moonlight_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import shutil
import subprocess
import sys
from pathlib import Path

import pytest
from transformers import AutoConfig, AutoTokenizer
from transformers.dynamic_module_utils import get_class_from_dynamic_module


MOONLIGHT_HF_MODEL_ID_PRIMARY = "moonshotai/Moonlight-16B-A3B-Instruct"
MOONLIGHT_HF_MODEL_ID_FALLBACK = "moonshotai/Moonlight-16B-A3B"

# Keep this toy config intentionally small to make instantiation cheap.
# Moonlight config keys may differ slightly across revisions; we apply overrides
# opportunistically and only assert on keys we set.
MOONLIGHT_OVERRIDES = {
# Core size reductions
"num_hidden_layers": 2,
"hidden_size": 2048,
"intermediate_size": 6144,
"num_attention_heads": 32,
"num_key_value_heads": 4,
"vocab_size": 32000,
"max_position_embeddings": 4096,
# MoE-ish knobs (if present)
"n_group": 2,
"n_routed_experts": 4,
"n_shared_experts": 1,
"num_experts_per_tok": 2,
"moe_intermediate_size": 768,
# Common DeepSeek-like knobs (if present)
"hidden_act": "silu",
"initializer_range": 0.02,
"first_k_dense_replace": 1,
"topk_group": 2,
}


def _try_load_config():
"""Try Instruct model id first, fallback to base."""
try:
return MOONLIGHT_HF_MODEL_ID_PRIMARY, AutoConfig.from_pretrained(
MOONLIGHT_HF_MODEL_ID_PRIMARY, trust_remote_code=True
)
except Exception:
return MOONLIGHT_HF_MODEL_ID_FALLBACK, AutoConfig.from_pretrained(
MOONLIGHT_HF_MODEL_ID_FALLBACK, trust_remote_code=True
)


class TestMoonlightConversion:
"""Functional tests for Moonlight toy conversion paths."""

@pytest.fixture(scope="class")
def moonlight_toy_model_path(self, tmp_path_factory):
temp_dir = tmp_path_factory.mktemp("moonlight_toy_model")
model_dir = temp_dir / "moonlight_toy"

hf_model_id, config = _try_load_config()

for key, value in MOONLIGHT_OVERRIDES.items():
setattr(config, key, value)

# Some configs ship a quantization_config that isn't JSON-serializable.
if hasattr(config, "quantization_config"):
delattr(config, "quantization_config")

auto_map = getattr(config, "auto_map", None) or {}
model_class_ref = auto_map.get("AutoModelForCausalLM")
if model_class_ref is None:
raise RuntimeError(f"Expected config.auto_map['AutoModelForCausalLM'] for {hf_model_id}, got: {auto_map}")

model_class = get_class_from_dynamic_module(
class_reference=model_class_ref,
pretrained_model_name_or_path=hf_model_id,
cache_dir=None,
force_download=False,
resume_download=True,
proxies=None,
use_auth_token=None,
revision=None,
local_files_only=False,
repo_id=hf_model_id,
)

model = model_class(config)
model = model.bfloat16() if hasattr(model, "bfloat16") else model

# Save a tokenizer (lightweight compatible tokenizer is fine for conversion flows).
try:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.save_pretrained(model_dir)
except Exception:
pass

# Save model and config
model.save_pretrained(model_dir, safe_serialization=True)
modeling_filepath = os.path.abspath(sys.modules[model_class.__module__].__file__)
shutil.copy(modeling_filepath, model_dir)

# Ensure config.json exists with expected keys
config_path = model_dir / "config.json"
with open(config_path, "w") as f:
json.dump(model.config.to_dict(), f, indent=2)

return str(model_dir)

@pytest.mark.run_only_on("GPU")
@pytest.mark.parametrize(
"tp,pp,ep,test_name",
[
(2, 1, 1, "TP"),
(1, 2, 1, "PP"),
(1, 1, 2, "EP"),
],
)
def test_moonlight_conversion_parallelism(self, moonlight_toy_model_path, tmp_path, tp, pp, ep, test_name):
test_output_dir = tmp_path / f"moonlight_{test_name}"
test_output_dir.mkdir(exist_ok=True)

cmd = [
"python",
"-m",
"torch.distributed.run",
"--nproc_per_node=2",
"--nnodes=1",
"-m",
"coverage",
"run",
"--data-file=/opt/Megatron-Bridge/.coverage",
"--source=/opt/Megatron-Bridge/",
"--parallel-mode",
"examples/conversion/hf_megatron_roundtrip_multi_gpu.py",
"--hf-model-id",
moonlight_toy_model_path,
"--output-dir",
str(test_output_dir),
"--tp",
str(tp),
"--pp",
str(pp),
"--ep",
str(ep),
]

result = subprocess.run(
cmd, capture_output=True, text=True, cwd=Path(__file__).parent.parent.parent.parent.parent
)

if result.returncode != 0:
print(f"STDOUT: {result.stdout}")
print(f"STDERR: {result.stderr}")
assert result.returncode == 0, f"Moonlight {test_name} conversion failed with {result.returncode}"

# Verify outputs
model_name = Path(moonlight_toy_model_path).name
converted_dir = test_output_dir / model_name
assert converted_dir.exists()

config_file = converted_dir / "config.json"
assert config_file.exists()

weights_file_safetensors = converted_dir / "model.safetensors"
weights_file_pytorch = converted_dir / "pytorch_model.bin"
weights_found = weights_file_safetensors.exists() or weights_file_pytorch.exists()
if not weights_found:
shards_st = list(converted_dir.glob("model-*-of-*.safetensors"))
shards_pt = list(converted_dir.glob("pytorch_model-*-of-*.bin"))
weights_found = len(shards_st) > 0 or len(shards_pt) > 0
assert weights_found

with open(config_file) as f:
saved = json.load(f)

# Assert the values we explicitly set in the toy config are preserved.
for k, v in MOONLIGHT_OVERRIDES.items():
if k in saved:
assert saved[k] == v, f"Expected {k}={v}, got {saved[k]}"

assert "model_type" in saved

print(f"SUCCESS: Moonlight {test_name} conversion test completed successfully")
print(f"Converted model saved at: {converted_dir}")
2 changes: 1 addition & 1 deletion tests/unit_tests/models/deepseek/test_deepseek_bridges.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def test_export_injects_inv_freq_for_layer(self, mock_pretrained_v3):
expected = 1.0 / (
mock_pretrained_v3.config.rope_theta
** (
torch.arange(0, mock_pretrained_v3.config.qk_rope_head_dim, 2, dtype=torch.float32)
torch.arange(0, mock_pretrained_v3.config.qk_rope_head_dim, 1, dtype=torch.float32)
/ mock_pretrained_v3.config.qk_rope_head_dim
)
)
Expand Down
Loading