Skip to content

Commit 672806f

Browse files
authored
fix: fix qwen3vlmoe state dict adapter (#1002)
1 parent a8d9ca3 commit 672806f

File tree

3 files changed

+178
-22
lines changed

3 files changed

+178
-22
lines changed

nemo_automodel/components/checkpoint/checkpointing.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,9 @@ def save_model(
201201
state_dict = model_state.state_dict()
202202

203203
# Convert to HF format if using custom model implementations
204-
state_dict = _maybe_adapt_state_dict_to_hf(model_state.model[0], state_dict, quantization=False)
204+
state_dict = _maybe_adapt_state_dict_to_hf(
205+
model_state.model[0], state_dict, quantization=False, device_mesh=self.moe_mesh
206+
)
205207
# Build the consolidated model.safetensors.index.json if needed
206208
fqn_to_file_index_mapping = self._maybe_build_consolidated_index(model_state, state_dict)
207209

@@ -305,7 +307,10 @@ def load_model(
305307
storage_reader = self._get_storage_reader(model_path, key_mapping, is_init_step=is_init_step)
306308

307309
state_dict = _maybe_adapt_state_dict_to_hf(
308-
model_state.model[0], state_dict, quantization=self.config.dequantize_base_checkpoint
310+
model_state.model[0],
311+
state_dict,
312+
quantization=self.config.dequantize_base_checkpoint,
313+
device_mesh=self.moe_mesh,
309314
)
310315

311316
state_dict = self._do_load(state_dict, model_path, storage_reader, is_init_step=is_init_step)
@@ -848,14 +853,14 @@ def compute_should_use_set_data(tensor, tensor_applied):
848853

849854

850855
def _maybe_adapt_state_dict_to_hf(
851-
model_part: nn.Module, state_dict: dict[str, torch.Tensor], quantization: bool = False
856+
model_part: nn.Module, state_dict: dict[str, torch.Tensor], quantization: bool = False, **kwargs
852857
) -> dict[str, torch.Tensor]:
853858
"""
854859
Custom models use state dict adapters to convert the state dict to the Hugging Face format.
855860
"""
856861
adapter = getattr(model_part, "state_dict_adapter", None)
857862
if adapter:
858-
return adapter.to_hf(state_dict, exclude_key_regex=r".*_extra_state.*", quantization=quantization)
863+
return adapter.to_hf(state_dict, exclude_key_regex=r".*_extra_state.*", quantization=quantization, **kwargs)
859864
return state_dict
860865

861866

nemo_automodel/components/models/qwen3_vl_moe/state_dict_adapter.py

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Any, Optional
1717

1818
import torch
19+
import torch.distributed as dist
1920
from torch.distributed.device_mesh import DeviceMesh
2021

2122
from nemo_automodel.components.checkpoint.state_dict_adapter import StateDictAdapter
@@ -49,27 +50,74 @@ def to_hf(
4950
quantization: bool = False,
5051
**kwargs,
5152
) -> dict[str, Any]:
53+
self._uses_model_prefix = any(key.startswith("model.") for key in state_dict.keys())
5254
prefix = "model." if self._uses_model_prefix else ""
5355
hf_state_dict: dict[str, Any] = {}
56+
device_mesh: Optional["DeviceMesh"] = kwargs.get("device_mesh")
5457

5558
for fqn, tensor in state_dict.items():
56-
if ".mlp.experts.gate_and_up_projs" in fqn:
59+
if ".mlp.experts.gate_and_up_projs" in fqn or ".mlp.experts.down_projs" in fqn:
5760
layer_num = re.search(r"layers\.(\d+)", fqn).group(1)
58-
hf_state_dict[f"{prefix}language_model.layers.{layer_num}.mlp.experts.gate_up_proj"] = torch.empty(
59-
(self.moe_config.n_routed_experts, tensor.shape[1], tensor.shape[2]),
60-
dtype=self.dtype,
61-
)
62-
continue
63-
64-
if ".mlp.experts.down_projs" in fqn:
65-
layer_num = re.search(r"layers\.(\d+)", fqn).group(1)
66-
hf_state_dict[f"{prefix}language_model.layers.{layer_num}.mlp.experts.down_proj"] = torch.empty(
67-
(self.moe_config.n_routed_experts, tensor.shape[1], tensor.shape[2]),
68-
dtype=self.dtype,
61+
which = "gate_up_proj" if "gate_and_up_projs" in fqn else "down_proj"
62+
if device_mesh is not None:
63+
n_experts = self.moe_config.n_routed_experts
64+
# Aggregate this layer's expert tensor only for the current key, then free temps.
65+
global_tensor = torch.zeros(
66+
(n_experts, tensor.shape[1], tensor.shape[2]), dtype=self.dtype, device="cpu"
67+
)
68+
69+
if state_dict_utils.is_dtensor(tensor):
70+
split_weights, expert_ids = state_dict_utils.split_experts_weights_dtensor_aware(
71+
tensor, n_experts
72+
)
73+
else:
74+
start_expert, end_expert = state_dict_utils.get_expert_range_for_rank_from_mesh(
75+
device_mesh, n_experts
76+
)
77+
split_weights = [tensor[i].to(self.dtype).cpu() for i in range(tensor.shape[0])]
78+
expert_ids = list(range(start_expert, end_expert))
79+
80+
# If distributed is initialized and we have an ep dimension, gather all slices.
81+
if dist.is_initialized() and "ep" in device_mesh.mesh_dim_names:
82+
try:
83+
ep_dim = device_mesh.mesh_dim_names.index("ep")
84+
ep_group = device_mesh.get_group(ep_dim)
85+
except Exception:
86+
ep_group = None
87+
88+
if ep_group is not None:
89+
payload = (expert_ids, [w.cpu() for w in split_weights])
90+
gathered: list[tuple[list[int], list[torch.Tensor]]] = [None] * dist.get_world_size(
91+
ep_group
92+
)
93+
dist.all_gather_object(gathered, payload, group=ep_group)
94+
for ids, weights in gathered:
95+
for eid, w in zip(ids, weights):
96+
global_tensor[eid].copy_(w.to(self.dtype).cpu())
97+
else:
98+
for weight, expert_id in zip(split_weights, expert_ids):
99+
global_tensor[expert_id].copy_(weight.to(self.dtype).cpu())
100+
else:
101+
for weight, expert_id in zip(split_weights, expert_ids):
102+
global_tensor[expert_id].copy_(weight.to(self.dtype).cpu())
103+
del split_weights
104+
del expert_ids
105+
106+
key = f"{prefix}language_model.layers.{layer_num}.mlp.experts.{which}"
107+
hf_state_dict[key] = global_tensor
108+
del global_tensor
109+
else:
110+
converted_tensors = self.convert_single_tensor_to_hf(
111+
fqn, tensor, exclude_key_regex=exclude_key_regex, quantization=quantization, **kwargs
112+
)
113+
for key, value in converted_tensors:
114+
hf_state_dict[key] = value
115+
else:
116+
converted_tensors = self.convert_single_tensor_to_hf(
117+
fqn, tensor, exclude_key_regex=exclude_key_regex, quantization=quantization, **kwargs
69118
)
70-
continue
71-
72-
hf_state_dict[fqn] = tensor
119+
for key, value in converted_tensors:
120+
hf_state_dict[key] = value
73121

74122
if exclude_key_regex:
75123
import re as _re
@@ -114,8 +162,6 @@ def from_hf(
114162
if match:
115163
_, layer_num, which = match.groups()
116164
tensor = value
117-
if state_dict_utils.is_dtensor(tensor):
118-
tensor = tensor.to_local()
119165
local_tensor = tensor[start_expert:end_expert].to(self.dtype)
120166
native_key = f"{model_prefix}language_model.layers.{layer_num}.mlp.experts."
121167
native_key += "gate_and_up_projs" if which == "gate_up_proj" else "down_projs"
@@ -155,5 +201,4 @@ def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[t
155201

156202
if exclude_key_regex:
157203
result = [(k, v) for k, v in result if not re.match(exclude_key_regex, k)]
158-
159204
return result

tests/unit_tests/models/qwen3_vl_moe/test_qwen3_vl_moe_state_dict_adapter.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,109 @@ def test_respects_exclude_regex(self, adapter):
120120
assert "exclude.me" not in out
121121

122122

123+
def test_aggregates_with_device_mesh_non_dtensor(self, adapter, monkeypatch):
124+
local_experts = torch.tensor(
125+
[
126+
[[1.0, 2.0], [3.0, 4.0]],
127+
[[5.0, 6.0], [7.0, 8.0]],
128+
],
129+
dtype=adapter.dtype,
130+
) # shape: [2, 2, 2]
131+
132+
# Only experts 1 and 2 live on this rank
133+
monkeypatch.setattr(
134+
"nemo_automodel.components.moe.state_dict_utils.get_expert_range_for_rank_from_mesh",
135+
lambda mesh, n_experts: (1, 3),
136+
)
137+
# No distributed init => skip all_gather branch
138+
monkeypatch.setattr("torch.distributed.is_initialized", lambda: False)
139+
140+
device_mesh = Mock()
141+
device_mesh.mesh_dim_names = ["ep"]
142+
143+
state_dict = {
144+
"model.language_model.layers.0.mlp.experts.gate_and_up_projs": local_experts,
145+
}
146+
147+
out = adapter.to_hf(state_dict, device_mesh=device_mesh)
148+
gate_key = "model.language_model.layers.0.mlp.experts.gate_up_proj"
149+
global_gate = out[gate_key]
150+
151+
assert global_gate.shape == (adapter.moe_config.n_routed_experts, 2, 2)
152+
# Experts 1 and 2 should be populated from local_experts; others remain zero
153+
torch.testing.assert_close(global_gate[1:3], local_experts)
154+
assert torch.all(global_gate[0] == 0)
155+
assert torch.all(global_gate[3] == 0)
156+
157+
158+
def test_aggregates_dtensor_path_uses_split_helper(self, adapter, monkeypatch):
159+
local_slice = torch.tensor([[9.0, 10.0]], dtype=adapter.dtype) # shape: [1, 2]
160+
161+
monkeypatch.setattr(
162+
"nemo_automodel.components.moe.state_dict_utils.is_dtensor", lambda tensor: True
163+
)
164+
monkeypatch.setattr(
165+
"nemo_automodel.components.moe.state_dict_utils.split_experts_weights_dtensor_aware",
166+
lambda weight, n_experts: ([local_slice], [2]),
167+
)
168+
monkeypatch.setattr("torch.distributed.is_initialized", lambda: False)
169+
170+
device_mesh = Mock()
171+
device_mesh.mesh_dim_names = ["ep"]
172+
173+
state_dict = {
174+
"model.language_model.layers.0.mlp.experts.down_projs": torch.empty(1, 1, 2),
175+
}
176+
177+
out = adapter.to_hf(state_dict, device_mesh=device_mesh)
178+
down_key = "model.language_model.layers.0.mlp.experts.down_proj"
179+
global_down = out[down_key]
180+
181+
assert global_down.shape[0] == adapter.moe_config.n_routed_experts
182+
torch.testing.assert_close(global_down[2], local_slice)
183+
184+
def test_all_gather_path_populates_global_tensor(self, adapter, monkeypatch):
185+
# Local shard has experts 0 and 1; simulate another rank providing experts 2 and 3
186+
local_experts = torch.tensor(
187+
[
188+
[[1.0]],
189+
[[2.0]],
190+
],
191+
dtype=adapter.dtype,
192+
) # shape: [2, 1, 1]
193+
194+
device_mesh = Mock()
195+
device_mesh.mesh_dim_names = ["ep"]
196+
device_mesh.get_group = lambda dim: "ep_group" if dim == 0 else None
197+
198+
monkeypatch.setattr(
199+
"nemo_automodel.components.moe.state_dict_utils.get_expert_range_for_rank_from_mesh",
200+
lambda mesh, n_experts: (0, 2),
201+
)
202+
monkeypatch.setattr("torch.distributed.is_initialized", lambda: True)
203+
monkeypatch.setattr("torch.distributed.get_world_size", lambda group=None: 2)
204+
205+
def fake_all_gather_object(gathered, payload, group=None):
206+
# payload from this rank for experts [0,1]; simulate other rank with [2,3]
207+
gathered[0] = payload
208+
other_weights = [torch.tensor([[3.0]], dtype=adapter.dtype), torch.tensor([[4.0]], dtype=adapter.dtype)]
209+
gathered[1] = ([2, 3], other_weights)
210+
211+
monkeypatch.setattr("torch.distributed.all_gather_object", fake_all_gather_object)
212+
213+
state_dict = {"model.language_model.layers.0.mlp.experts.gate_and_up_projs": local_experts}
214+
out = adapter.to_hf(state_dict, device_mesh=device_mesh)
215+
216+
gate_key = "model.language_model.layers.0.mlp.experts.gate_up_proj"
217+
global_gate = out[gate_key]
218+
219+
assert global_gate.shape == (adapter.moe_config.n_routed_experts, 1, 1)
220+
torch.testing.assert_close(global_gate[0], torch.tensor([[1.0]], dtype=adapter.dtype))
221+
torch.testing.assert_close(global_gate[1], torch.tensor([[2.0]], dtype=adapter.dtype))
222+
torch.testing.assert_close(global_gate[2], torch.tensor([[3.0]], dtype=adapter.dtype))
223+
torch.testing.assert_close(global_gate[3], torch.tensor([[4.0]], dtype=adapter.dtype))
224+
225+
123226
class TestFromHF:
124227
def test_detects_model_prefix(self, adapter):
125228
hf_state = {
@@ -173,6 +276,9 @@ def __init__(self, data):
173276
def to_local(self):
174277
return self._data
175278

279+
def __getitem__(self, idx):
280+
return self._data[idx]
281+
176282
captured = {"locals": []}
177283

178284
monkeypatch.setattr(

0 commit comments

Comments
 (0)