Skip to content

Commit 1dbb888

Browse files
authored
[Bugfix] LoRA logits einsum dimension mismatch in add_lora_logits (vllm-project#1583)
### What this PR does / why we need it? This PR fixes a tensor shape mismatch in `add_lora_logits`. Previously, `lora_a_stacked` was passed as shape `[num_loras, in_dim, rank]`, which does not match the expected einsum pattern `"bi, boi -> bo"` used in `bgmv_shrink`. This causes runtime errors like: RuntimeError: einsum(): subscript i has size 3 for operand 1 which does not broadcast with previously seen size 4 ![image](https://github.com/user-attachments/assets/63029479-49ae-4c3c-b995-f6805d15ad06) This fix transposes `lora_a_stacked` and `lora_b_stacked` to match the expected shapes: - `lora_a`: `[num_loras, rank, in_dim]` - `lora_b`: `[num_loras, out_dim, rank]` All unit tests pass after this fix. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? ``` import torch import pytest from unittest.mock import patch, PropertyMock, ANY from vllm_ascend.lora.punica_wrapper.punica_npu import PunicaWrapperNPU @pytest.fixture def wrapper_cpu(): cfg = {"max_num_batched_tokens": 10, "max_batches": 2, "device": "cpu"} w = PunicaWrapperNPU(**cfg) w.is_prefill = True w.no_lora = False return w def test_add_lora_logits(wrapper_cpu): batch_size = 2 hidden_size = 4 lora_rank = 3 vocab_size = 5 y = torch.zeros(batch_size, vocab_size) x = torch.randn(batch_size, hidden_size) num_loras = 1 lora_a = torch.randn(num_loras, hidden_size, lora_rank) lora_b = torch.randn(num_loras, lora_rank, vocab_size) with patch.object(wrapper_cpu.__class__, "sampler_indices", new_callable=PropertyMock) as mock_idx: mock_idx.return_value = torch.zeros(batch_size, dtype=torch.long) wrapper_cpu.add_lora_logits(y, x, lora_a, lora_b, scale=1.0) assert y.shape == (batch_size, vocab_size) assert not torch.allclose(y, torch.zeros_like(y)) Signed-off-by: hongfugui <[email protected]>
1 parent d80b0cc commit 1dbb888

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

vllm_ascend/lora/punica_wrapper/punica_npu.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def add_lora_logits(self,
322322
**kwargs) -> None:
323323
"""
324324
Applies lora specifically for LogitsProcessorWithLoRA.
325-
325+
326326
Semantics:
327327
buffer = (x @ lora_a_stacked) * scale
328328
y += buffer @ lora_b_stacked
@@ -338,18 +338,27 @@ def add_lora_logits(self,
338338
y_org = y
339339
y = y.view(-1, y.shape[-1])
340340
x = x.view(-1, x.shape[-1])
341-
r = lora_b_stacked.size(-1)
341+
342+
if lora_a_stacked.dim() == 2:
343+
lora_a_stacked = lora_a_stacked.unsqueeze(0)
344+
if lora_b_stacked.dim() == 2:
345+
lora_b_stacked = lora_b_stacked.unsqueeze(0)
346+
347+
r = lora_a_stacked.size(-1)
348+
342349
if buffer is None:
343-
# We set the buffer to be float32 by default, consistent with the
344-
# triton op
345350
buffer = torch.zeros((x.size(0), r),
346351
dtype=torch.float32,
347352
device=x.device)
348-
# LogitsProcessorWithLoRA always using bgmv.
349-
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
350-
bgmv_expand(buffer,
351-
lora_b_stacked,
352-
y,
353-
self.sampler_indices,
354-
add_inputs=True)
353+
354+
indices = self.sampler_indices
355+
if indices.max() >= lora_a_stacked.size(0):
356+
indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1)
357+
358+
lora_a_reshaped = lora_a_stacked.transpose(1, 2)
359+
lora_b_reshaped = lora_b_stacked.transpose(1, 2)
360+
361+
bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale)
362+
bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True)
363+
355364
y = y.view_as(y_org)

0 commit comments

Comments
 (0)