Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit dd2a6a8

Browse files
authored
[Bugfix] Fix internlm2 tensor parallel inference (vllm-project#8055)
1 parent 4ca65a9 commit dd2a6a8

File tree

1 file changed

+34
-13
lines changed

1 file changed

+34
-13
lines changed

vllm/model_executor/models/internlm2.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# -*- coding: utf-8 -*-
2+
from functools import partial
23
from typing import Any, Dict, Iterable, List, Optional, Tuple
34

45
import torch
@@ -7,7 +8,10 @@
78

89
from vllm.attention import Attention, AttentionMetadata
910
from vllm.config import CacheConfig
10-
from vllm.distributed import get_tensor_model_parallel_world_size
11+
from vllm.distributed import (get_tensor_model_parallel_rank,
12+
get_tensor_model_parallel_world_size,
13+
split_tensor_along_last_dim,
14+
tensor_model_parallel_all_gather)
1115
from vllm.model_executor.layers.activation import SiluAndMul
1216
from vllm.model_executor.layers.layernorm import RMSNorm
1317
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@@ -70,20 +74,21 @@ def __init__(
7074
) -> None:
7175
super().__init__()
7276
self.hidden_size = hidden_size
73-
tp_size = get_tensor_model_parallel_world_size()
77+
self.tp_size = get_tensor_model_parallel_world_size()
78+
self.tp_rank = get_tensor_model_parallel_rank()
7479
self.total_num_heads = num_heads
75-
assert self.total_num_heads % tp_size == 0
76-
self.num_heads = self.total_num_heads // tp_size
80+
assert self.total_num_heads % self.tp_size == 0
81+
self.num_heads = self.total_num_heads // self.tp_size
7782
self.total_num_kv_heads = num_kv_heads
78-
if self.total_num_kv_heads >= tp_size:
83+
if self.total_num_kv_heads >= self.tp_size:
7984
# Number of KV heads is greater than TP size, so we partition
8085
# the KV heads across multiple tensor parallel GPUs.
81-
assert self.total_num_kv_heads % tp_size == 0
86+
assert self.total_num_kv_heads % self.tp_size == 0
8287
else:
8388
# Number of KV heads is less than TP size, so we replicate
8489
# the KV heads across multiple tensor parallel GPUs.
85-
assert tp_size % self.total_num_kv_heads == 0
86-
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
90+
assert self.tp_size % self.total_num_kv_heads == 0
91+
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
8792
self.head_dim = hidden_size // self.total_num_heads
8893
self.q_size = self.num_heads * self.head_dim
8994
self.kv_size = self.num_kv_heads * self.head_dim
@@ -122,11 +127,27 @@ def __init__(
122127
quant_config=quant_config)
123128

124129
def split_qkv(self, qkv: torch.Tensor):
125-
qkv = qkv.view(-1, self.num_kv_heads, self.key_value_groups + 2, 128)
126-
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=2)
127-
q = q.reshape(-1, self.q_size)
128-
k = k.reshape(-1, self.kv_size)
129-
v = v.reshape(-1, self.kv_size)
130+
seq_len = qkv.shape[0]
131+
if self.tp_size > 1:
132+
qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
133+
qkv = tensor_model_parallel_all_gather(qkv)
134+
qkv = torch.split(qkv, qkv_map, dim=-1)
135+
qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
136+
qkv = torch.cat(qkv, dim=-1)
137+
138+
qkv = qkv.view(seq_len, self.total_num_kv_heads,
139+
self.key_value_groups + 2, self.head_dim)
140+
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
141+
q = q.reshape(seq_len, self.q_size * self.tp_size)
142+
k = k.reshape(seq_len, self.kv_size * self.tp_size)
143+
v = v.reshape(seq_len, self.kv_size * self.tp_size)
144+
145+
if self.tp_size > 1:
146+
splitter = partial(split_tensor_along_last_dim,
147+
num_partitions=self.tp_size)
148+
q = splitter(q)[self.tp_rank]
149+
k = splitter(k)[self.tp_rank]
150+
v = splitter(v)[self.tp_rank]
130151
return q, k, v
131152

132153
def forward(

0 commit comments

Comments
 (0)