|
16 | 16 | from typing import Any, Optional |
17 | 17 |
|
18 | 18 | import torch |
| 19 | +import torch.distributed as dist |
19 | 20 | from torch.distributed.device_mesh import DeviceMesh |
20 | 21 |
|
21 | 22 | from nemo_automodel.components.checkpoint.state_dict_adapter import StateDictAdapter |
@@ -49,27 +50,74 @@ def to_hf( |
49 | 50 | quantization: bool = False, |
50 | 51 | **kwargs, |
51 | 52 | ) -> dict[str, Any]: |
| 53 | + self._uses_model_prefix = any(key.startswith("model.") for key in state_dict.keys()) |
52 | 54 | prefix = "model." if self._uses_model_prefix else "" |
53 | 55 | hf_state_dict: dict[str, Any] = {} |
| 56 | + device_mesh: Optional["DeviceMesh"] = kwargs.get("device_mesh") |
54 | 57 |
|
55 | 58 | for fqn, tensor in state_dict.items(): |
56 | | - if ".mlp.experts.gate_and_up_projs" in fqn: |
| 59 | + if ".mlp.experts.gate_and_up_projs" in fqn or ".mlp.experts.down_projs" in fqn: |
57 | 60 | layer_num = re.search(r"layers\.(\d+)", fqn).group(1) |
58 | | - hf_state_dict[f"{prefix}language_model.layers.{layer_num}.mlp.experts.gate_up_proj"] = torch.empty( |
59 | | - (self.moe_config.n_routed_experts, tensor.shape[1], tensor.shape[2]), |
60 | | - dtype=self.dtype, |
61 | | - ) |
62 | | - continue |
63 | | - |
64 | | - if ".mlp.experts.down_projs" in fqn: |
65 | | - layer_num = re.search(r"layers\.(\d+)", fqn).group(1) |
66 | | - hf_state_dict[f"{prefix}language_model.layers.{layer_num}.mlp.experts.down_proj"] = torch.empty( |
67 | | - (self.moe_config.n_routed_experts, tensor.shape[1], tensor.shape[2]), |
68 | | - dtype=self.dtype, |
| 61 | + which = "gate_up_proj" if "gate_and_up_projs" in fqn else "down_proj" |
| 62 | + if device_mesh is not None: |
| 63 | + n_experts = self.moe_config.n_routed_experts |
| 64 | + # Aggregate this layer's expert tensor only for the current key, then free temps. |
| 65 | + global_tensor = torch.zeros( |
| 66 | + (n_experts, tensor.shape[1], tensor.shape[2]), dtype=self.dtype, device="cpu" |
| 67 | + ) |
| 68 | + |
| 69 | + if state_dict_utils.is_dtensor(tensor): |
| 70 | + split_weights, expert_ids = state_dict_utils.split_experts_weights_dtensor_aware( |
| 71 | + tensor, n_experts |
| 72 | + ) |
| 73 | + else: |
| 74 | + start_expert, end_expert = state_dict_utils.get_expert_range_for_rank_from_mesh( |
| 75 | + device_mesh, n_experts |
| 76 | + ) |
| 77 | + split_weights = [tensor[i].to(self.dtype).cpu() for i in range(tensor.shape[0])] |
| 78 | + expert_ids = list(range(start_expert, end_expert)) |
| 79 | + |
| 80 | + # If distributed is initialized and we have an ep dimension, gather all slices. |
| 81 | + if dist.is_initialized() and "ep" in device_mesh.mesh_dim_names: |
| 82 | + try: |
| 83 | + ep_dim = device_mesh.mesh_dim_names.index("ep") |
| 84 | + ep_group = device_mesh.get_group(ep_dim) |
| 85 | + except Exception: |
| 86 | + ep_group = None |
| 87 | + |
| 88 | + if ep_group is not None: |
| 89 | + payload = (expert_ids, [w.cpu() for w in split_weights]) |
| 90 | + gathered: list[tuple[list[int], list[torch.Tensor]]] = [None] * dist.get_world_size( |
| 91 | + ep_group |
| 92 | + ) |
| 93 | + dist.all_gather_object(gathered, payload, group=ep_group) |
| 94 | + for ids, weights in gathered: |
| 95 | + for eid, w in zip(ids, weights): |
| 96 | + global_tensor[eid].copy_(w.to(self.dtype).cpu()) |
| 97 | + else: |
| 98 | + for weight, expert_id in zip(split_weights, expert_ids): |
| 99 | + global_tensor[expert_id].copy_(weight.to(self.dtype).cpu()) |
| 100 | + else: |
| 101 | + for weight, expert_id in zip(split_weights, expert_ids): |
| 102 | + global_tensor[expert_id].copy_(weight.to(self.dtype).cpu()) |
| 103 | + del split_weights |
| 104 | + del expert_ids |
| 105 | + |
| 106 | + key = f"{prefix}language_model.layers.{layer_num}.mlp.experts.{which}" |
| 107 | + hf_state_dict[key] = global_tensor |
| 108 | + del global_tensor |
| 109 | + else: |
| 110 | + converted_tensors = self.convert_single_tensor_to_hf( |
| 111 | + fqn, tensor, exclude_key_regex=exclude_key_regex, quantization=quantization, **kwargs |
| 112 | + ) |
| 113 | + for key, value in converted_tensors: |
| 114 | + hf_state_dict[key] = value |
| 115 | + else: |
| 116 | + converted_tensors = self.convert_single_tensor_to_hf( |
| 117 | + fqn, tensor, exclude_key_regex=exclude_key_regex, quantization=quantization, **kwargs |
69 | 118 | ) |
70 | | - continue |
71 | | - |
72 | | - hf_state_dict[fqn] = tensor |
| 119 | + for key, value in converted_tensors: |
| 120 | + hf_state_dict[key] = value |
73 | 121 |
|
74 | 122 | if exclude_key_regex: |
75 | 123 | import re as _re |
@@ -114,8 +162,6 @@ def from_hf( |
114 | 162 | if match: |
115 | 163 | _, layer_num, which = match.groups() |
116 | 164 | tensor = value |
117 | | - if state_dict_utils.is_dtensor(tensor): |
118 | | - tensor = tensor.to_local() |
119 | 165 | local_tensor = tensor[start_expert:end_expert].to(self.dtype) |
120 | 166 | native_key = f"{model_prefix}language_model.layers.{layer_num}.mlp.experts." |
121 | 167 | native_key += "gate_and_up_projs" if which == "gate_up_proj" else "down_projs" |
@@ -155,5 +201,4 @@ def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[t |
155 | 201 |
|
156 | 202 | if exclude_key_regex: |
157 | 203 | result = [(k, v) for k, v in result if not re.match(exclude_key_regex, k)] |
158 | | - |
159 | 204 | return result |
0 commit comments