1313# limitations under the License.
1414"""Inference-only OLMoE model compatible with HuggingFace weights."""
1515from collections .abc import Iterable
16+ from functools import partial
1617from typing import Any , Optional , Union
1718
1819import torch
2223from vllm .attention import Attention
2324from vllm .compilation .decorators import support_torch_compile
2425from vllm .config import CacheConfig , VllmConfig
25- from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
26+ from vllm .distributed import (get_pp_group , get_tensor_model_parallel_rank ,
27+ get_tensor_model_parallel_world_size ,
28+ tensor_model_parallel_all_gather )
29+ from vllm .distributed .utils import split_tensor_along_last_dim
2630from vllm .logger import init_logger
2731from vllm .model_executor .layers .fused_moe import FusedMoE
2832from vllm .model_executor .layers .layernorm import RMSNorm
@@ -140,8 +144,11 @@ def __init__(
140144 bias = False ,
141145 quant_config = quant_config ,
142146 )
143- self .q_norm = RMSNorm (hidden_size , eps = 1e-5 )
144- self .k_norm = RMSNorm (hidden_size , eps = 1e-5 )
147+ self .tp_size = tp_size
148+ self .tp_rank = get_tensor_model_parallel_rank ()
149+ self .q_norm = RMSNorm (self .total_num_heads * self .head_dim , eps = 1e-5 )
150+ self .k_norm = RMSNorm (self .total_num_kv_heads * self .head_dim ,
151+ eps = 1e-5 )
145152 self .o_proj = RowParallelLinear (
146153 self .total_num_heads * self .head_dim ,
147154 hidden_size ,
@@ -165,14 +172,28 @@ def __init__(
165172 quant_config = quant_config ,
166173 prefix = f"{ prefix } .attn" )
167174
175+ def _apply_qk_norm (self , q : torch .Tensor ,
176+ k : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
177+ if self .tp_size > 1 :
178+ q = tensor_model_parallel_all_gather (q .contiguous ())
179+ k = tensor_model_parallel_all_gather (k .contiguous ())
180+ q = self .q_norm (q )
181+ k = self .k_norm (k )
182+ if self .tp_size > 1 :
183+ splitter = partial (split_tensor_along_last_dim ,
184+ num_partitions = self .tp_size )
185+ q = splitter (q )[self .tp_rank ]
186+ k = splitter (k )[self .tp_rank ]
187+ return q , k
188+
168189 def forward (
169190 self ,
170191 positions : torch .Tensor ,
171192 hidden_states : torch .Tensor ,
172193 ) -> torch .Tensor :
173194 qkv , _ = self .qkv_proj (hidden_states )
174195 q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
175- q , k = self .q_norm ( q . contiguous ()), self . k_norm ( k . contiguous () )
196+ q , k = self ._apply_qk_norm ( q , k )
176197 q , k = self .rotary_emb (positions , q , k )
177198 attn_output = self .attn (q , k , v )
178199 output , _ = self .o_proj (attn_output )
0 commit comments