Skip to content

Commit f1bd72f

Browse files
committed
fix state dict adapter
Signed-off-by: HuiyingLi <[email protected]>
1 parent 6c38eae commit f1bd72f

File tree

2 files changed

+74
-22
lines changed

2 files changed

+74
-22
lines changed

nemo_automodel/components/checkpoint/checkpointing.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,9 @@ def save_model(
201201
state_dict = model_state.state_dict()
202202

203203
# Convert to HF format if using custom model implementations
204-
state_dict = _maybe_adapt_state_dict_to_hf(model_state.model[0], state_dict, quantization=False)
204+
state_dict = _maybe_adapt_state_dict_to_hf(
205+
model_state.model[0], state_dict, quantization=False, device_mesh=self.moe_mesh
206+
)
205207
# Build the consolidated model.safetensors.index.json if needed
206208
fqn_to_file_index_mapping = self._maybe_build_consolidated_index(model_state, state_dict)
207209

@@ -305,7 +307,10 @@ def load_model(
305307
storage_reader = self._get_storage_reader(model_path, key_mapping, is_init_step=is_init_step)
306308

307309
state_dict = _maybe_adapt_state_dict_to_hf(
308-
model_state.model[0], state_dict, quantization=self.config.dequantize_base_checkpoint
310+
model_state.model[0],
311+
state_dict,
312+
quantization=self.config.dequantize_base_checkpoint,
313+
device_mesh=self.moe_mesh,
309314
)
310315

311316
state_dict = self._do_load(state_dict, model_path, storage_reader, is_init_step=is_init_step)
@@ -848,14 +853,16 @@ def compute_should_use_set_data(tensor, tensor_applied):
848853

849854

850855
def _maybe_adapt_state_dict_to_hf(
851-
model_part: nn.Module, state_dict: dict[str, torch.Tensor], quantization: bool = False
856+
model_part: nn.Module, state_dict: dict[str, torch.Tensor], quantization: bool = False, **kwargs
852857
) -> dict[str, torch.Tensor]:
853858
"""
854859
Custom models use state dict adapters to convert the state dict to the Hugging Face format.
855860
"""
856861
adapter = getattr(model_part, "state_dict_adapter", None)
857862
if adapter:
858-
return adapter.to_hf(state_dict, exclude_key_regex=r".*_extra_state.*", quantization=quantization)
863+
return adapter.to_hf(
864+
state_dict, exclude_key_regex=r".*_extra_state.*", quantization=quantization, **kwargs
865+
)
859866
return state_dict
860867

861868

nemo_automodel/components/models/qwen3_vl_moe/state_dict_adapter.py

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Any, Optional
1717

1818
import torch
19+
import torch.distributed as dist
1920
from torch.distributed.device_mesh import DeviceMesh
2021

2122
from nemo_automodel.components.checkpoint.state_dict_adapter import StateDictAdapter
@@ -49,27 +50,74 @@ def to_hf(
4950
quantization: bool = False,
5051
**kwargs,
5152
) -> dict[str, Any]:
53+
self._uses_model_prefix = any(key.startswith("model.") for key in state_dict.keys())
5254
prefix = "model." if self._uses_model_prefix else ""
5355
hf_state_dict: dict[str, Any] = {}
56+
device_mesh: Optional["DeviceMesh"] = kwargs.get("device_mesh")
5457

5558
for fqn, tensor in state_dict.items():
56-
if ".mlp.experts.gate_and_up_projs" in fqn:
59+
if ".mlp.experts.gate_and_up_projs" in fqn or ".mlp.experts.down_projs" in fqn:
5760
layer_num = re.search(r"layers\.(\d+)", fqn).group(1)
58-
hf_state_dict[f"{prefix}language_model.layers.{layer_num}.mlp.experts.gate_up_proj"] = torch.empty(
59-
(self.moe_config.n_routed_experts, tensor.shape[1], tensor.shape[2]),
60-
dtype=self.dtype,
61-
)
62-
continue
63-
64-
if ".mlp.experts.down_projs" in fqn:
65-
layer_num = re.search(r"layers\.(\d+)", fqn).group(1)
66-
hf_state_dict[f"{prefix}language_model.layers.{layer_num}.mlp.experts.down_proj"] = torch.empty(
67-
(self.moe_config.n_routed_experts, tensor.shape[1], tensor.shape[2]),
68-
dtype=self.dtype,
61+
which = "gate_up_proj" if "gate_and_up_projs" in fqn else "down_proj"
62+
if device_mesh is not None:
63+
n_experts = self.moe_config.n_routed_experts
64+
# Aggregate this layer's expert tensor only for the current key, then free temps.
65+
global_tensor = torch.zeros(
66+
(n_experts, tensor.shape[1], tensor.shape[2]), dtype=self.dtype, device="cpu"
67+
)
68+
69+
if state_dict_utils.is_dtensor(tensor):
70+
split_weights, expert_ids = state_dict_utils.split_experts_weights_dtensor_aware(
71+
tensor, n_experts
72+
)
73+
else:
74+
start_expert, end_expert = state_dict_utils.get_expert_range_for_rank_from_mesh(
75+
device_mesh, n_experts
76+
)
77+
split_weights = [tensor[i].to(self.dtype).cpu() for i in range(tensor.shape[0])]
78+
expert_ids = list(range(start_expert, end_expert))
79+
80+
# If distributed is initialized and we have an ep dimension, gather all slices.
81+
if dist.is_initialized() and "ep" in device_mesh.mesh_dim_names:
82+
try:
83+
ep_dim = device_mesh.mesh_dim_names.index("ep")
84+
ep_group = device_mesh.get_group(ep_dim)
85+
except Exception:
86+
ep_group = None
87+
88+
if ep_group is not None:
89+
payload = (expert_ids, [w.cpu() for w in split_weights])
90+
gathered: list[tuple[list[int], list[torch.Tensor]]] = [None] * dist.get_world_size(
91+
ep_group
92+
)
93+
dist.all_gather_object(gathered, payload, group=ep_group)
94+
for ids, weights in gathered:
95+
for eid, w in zip(ids, weights):
96+
global_tensor[eid].copy_(w.to(self.dtype).cpu())
97+
else:
98+
for weight, expert_id in zip(split_weights, expert_ids):
99+
global_tensor[expert_id].copy_(weight.to(self.dtype).cpu())
100+
else:
101+
for weight, expert_id in zip(split_weights, expert_ids):
102+
global_tensor[expert_id].copy_(weight.to(self.dtype).cpu())
103+
del split_weights
104+
del expert_ids
105+
106+
key = f"{prefix}language_model.layers.{layer_num}.mlp.experts.{which}"
107+
hf_state_dict[key] = global_tensor
108+
del global_tensor
109+
else:
110+
converted_tensors = self.convert_single_tensor_to_hf(
111+
fqn, tensor, exclude_key_regex=exclude_key_regex, quantization=quantization, **kwargs
112+
)
113+
for key, value in converted_tensors:
114+
hf_state_dict[key] = value
115+
else:
116+
converted_tensors = self.convert_single_tensor_to_hf(
117+
fqn, tensor, exclude_key_regex=exclude_key_regex, quantization=quantization, **kwargs
69118
)
70-
continue
71-
72-
hf_state_dict[fqn] = tensor
119+
for key, value in converted_tensors:
120+
hf_state_dict[key] = value
73121

74122
if exclude_key_regex:
75123
import re as _re
@@ -114,8 +162,6 @@ def from_hf(
114162
if match:
115163
_, layer_num, which = match.groups()
116164
tensor = value
117-
if state_dict_utils.is_dtensor(tensor):
118-
tensor = tensor.to_local()
119165
local_tensor = tensor[start_expert:end_expert].to(self.dtype)
120166
native_key = f"{model_prefix}language_model.layers.{layer_num}.mlp.experts."
121167
native_key += "gate_and_up_projs" if which == "gate_up_proj" else "down_projs"
@@ -155,5 +201,4 @@ def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[t
155201

156202
if exclude_key_regex:
157203
result = [(k, v) for k, v in result if not re.match(exclude_key_regex, k)]
158-
159204
return result

0 commit comments

Comments
 (0)