diff --git a/nemo_automodel/components/checkpoint/checkpointing.py b/nemo_automodel/components/checkpoint/checkpointing.py index 4dbf6e838..b166e0339 100644 --- a/nemo_automodel/components/checkpoint/checkpointing.py +++ b/nemo_automodel/components/checkpoint/checkpointing.py @@ -201,7 +201,9 @@ def save_model( state_dict = model_state.state_dict() # Convert to HF format if using custom model implementations - state_dict = _maybe_adapt_state_dict_to_hf(model_state.model[0], state_dict, quantization=False) + state_dict = _maybe_adapt_state_dict_to_hf( + model_state.model[0], state_dict, quantization=False, device_mesh=self.moe_mesh + ) # Build the consolidated model.safetensors.index.json if needed fqn_to_file_index_mapping = self._maybe_build_consolidated_index(model_state, state_dict) @@ -305,7 +307,10 @@ def load_model( storage_reader = self._get_storage_reader(model_path, key_mapping, is_init_step=is_init_step) state_dict = _maybe_adapt_state_dict_to_hf( - model_state.model[0], state_dict, quantization=self.config.dequantize_base_checkpoint + model_state.model[0], + state_dict, + quantization=self.config.dequantize_base_checkpoint, + device_mesh=self.moe_mesh, ) state_dict = self._do_load(state_dict, model_path, storage_reader, is_init_step=is_init_step) @@ -848,14 +853,14 @@ def compute_should_use_set_data(tensor, tensor_applied): def _maybe_adapt_state_dict_to_hf( - model_part: nn.Module, state_dict: dict[str, torch.Tensor], quantization: bool = False + model_part: nn.Module, state_dict: dict[str, torch.Tensor], quantization: bool = False, **kwargs ) -> dict[str, torch.Tensor]: """ Custom models use state dict adapters to convert the state dict to the Hugging Face format. """ adapter = getattr(model_part, "state_dict_adapter", None) if adapter: - return adapter.to_hf(state_dict, exclude_key_regex=r".*_extra_state.*", quantization=quantization) + return adapter.to_hf(state_dict, exclude_key_regex=r".*_extra_state.*", quantization=quantization, **kwargs) return state_dict diff --git a/nemo_automodel/components/models/qwen3_vl_moe/state_dict_adapter.py b/nemo_automodel/components/models/qwen3_vl_moe/state_dict_adapter.py index 731fd5044..2248c2bff 100644 --- a/nemo_automodel/components/models/qwen3_vl_moe/state_dict_adapter.py +++ b/nemo_automodel/components/models/qwen3_vl_moe/state_dict_adapter.py @@ -16,6 +16,7 @@ from typing import Any, Optional import torch +import torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh from nemo_automodel.components.checkpoint.state_dict_adapter import StateDictAdapter @@ -49,27 +50,74 @@ def to_hf( quantization: bool = False, **kwargs, ) -> dict[str, Any]: + self._uses_model_prefix = any(key.startswith("model.") for key in state_dict.keys()) prefix = "model." if self._uses_model_prefix else "" hf_state_dict: dict[str, Any] = {} + device_mesh: Optional["DeviceMesh"] = kwargs.get("device_mesh") for fqn, tensor in state_dict.items(): - if ".mlp.experts.gate_and_up_projs" in fqn: + if ".mlp.experts.gate_and_up_projs" in fqn or ".mlp.experts.down_projs" in fqn: layer_num = re.search(r"layers\.(\d+)", fqn).group(1) - hf_state_dict[f"{prefix}language_model.layers.{layer_num}.mlp.experts.gate_up_proj"] = torch.empty( - (self.moe_config.n_routed_experts, tensor.shape[1], tensor.shape[2]), - dtype=self.dtype, - ) - continue - - if ".mlp.experts.down_projs" in fqn: - layer_num = re.search(r"layers\.(\d+)", fqn).group(1) - hf_state_dict[f"{prefix}language_model.layers.{layer_num}.mlp.experts.down_proj"] = torch.empty( - (self.moe_config.n_routed_experts, tensor.shape[1], tensor.shape[2]), - dtype=self.dtype, + which = "gate_up_proj" if "gate_and_up_projs" in fqn else "down_proj" + if device_mesh is not None: + n_experts = self.moe_config.n_routed_experts + # Aggregate this layer's expert tensor only for the current key, then free temps. + global_tensor = torch.zeros( + (n_experts, tensor.shape[1], tensor.shape[2]), dtype=self.dtype, device="cpu" + ) + + if state_dict_utils.is_dtensor(tensor): + split_weights, expert_ids = state_dict_utils.split_experts_weights_dtensor_aware( + tensor, n_experts + ) + else: + start_expert, end_expert = state_dict_utils.get_expert_range_for_rank_from_mesh( + device_mesh, n_experts + ) + split_weights = [tensor[i].to(self.dtype).cpu() for i in range(tensor.shape[0])] + expert_ids = list(range(start_expert, end_expert)) + + # If distributed is initialized and we have an ep dimension, gather all slices. + if dist.is_initialized() and "ep" in device_mesh.mesh_dim_names: + try: + ep_dim = device_mesh.mesh_dim_names.index("ep") + ep_group = device_mesh.get_group(ep_dim) + except Exception: + ep_group = None + + if ep_group is not None: + payload = (expert_ids, [w.cpu() for w in split_weights]) + gathered: list[tuple[list[int], list[torch.Tensor]]] = [None] * dist.get_world_size( + ep_group + ) + dist.all_gather_object(gathered, payload, group=ep_group) + for ids, weights in gathered: + for eid, w in zip(ids, weights): + global_tensor[eid].copy_(w.to(self.dtype).cpu()) + else: + for weight, expert_id in zip(split_weights, expert_ids): + global_tensor[expert_id].copy_(weight.to(self.dtype).cpu()) + else: + for weight, expert_id in zip(split_weights, expert_ids): + global_tensor[expert_id].copy_(weight.to(self.dtype).cpu()) + del split_weights + del expert_ids + + key = f"{prefix}language_model.layers.{layer_num}.mlp.experts.{which}" + hf_state_dict[key] = global_tensor + del global_tensor + else: + converted_tensors = self.convert_single_tensor_to_hf( + fqn, tensor, exclude_key_regex=exclude_key_regex, quantization=quantization, **kwargs + ) + for key, value in converted_tensors: + hf_state_dict[key] = value + else: + converted_tensors = self.convert_single_tensor_to_hf( + fqn, tensor, exclude_key_regex=exclude_key_regex, quantization=quantization, **kwargs ) - continue - - hf_state_dict[fqn] = tensor + for key, value in converted_tensors: + hf_state_dict[key] = value if exclude_key_regex: import re as _re @@ -114,8 +162,6 @@ def from_hf( if match: _, layer_num, which = match.groups() tensor = value - if state_dict_utils.is_dtensor(tensor): - tensor = tensor.to_local() local_tensor = tensor[start_expert:end_expert].to(self.dtype) native_key = f"{model_prefix}language_model.layers.{layer_num}.mlp.experts." native_key += "gate_and_up_projs" if which == "gate_up_proj" else "down_projs" @@ -155,5 +201,4 @@ def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[t if exclude_key_regex: result = [(k, v) for k, v in result if not re.match(exclude_key_regex, k)] - return result diff --git a/tests/unit_tests/models/qwen3_vl_moe/test_qwen3_vl_moe_state_dict_adapter.py b/tests/unit_tests/models/qwen3_vl_moe/test_qwen3_vl_moe_state_dict_adapter.py index 1442f20cd..36fcfce23 100644 --- a/tests/unit_tests/models/qwen3_vl_moe/test_qwen3_vl_moe_state_dict_adapter.py +++ b/tests/unit_tests/models/qwen3_vl_moe/test_qwen3_vl_moe_state_dict_adapter.py @@ -120,6 +120,109 @@ def test_respects_exclude_regex(self, adapter): assert "exclude.me" not in out + def test_aggregates_with_device_mesh_non_dtensor(self, adapter, monkeypatch): + local_experts = torch.tensor( + [ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + ], + dtype=adapter.dtype, + ) # shape: [2, 2, 2] + + # Only experts 1 and 2 live on this rank + monkeypatch.setattr( + "nemo_automodel.components.moe.state_dict_utils.get_expert_range_for_rank_from_mesh", + lambda mesh, n_experts: (1, 3), + ) + # No distributed init => skip all_gather branch + monkeypatch.setattr("torch.distributed.is_initialized", lambda: False) + + device_mesh = Mock() + device_mesh.mesh_dim_names = ["ep"] + + state_dict = { + "model.language_model.layers.0.mlp.experts.gate_and_up_projs": local_experts, + } + + out = adapter.to_hf(state_dict, device_mesh=device_mesh) + gate_key = "model.language_model.layers.0.mlp.experts.gate_up_proj" + global_gate = out[gate_key] + + assert global_gate.shape == (adapter.moe_config.n_routed_experts, 2, 2) + # Experts 1 and 2 should be populated from local_experts; others remain zero + torch.testing.assert_close(global_gate[1:3], local_experts) + assert torch.all(global_gate[0] == 0) + assert torch.all(global_gate[3] == 0) + + + def test_aggregates_dtensor_path_uses_split_helper(self, adapter, monkeypatch): + local_slice = torch.tensor([[9.0, 10.0]], dtype=adapter.dtype) # shape: [1, 2] + + monkeypatch.setattr( + "nemo_automodel.components.moe.state_dict_utils.is_dtensor", lambda tensor: True + ) + monkeypatch.setattr( + "nemo_automodel.components.moe.state_dict_utils.split_experts_weights_dtensor_aware", + lambda weight, n_experts: ([local_slice], [2]), + ) + monkeypatch.setattr("torch.distributed.is_initialized", lambda: False) + + device_mesh = Mock() + device_mesh.mesh_dim_names = ["ep"] + + state_dict = { + "model.language_model.layers.0.mlp.experts.down_projs": torch.empty(1, 1, 2), + } + + out = adapter.to_hf(state_dict, device_mesh=device_mesh) + down_key = "model.language_model.layers.0.mlp.experts.down_proj" + global_down = out[down_key] + + assert global_down.shape[0] == adapter.moe_config.n_routed_experts + torch.testing.assert_close(global_down[2], local_slice) + + def test_all_gather_path_populates_global_tensor(self, adapter, monkeypatch): + # Local shard has experts 0 and 1; simulate another rank providing experts 2 and 3 + local_experts = torch.tensor( + [ + [[1.0]], + [[2.0]], + ], + dtype=adapter.dtype, + ) # shape: [2, 1, 1] + + device_mesh = Mock() + device_mesh.mesh_dim_names = ["ep"] + device_mesh.get_group = lambda dim: "ep_group" if dim == 0 else None + + monkeypatch.setattr( + "nemo_automodel.components.moe.state_dict_utils.get_expert_range_for_rank_from_mesh", + lambda mesh, n_experts: (0, 2), + ) + monkeypatch.setattr("torch.distributed.is_initialized", lambda: True) + monkeypatch.setattr("torch.distributed.get_world_size", lambda group=None: 2) + + def fake_all_gather_object(gathered, payload, group=None): + # payload from this rank for experts [0,1]; simulate other rank with [2,3] + gathered[0] = payload + other_weights = [torch.tensor([[3.0]], dtype=adapter.dtype), torch.tensor([[4.0]], dtype=adapter.dtype)] + gathered[1] = ([2, 3], other_weights) + + monkeypatch.setattr("torch.distributed.all_gather_object", fake_all_gather_object) + + state_dict = {"model.language_model.layers.0.mlp.experts.gate_and_up_projs": local_experts} + out = adapter.to_hf(state_dict, device_mesh=device_mesh) + + gate_key = "model.language_model.layers.0.mlp.experts.gate_up_proj" + global_gate = out[gate_key] + + assert global_gate.shape == (adapter.moe_config.n_routed_experts, 1, 1) + torch.testing.assert_close(global_gate[0], torch.tensor([[1.0]], dtype=adapter.dtype)) + torch.testing.assert_close(global_gate[1], torch.tensor([[2.0]], dtype=adapter.dtype)) + torch.testing.assert_close(global_gate[2], torch.tensor([[3.0]], dtype=adapter.dtype)) + torch.testing.assert_close(global_gate[3], torch.tensor([[4.0]], dtype=adapter.dtype)) + + class TestFromHF: def test_detects_model_prefix(self, adapter): hf_state = { @@ -173,6 +276,9 @@ def __init__(self, data): def to_local(self): return self._data + def __getitem__(self, idx): + return self._data[idx] + captured = {"locals": []} monkeypatch.setattr(