|
1 | 1 | # -*- coding: utf-8 -*-
|
| 2 | +from functools import partial |
2 | 3 | from typing import Any, Dict, Iterable, List, Optional, Tuple
|
3 | 4 |
|
4 | 5 | import torch
|
|
7 | 8 |
|
8 | 9 | from vllm.attention import Attention, AttentionMetadata
|
9 | 10 | from vllm.config import CacheConfig
|
10 |
| -from vllm.distributed import get_tensor_model_parallel_world_size |
| 11 | +from vllm.distributed import (get_tensor_model_parallel_rank, |
| 12 | + get_tensor_model_parallel_world_size, |
| 13 | + split_tensor_along_last_dim, |
| 14 | + tensor_model_parallel_all_gather) |
11 | 15 | from vllm.model_executor.layers.activation import SiluAndMul
|
12 | 16 | from vllm.model_executor.layers.layernorm import RMSNorm
|
13 | 17 | from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
@@ -70,20 +74,21 @@ def __init__(
|
70 | 74 | ) -> None:
|
71 | 75 | super().__init__()
|
72 | 76 | self.hidden_size = hidden_size
|
73 |
| - tp_size = get_tensor_model_parallel_world_size() |
| 77 | + self.tp_size = get_tensor_model_parallel_world_size() |
| 78 | + self.tp_rank = get_tensor_model_parallel_rank() |
74 | 79 | self.total_num_heads = num_heads
|
75 |
| - assert self.total_num_heads % tp_size == 0 |
76 |
| - self.num_heads = self.total_num_heads // tp_size |
| 80 | + assert self.total_num_heads % self.tp_size == 0 |
| 81 | + self.num_heads = self.total_num_heads // self.tp_size |
77 | 82 | self.total_num_kv_heads = num_kv_heads
|
78 |
| - if self.total_num_kv_heads >= tp_size: |
| 83 | + if self.total_num_kv_heads >= self.tp_size: |
79 | 84 | # Number of KV heads is greater than TP size, so we partition
|
80 | 85 | # the KV heads across multiple tensor parallel GPUs.
|
81 |
| - assert self.total_num_kv_heads % tp_size == 0 |
| 86 | + assert self.total_num_kv_heads % self.tp_size == 0 |
82 | 87 | else:
|
83 | 88 | # Number of KV heads is less than TP size, so we replicate
|
84 | 89 | # the KV heads across multiple tensor parallel GPUs.
|
85 |
| - assert tp_size % self.total_num_kv_heads == 0 |
86 |
| - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) |
| 90 | + assert self.tp_size % self.total_num_kv_heads == 0 |
| 91 | + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) |
87 | 92 | self.head_dim = hidden_size // self.total_num_heads
|
88 | 93 | self.q_size = self.num_heads * self.head_dim
|
89 | 94 | self.kv_size = self.num_kv_heads * self.head_dim
|
@@ -122,11 +127,27 @@ def __init__(
|
122 | 127 | quant_config=quant_config)
|
123 | 128 |
|
124 | 129 | def split_qkv(self, qkv: torch.Tensor):
|
125 |
| - qkv = qkv.view(-1, self.num_kv_heads, self.key_value_groups + 2, 128) |
126 |
| - q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=2) |
127 |
| - q = q.reshape(-1, self.q_size) |
128 |
| - k = k.reshape(-1, self.kv_size) |
129 |
| - v = v.reshape(-1, self.kv_size) |
| 130 | + seq_len = qkv.shape[0] |
| 131 | + if self.tp_size > 1: |
| 132 | + qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size |
| 133 | + qkv = tensor_model_parallel_all_gather(qkv) |
| 134 | + qkv = torch.split(qkv, qkv_map, dim=-1) |
| 135 | + qkv = qkv[::3] + qkv[1::3] + qkv[2::3] |
| 136 | + qkv = torch.cat(qkv, dim=-1) |
| 137 | + |
| 138 | + qkv = qkv.view(seq_len, self.total_num_kv_heads, |
| 139 | + self.key_value_groups + 2, self.head_dim) |
| 140 | + q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2) |
| 141 | + q = q.reshape(seq_len, self.q_size * self.tp_size) |
| 142 | + k = k.reshape(seq_len, self.kv_size * self.tp_size) |
| 143 | + v = v.reshape(seq_len, self.kv_size * self.tp_size) |
| 144 | + |
| 145 | + if self.tp_size > 1: |
| 146 | + splitter = partial(split_tensor_along_last_dim, |
| 147 | + num_partitions=self.tp_size) |
| 148 | + q = splitter(q)[self.tp_rank] |
| 149 | + k = splitter(k)[self.tp_rank] |
| 150 | + v = splitter(v)[self.tp_rank] |
130 | 151 | return q, k, v
|
131 | 152 |
|
132 | 153 | def forward(
|
|
0 commit comments