Skip to content

Commit c68b5c6

Browse files
authored
[Misc] fix olmoe model layer can't laod in tp gt 1 (vllm-project#18828)
Signed-off-by: rongfu.leng <[email protected]>
1 parent fced756 commit c68b5c6

File tree

1 file changed

+25
-4
lines changed

1 file changed

+25
-4
lines changed

vllm/model_executor/models/olmoe.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Inference-only OLMoE model compatible with HuggingFace weights."""
1515
from collections.abc import Iterable
16+
from functools import partial
1617
from typing import Any, Optional, Union
1718

1819
import torch
@@ -22,7 +23,10 @@
2223
from vllm.attention import Attention
2324
from vllm.compilation.decorators import support_torch_compile
2425
from vllm.config import CacheConfig, VllmConfig
25-
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
26+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
27+
get_tensor_model_parallel_world_size,
28+
tensor_model_parallel_all_gather)
29+
from vllm.distributed.utils import split_tensor_along_last_dim
2630
from vllm.logger import init_logger
2731
from vllm.model_executor.layers.fused_moe import FusedMoE
2832
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -140,8 +144,11 @@ def __init__(
140144
bias=False,
141145
quant_config=quant_config,
142146
)
143-
self.q_norm = RMSNorm(hidden_size, eps=1e-5)
144-
self.k_norm = RMSNorm(hidden_size, eps=1e-5)
147+
self.tp_size = tp_size
148+
self.tp_rank = get_tensor_model_parallel_rank()
149+
self.q_norm = RMSNorm(self.total_num_heads * self.head_dim, eps=1e-5)
150+
self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim,
151+
eps=1e-5)
145152
self.o_proj = RowParallelLinear(
146153
self.total_num_heads * self.head_dim,
147154
hidden_size,
@@ -165,14 +172,28 @@ def __init__(
165172
quant_config=quant_config,
166173
prefix=f"{prefix}.attn")
167174

175+
def _apply_qk_norm(self, q: torch.Tensor,
176+
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
177+
if self.tp_size > 1:
178+
q = tensor_model_parallel_all_gather(q.contiguous())
179+
k = tensor_model_parallel_all_gather(k.contiguous())
180+
q = self.q_norm(q)
181+
k = self.k_norm(k)
182+
if self.tp_size > 1:
183+
splitter = partial(split_tensor_along_last_dim,
184+
num_partitions=self.tp_size)
185+
q = splitter(q)[self.tp_rank]
186+
k = splitter(k)[self.tp_rank]
187+
return q, k
188+
168189
def forward(
169190
self,
170191
positions: torch.Tensor,
171192
hidden_states: torch.Tensor,
172193
) -> torch.Tensor:
173194
qkv, _ = self.qkv_proj(hidden_states)
174195
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
175-
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
196+
q, k = self._apply_qk_norm(q, k)
176197
q, k = self.rotary_emb(positions, q, k)
177198
attn_output = self.attn(q, k, v)
178199
output, _ = self.o_proj(attn_output)

0 commit comments

Comments
 (0)