Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions nemo_automodel/components/checkpoint/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,9 @@ def save_model(
state_dict = model_state.state_dict()

# Convert to HF format if using custom model implementations
state_dict = _maybe_adapt_state_dict_to_hf(model_state.model[0], state_dict, quantization=False)
state_dict = _maybe_adapt_state_dict_to_hf(
model_state.model[0], state_dict, quantization=False, device_mesh=self.moe_mesh
)
# Build the consolidated model.safetensors.index.json if needed
fqn_to_file_index_mapping = self._maybe_build_consolidated_index(model_state, state_dict)

Expand Down Expand Up @@ -305,7 +307,10 @@ def load_model(
storage_reader = self._get_storage_reader(model_path, key_mapping, is_init_step=is_init_step)

state_dict = _maybe_adapt_state_dict_to_hf(
model_state.model[0], state_dict, quantization=self.config.dequantize_base_checkpoint
model_state.model[0],
state_dict,
quantization=self.config.dequantize_base_checkpoint,
device_mesh=self.moe_mesh,
)

state_dict = self._do_load(state_dict, model_path, storage_reader, is_init_step=is_init_step)
Expand Down Expand Up @@ -848,14 +853,14 @@ def compute_should_use_set_data(tensor, tensor_applied):


def _maybe_adapt_state_dict_to_hf(
model_part: nn.Module, state_dict: dict[str, torch.Tensor], quantization: bool = False
model_part: nn.Module, state_dict: dict[str, torch.Tensor], quantization: bool = False, **kwargs
) -> dict[str, torch.Tensor]:
"""
Custom models use state dict adapters to convert the state dict to the Hugging Face format.
"""
adapter = getattr(model_part, "state_dict_adapter", None)
if adapter:
return adapter.to_hf(state_dict, exclude_key_regex=r".*_extra_state.*", quantization=quantization)
return adapter.to_hf(state_dict, exclude_key_regex=r".*_extra_state.*", quantization=quantization, **kwargs)
return state_dict


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Optional

import torch
import torch.distributed as dist
from torch.distributed.device_mesh import DeviceMesh

from nemo_automodel.components.checkpoint.state_dict_adapter import StateDictAdapter
Expand Down Expand Up @@ -49,27 +50,74 @@ def to_hf(
quantization: bool = False,
**kwargs,
) -> dict[str, Any]:
self._uses_model_prefix = any(key.startswith("model.") for key in state_dict.keys())
prefix = "model." if self._uses_model_prefix else ""
hf_state_dict: dict[str, Any] = {}
device_mesh: Optional["DeviceMesh"] = kwargs.get("device_mesh")

for fqn, tensor in state_dict.items():
if ".mlp.experts.gate_and_up_projs" in fqn:
if ".mlp.experts.gate_and_up_projs" in fqn or ".mlp.experts.down_projs" in fqn:
layer_num = re.search(r"layers\.(\d+)", fqn).group(1)
hf_state_dict[f"{prefix}language_model.layers.{layer_num}.mlp.experts.gate_up_proj"] = torch.empty(
(self.moe_config.n_routed_experts, tensor.shape[1], tensor.shape[2]),
dtype=self.dtype,
)
continue

if ".mlp.experts.down_projs" in fqn:
layer_num = re.search(r"layers\.(\d+)", fqn).group(1)
hf_state_dict[f"{prefix}language_model.layers.{layer_num}.mlp.experts.down_proj"] = torch.empty(
(self.moe_config.n_routed_experts, tensor.shape[1], tensor.shape[2]),
dtype=self.dtype,
which = "gate_up_proj" if "gate_and_up_projs" in fqn else "down_proj"
if device_mesh is not None:
n_experts = self.moe_config.n_routed_experts
# Aggregate this layer's expert tensor only for the current key, then free temps.
global_tensor = torch.zeros(
(n_experts, tensor.shape[1], tensor.shape[2]), dtype=self.dtype, device="cpu"
)

if state_dict_utils.is_dtensor(tensor):
split_weights, expert_ids = state_dict_utils.split_experts_weights_dtensor_aware(
tensor, n_experts
)
else:
start_expert, end_expert = state_dict_utils.get_expert_range_for_rank_from_mesh(
device_mesh, n_experts
)
split_weights = [tensor[i].to(self.dtype).cpu() for i in range(tensor.shape[0])]
expert_ids = list(range(start_expert, end_expert))

# If distributed is initialized and we have an ep dimension, gather all slices.
if dist.is_initialized() and "ep" in device_mesh.mesh_dim_names:
try:
ep_dim = device_mesh.mesh_dim_names.index("ep")
ep_group = device_mesh.get_group(ep_dim)
except Exception:
ep_group = None

if ep_group is not None:
payload = (expert_ids, [w.cpu() for w in split_weights])
gathered: list[tuple[list[int], list[torch.Tensor]]] = [None] * dist.get_world_size(
ep_group
)
dist.all_gather_object(gathered, payload, group=ep_group)
for ids, weights in gathered:
for eid, w in zip(ids, weights):
global_tensor[eid].copy_(w.to(self.dtype).cpu())
else:
for weight, expert_id in zip(split_weights, expert_ids):
global_tensor[expert_id].copy_(weight.to(self.dtype).cpu())
else:
for weight, expert_id in zip(split_weights, expert_ids):
global_tensor[expert_id].copy_(weight.to(self.dtype).cpu())
del split_weights
del expert_ids

key = f"{prefix}language_model.layers.{layer_num}.mlp.experts.{which}"
hf_state_dict[key] = global_tensor
del global_tensor
else:
converted_tensors = self.convert_single_tensor_to_hf(
fqn, tensor, exclude_key_regex=exclude_key_regex, quantization=quantization, **kwargs
)
for key, value in converted_tensors:
hf_state_dict[key] = value
else:
converted_tensors = self.convert_single_tensor_to_hf(
fqn, tensor, exclude_key_regex=exclude_key_regex, quantization=quantization, **kwargs
)
continue

hf_state_dict[fqn] = tensor
for key, value in converted_tensors:
hf_state_dict[key] = value

if exclude_key_regex:
import re as _re
Expand Down Expand Up @@ -114,8 +162,6 @@ def from_hf(
if match:
_, layer_num, which = match.groups()
tensor = value
if state_dict_utils.is_dtensor(tensor):
tensor = tensor.to_local()
local_tensor = tensor[start_expert:end_expert].to(self.dtype)
native_key = f"{model_prefix}language_model.layers.{layer_num}.mlp.experts."
native_key += "gate_and_up_projs" if which == "gate_up_proj" else "down_projs"
Expand Down Expand Up @@ -155,5 +201,4 @@ def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[t

if exclude_key_regex:
result = [(k, v) for k, v in result if not re.match(exclude_key_regex, k)]

return result
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,109 @@ def test_respects_exclude_regex(self, adapter):
assert "exclude.me" not in out


def test_aggregates_with_device_mesh_non_dtensor(self, adapter, monkeypatch):
local_experts = torch.tensor(
[
[[1.0, 2.0], [3.0, 4.0]],
[[5.0, 6.0], [7.0, 8.0]],
],
dtype=adapter.dtype,
) # shape: [2, 2, 2]

# Only experts 1 and 2 live on this rank
monkeypatch.setattr(
"nemo_automodel.components.moe.state_dict_utils.get_expert_range_for_rank_from_mesh",
lambda mesh, n_experts: (1, 3),
)
# No distributed init => skip all_gather branch
monkeypatch.setattr("torch.distributed.is_initialized", lambda: False)

device_mesh = Mock()
device_mesh.mesh_dim_names = ["ep"]

state_dict = {
"model.language_model.layers.0.mlp.experts.gate_and_up_projs": local_experts,
}

out = adapter.to_hf(state_dict, device_mesh=device_mesh)
gate_key = "model.language_model.layers.0.mlp.experts.gate_up_proj"
global_gate = out[gate_key]

assert global_gate.shape == (adapter.moe_config.n_routed_experts, 2, 2)
# Experts 1 and 2 should be populated from local_experts; others remain zero
torch.testing.assert_close(global_gate[1:3], local_experts)
assert torch.all(global_gate[0] == 0)
assert torch.all(global_gate[3] == 0)


def test_aggregates_dtensor_path_uses_split_helper(self, adapter, monkeypatch):
local_slice = torch.tensor([[9.0, 10.0]], dtype=adapter.dtype) # shape: [1, 2]

monkeypatch.setattr(
"nemo_automodel.components.moe.state_dict_utils.is_dtensor", lambda tensor: True
)
monkeypatch.setattr(
"nemo_automodel.components.moe.state_dict_utils.split_experts_weights_dtensor_aware",
lambda weight, n_experts: ([local_slice], [2]),
)
monkeypatch.setattr("torch.distributed.is_initialized", lambda: False)

device_mesh = Mock()
device_mesh.mesh_dim_names = ["ep"]

state_dict = {
"model.language_model.layers.0.mlp.experts.down_projs": torch.empty(1, 1, 2),
}

out = adapter.to_hf(state_dict, device_mesh=device_mesh)
down_key = "model.language_model.layers.0.mlp.experts.down_proj"
global_down = out[down_key]

assert global_down.shape[0] == adapter.moe_config.n_routed_experts
torch.testing.assert_close(global_down[2], local_slice)

def test_all_gather_path_populates_global_tensor(self, adapter, monkeypatch):
# Local shard has experts 0 and 1; simulate another rank providing experts 2 and 3
local_experts = torch.tensor(
[
[[1.0]],
[[2.0]],
],
dtype=adapter.dtype,
) # shape: [2, 1, 1]

device_mesh = Mock()
device_mesh.mesh_dim_names = ["ep"]
device_mesh.get_group = lambda dim: "ep_group" if dim == 0 else None

monkeypatch.setattr(
"nemo_automodel.components.moe.state_dict_utils.get_expert_range_for_rank_from_mesh",
lambda mesh, n_experts: (0, 2),
)
monkeypatch.setattr("torch.distributed.is_initialized", lambda: True)
monkeypatch.setattr("torch.distributed.get_world_size", lambda group=None: 2)

def fake_all_gather_object(gathered, payload, group=None):
# payload from this rank for experts [0,1]; simulate other rank with [2,3]
gathered[0] = payload
other_weights = [torch.tensor([[3.0]], dtype=adapter.dtype), torch.tensor([[4.0]], dtype=adapter.dtype)]
gathered[1] = ([2, 3], other_weights)

monkeypatch.setattr("torch.distributed.all_gather_object", fake_all_gather_object)

state_dict = {"model.language_model.layers.0.mlp.experts.gate_and_up_projs": local_experts}
out = adapter.to_hf(state_dict, device_mesh=device_mesh)

gate_key = "model.language_model.layers.0.mlp.experts.gate_up_proj"
global_gate = out[gate_key]

assert global_gate.shape == (adapter.moe_config.n_routed_experts, 1, 1)
torch.testing.assert_close(global_gate[0], torch.tensor([[1.0]], dtype=adapter.dtype))
torch.testing.assert_close(global_gate[1], torch.tensor([[2.0]], dtype=adapter.dtype))
torch.testing.assert_close(global_gate[2], torch.tensor([[3.0]], dtype=adapter.dtype))
torch.testing.assert_close(global_gate[3], torch.tensor([[4.0]], dtype=adapter.dtype))


class TestFromHF:
def test_detects_model_prefix(self, adapter):
hf_state = {
Expand Down Expand Up @@ -173,6 +276,9 @@ def __init__(self, data):
def to_local(self):
return self._data

def __getitem__(self, idx):
return self._data[idx]

captured = {"locals": []}

monkeypatch.setattr(
Expand Down
Loading