Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 67005a0

Browse files
jeejeeleeYard1
andauthored
[Bugfix] Add fully sharded layer for QKVParallelLinearWithLora (vllm-project#5665)
Co-authored-by: Antoni Baum <[email protected]>
1 parent c35e4a3 commit 67005a0

File tree

5 files changed

+93
-26
lines changed

5 files changed

+93
-26
lines changed

tests/lora/test_baichuan.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def test_baichuan_lora(baichuan_lora_files):
6464

6565

6666
@pytest.mark.skip("Requires multiple GPUs")
67-
def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
67+
@pytest.mark.parametrize("fully_sharded", [True, False])
68+
def test_baichuan_tensor_parallel_equality(baichuan_lora_files, fully_sharded):
6869
# Cannot use as it will initialize torch.cuda too early...
6970
# if torch.cuda.device_count() < 4:
7071
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
@@ -75,7 +76,8 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
7576
max_loras=4,
7677
max_lora_rank=64,
7778
tensor_parallel_size=1,
78-
trust_remote_code=True)
79+
trust_remote_code=True,
80+
fully_sharded_loras=fully_sharded)
7981
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)
8082

8183
del llm_tp1
@@ -87,7 +89,8 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
8789
max_loras=4,
8890
max_lora_rank=64,
8991
tensor_parallel_size=2,
90-
trust_remote_code=True)
92+
trust_remote_code=True,
93+
fully_sharded_loras=fully_sharded)
9194
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)
9295

9396
del llm_tp2
@@ -101,10 +104,11 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
101104
max_loras=4,
102105
max_lora_rank=64,
103106
tensor_parallel_size=4,
104-
trust_remote_code=True)
107+
trust_remote_code=True,
108+
fully_sharded_loras=fully_sharded)
105109
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)
106110

107111
del llm_tp4
108112
cleanup()
109113

110-
assert output_tp1 == output_tp4
114+
assert output_tp1 == output_tp4

tests/lora/test_layers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from vllm.lora.fully_sharded_layers import (
1313
ColumnParallelLinearWithShardedLoRA,
1414
MergedColumnParallelLinearWithShardedLoRA,
15-
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
15+
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
16+
RowParallelLinearWithShardedLoRA)
1617
# yapf conflicts with isort for this block
1718
# yapf: disable
1819
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
@@ -684,7 +685,9 @@ def create_column_parallel_packed_layer():
684685
bias=False,
685686
params_dtype=torch.float16)
686687
linear.weight.data = torch.rand_like(linear.weight.data)
687-
lora_linear = QKVParallelLinearWithLora(linear)
688+
lora_linear = QKVParallelLinearWithLora(
689+
linear
690+
) if not fully_shard else QKVParallelLinearWithShardedLora(linear)
688691

689692
@dataclass
690693
class FakeConfig:

vllm/lora/fully_sharded_layers.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
1313
MergedColumnParallelLinearWithLoRA,
1414
MergedQKVParallelLinearWithLora,
15+
QKVParallelLinearWithLora,
1516
RowParallelLinearWithLoRA)
1617
from vllm.lora.punica import bgmv, dispatch_bgmv_low_level
1718

@@ -90,11 +91,11 @@ def can_replace_layer(cls, source_layer: nn.Module,
9091
def _mcp_apply(x, bias, layer):
9192
"""
9293
MergedColumnParallelLinearWithShardedLoRA and
93-
QKVParallelLinearWithShardedLora share the same
94+
MergedQKVParallelLinearWithShardedLora share the same
9495
LoRa weight application method.
9596
9697
The main difference is the step by shard_size for lora_b which can
97-
vary for QKVParallelLinearWithShardedLora but is constant for
98+
vary for MergedQKVParallelLinearWithShardedLora but is constant for
9899
MergedColumnParallelLinearWithShardedLoRA.
99100
"""
100101
# expecting 2 for column parallel and 3 for qkv
@@ -167,14 +168,65 @@ def can_replace_layer(cls, source_layer: nn.Module,
167168
)
168169

169170

170-
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
171+
class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
171172
"""
172173
Differs from QKVParallelLinearWithLora by slicing the
173174
LoRA A's also.
174175
175176
Based on S-LoRA, slicing happens along the rank dim.
176177
"""
177178

179+
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
180+
tp_rank = get_tensor_model_parallel_rank()
181+
shard_size = self.lora_a_stacked.shape[2]
182+
start_idx = tp_rank * shard_size
183+
lora_a = lora_a[:, start_idx:start_idx + shard_size]
184+
return lora_a
185+
186+
def apply(self, x: torch.Tensor,
187+
bias: Optional[torch.Tensor]) -> torch.Tensor:
188+
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
189+
190+
x = x.view(-1, x.shape[-1])
191+
output, out_orig_shape = output.view(-1,
192+
output.shape[-1]), output.shape
193+
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
194+
dtype=torch.float32,
195+
device=x.device)
196+
197+
bgmv(buffer, x, self.lora_a_stacked,
198+
self.indices[:self.indices_len[0]], 0, 1.0)
199+
buffer = tensor_model_parallel_all_gather(buffer)
200+
bgmv(output, buffer, self.lora_b_stacked,
201+
self.indices[:self.indices_len[0]], 0, 1.0)
202+
# now have column partitioned output
203+
204+
output = output.view(*out_orig_shape)
205+
return output
206+
207+
@classmethod
208+
@_fully_sharded_can_replace
209+
def can_replace_layer(cls, source_layer: nn.Module,
210+
lora_config: LoRAConfig, packed_modules_list: List,
211+
model_config: Optional[PretrainedConfig]) -> bool:
212+
# specifying kwargs so they can be easily accessed in decorator
213+
return super().can_replace_layer(
214+
source_layer=source_layer,
215+
lora_config=lora_config,
216+
packed_modules_list=packed_modules_list,
217+
model_config=model_config,
218+
decorate=False,
219+
)
220+
221+
222+
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
223+
"""
224+
Differs from MergedQKVParallelLinearWithLora by slicing the
225+
LoRA A's also.
226+
227+
Based on S-LoRA, slicing happens along the rank dim.
228+
"""
229+
178230
def slice_lora_a(
179231
self, lora_a: List[Union[torch.Tensor, None]]
180232
) -> List[Union[torch.Tensor, None]]:

vllm/lora/layers.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,24 @@ def __init__(self, base_layer: QKVParallelLinear) -> None:
641641
self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
642642
self.base_layer.head_size)
643643

644+
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
645+
tp_rank = get_tensor_model_parallel_rank()
646+
self.q_shard_id = tp_rank
647+
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
648+
lora_b_q = lora_b[:, self.q_proj_shard_size *
649+
self.q_shard_id:self.q_proj_shard_size *
650+
(self.q_shard_id + 1)]
651+
k_offset = self.q_proj_total_size
652+
lora_b_k = lora_b[:, k_offset +
653+
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
654+
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
655+
v_offset = k_offset + self.kv_proj_total_size
656+
lora_b_v = lora_b[:, v_offset +
657+
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
658+
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
659+
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
660+
return lora_b
661+
644662
def set_lora(
645663
self,
646664
index: int,
@@ -650,21 +668,8 @@ def set_lora(
650668
):
651669
self.reset_lora(index)
652670
if self.tp_size > 1:
653-
tp_rank = get_tensor_model_parallel_rank()
654-
self.q_shard_id = tp_rank
655-
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
656-
lora_b_q = lora_b[:, self.q_proj_shard_size *
657-
self.q_shard_id:self.q_proj_shard_size *
658-
(self.q_shard_id + 1)]
659-
k_offset = self.q_proj_total_size
660-
lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
661-
self.kv_shard_id:k_offset +
662-
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
663-
v_offset = k_offset + self.kv_proj_total_size
664-
lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
665-
self.kv_shard_id:v_offset +
666-
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
667-
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
671+
lora_a = self.slice_lora_a(lora_a)
672+
lora_b = self.slice_lora_b(lora_b)
668673

669674
self.lora_a_stacked[index,
670675
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
@@ -674,6 +679,7 @@ def set_lora(
674679
lora_b.T, non_blocking=True)
675680

676681
@classmethod
682+
@_not_fully_sharded_can_replace
677683
def can_replace_layer(cls, source_layer: nn.Module,
678684
lora_config: LoRAConfig, packed_modules_list: List,
679685
model_config: Optional[PretrainedConfig]) -> bool:

vllm/lora/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from vllm.lora.fully_sharded_layers import (
99
ColumnParallelLinearWithShardedLoRA,
1010
MergedColumnParallelLinearWithShardedLoRA,
11-
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
11+
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
12+
RowParallelLinearWithShardedLoRA)
1213
# being imported for _all_lora_classes below
1314
# yapf conflicts with isort for this block
1415
# yapf: disable
@@ -35,6 +36,7 @@
3536
RowParallelLinearWithLoRA,
3637
LogitsProcessorWithLoRA,
3738
ColumnParallelLinearWithShardedLoRA,
39+
QKVParallelLinearWithShardedLora,
3840
MergedColumnParallelLinearWithShardedLoRA,
3941
MergedQKVParallelLinearWithShardedLora,
4042
RowParallelLinearWithShardedLoRA,

0 commit comments

Comments
 (0)