Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 1f56742

Browse files
authored
[Kernel] Add punica dimension for Qwen2 LoRA (vllm-project#5441)
1 parent b12518d commit 1f56742

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

csrc/punica/bgmv/bgmv_config.h

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,33 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
1616
f(in_T, out_T, W_T, narrow, 512) \
1717
f(in_T, out_T, W_T, narrow, 640) \
1818
f(in_T, out_T, W_T, narrow, 768) \
19+
f(in_T, out_T, W_T, narrow, 896) \
1920
f(in_T, out_T, W_T, narrow, 1024) \
2021
f(in_T, out_T, W_T, narrow, 1152) \
22+
f(in_T, out_T, W_T, narrow, 1216) \
2123
f(in_T, out_T, W_T, narrow, 1280) \
2224
f(in_T, out_T, W_T, narrow, 1536) \
2325
f(in_T, out_T, W_T, narrow, 1664) \
2426
f(in_T, out_T, W_T, narrow, 1728) \
2527
f(in_T, out_T, W_T, narrow, 1792) \
2628
f(in_T, out_T, W_T, narrow, 2048) \
29+
f(in_T, out_T, W_T, narrow, 2240) \
2730
f(in_T, out_T, W_T, narrow, 2304) \
31+
f(in_T, out_T, W_T, narrow, 2368) \
32+
f(in_T, out_T, W_T, narrow, 2432) \
2833
f(in_T, out_T, W_T, narrow, 2560) \
2934
f(in_T, out_T, W_T, narrow, 2752) \
3035
f(in_T, out_T, W_T, narrow, 2816) \
3136
f(in_T, out_T, W_T, narrow, 3072) \
3237
f(in_T, out_T, W_T, narrow, 3328) \
3338
f(in_T, out_T, W_T, narrow, 3456) \
3439
f(in_T, out_T, W_T, narrow, 3584) \
40+
f(in_T, out_T, W_T, narrow, 3712) \
3541
f(in_T, out_T, W_T, narrow, 4096) \
42+
f(in_T, out_T, W_T, narrow, 4480) \
3643
f(in_T, out_T, W_T, narrow, 4608) \
44+
f(in_T, out_T, W_T, narrow, 4736) \
45+
f(in_T, out_T, W_T, narrow, 4864) \
3746
f(in_T, out_T, W_T, narrow, 5120) \
3847
f(in_T, out_T, W_T, narrow, 5504) \
3948
f(in_T, out_T, W_T, narrow, 5632) \
@@ -43,24 +52,32 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
4352
f(in_T, out_T, W_T, narrow, 6848) \
4453
f(in_T, out_T, W_T, narrow, 6912) \
4554
f(in_T, out_T, W_T, narrow, 7168) \
55+
f(in_T, out_T, W_T, narrow, 7424) \
4656
f(in_T, out_T, W_T, narrow, 8192) \
57+
f(in_T, out_T, W_T, narrow, 8960) \
4758
f(in_T, out_T, W_T, narrow, 9216) \
59+
f(in_T, out_T, W_T, narrow, 9472) \
4860
f(in_T, out_T, W_T, narrow, 10240) \
4961
f(in_T, out_T, W_T, narrow, 11008) \
5062
f(in_T, out_T, W_T, narrow, 11264) \
5163
f(in_T, out_T, W_T, narrow, 12288) \
5264
f(in_T, out_T, W_T, narrow, 13696) \
5365
f(in_T, out_T, W_T, narrow, 13824) \
5466
f(in_T, out_T, W_T, narrow, 14336) \
67+
f(in_T, out_T, W_T, narrow, 14784) \
68+
f(in_T, out_T, W_T, narrow, 14848) \
5569
f(in_T, out_T, W_T, narrow, 15360) \
5670
f(in_T, out_T, W_T, narrow, 16384) \
71+
f(in_T, out_T, W_T, narrow, 18944) \
5772
f(in_T, out_T, W_T, narrow, 20480) \
5873
f(in_T, out_T, W_T, narrow, 22016) \
5974
f(in_T, out_T, W_T, narrow, 22528) \
6075
f(in_T, out_T, W_T, narrow, 24576) \
6176
f(in_T, out_T, W_T, narrow, 27392) \
6277
f(in_T, out_T, W_T, narrow, 27648) \
6378
f(in_T, out_T, W_T, narrow, 28672) \
79+
f(in_T, out_T, W_T, narrow, 29568) \
80+
f(in_T, out_T, W_T, narrow, 29696) \
6481
f(in_T, out_T, W_T, narrow, 32000) \
6582
f(in_T, out_T, W_T, narrow, 32256) \
6683
f(in_T, out_T, W_T, narrow, 32512) \
@@ -85,34 +102,43 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
85102
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
86103
// and vllm/tests/lora/test_punica.py
87104

88-
// Used for defining kernels going from the variety of
105+
// Used for defining kernels going from the variety of
89106
// dim in to the narrow dim out
90-
// Using it for the fully sharded column
107+
// Using it for the fully sharded column
91108
// parallel LoRA A which splits the rank dim
92109
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
93110
f(in_T, out_T, W_T, 128, narrow) \
94111
f(in_T, out_T, W_T, 256, narrow) \
95112
f(in_T, out_T, W_T, 512, narrow) \
96113
f(in_T, out_T, W_T, 640, narrow) \
97114
f(in_T, out_T, W_T, 768, narrow) \
115+
f(in_T, out_T, W_T, 896, narrow) \
98116
f(in_T, out_T, W_T, 1024, narrow) \
99117
f(in_T, out_T, W_T, 1152, narrow) \
118+
f(in_T, out_T, W_T, 1216, narrow) \
100119
f(in_T, out_T, W_T, 1280, narrow) \
101120
f(in_T, out_T, W_T, 1536, narrow) \
102121
f(in_T, out_T, W_T, 1664, narrow) \
103122
f(in_T, out_T, W_T, 1728, narrow) \
104123
f(in_T, out_T, W_T, 1792, narrow) \
105124
f(in_T, out_T, W_T, 2048, narrow) \
125+
f(in_T, out_T, W_T, 2240, narrow) \
106126
f(in_T, out_T, W_T, 2304, narrow) \
127+
f(in_T, out_T, W_T, 2368, narrow) \
128+
f(in_T, out_T, W_T, 2432, narrow) \
107129
f(in_T, out_T, W_T, 2560, narrow) \
108130
f(in_T, out_T, W_T, 2752, narrow) \
109131
f(in_T, out_T, W_T, 2816, narrow) \
110132
f(in_T, out_T, W_T, 3072, narrow) \
111133
f(in_T, out_T, W_T, 3328, narrow) \
112134
f(in_T, out_T, W_T, 3456, narrow) \
113135
f(in_T, out_T, W_T, 3584, narrow) \
136+
f(in_T, out_T, W_T, 3712, narrow) \
114137
f(in_T, out_T, W_T, 4096, narrow) \
138+
f(in_T, out_T, W_T, 4480, narrow) \
115139
f(in_T, out_T, W_T, 4608, narrow) \
140+
f(in_T, out_T, W_T, 4736, narrow) \
141+
f(in_T, out_T, W_T, 4864, narrow) \
116142
f(in_T, out_T, W_T, 5120, narrow) \
117143
f(in_T, out_T, W_T, 5504, narrow) \
118144
f(in_T, out_T, W_T, 5632, narrow) \
@@ -122,24 +148,32 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
122148
f(in_T, out_T, W_T, 6848, narrow) \
123149
f(in_T, out_T, W_T, 6912, narrow) \
124150
f(in_T, out_T, W_T, 7168, narrow) \
151+
f(in_T, out_T, W_T, 7424, narrow) \
125152
f(in_T, out_T, W_T, 8192, narrow) \
153+
f(in_T, out_T, W_T, 8960, narrow) \
126154
f(in_T, out_T, W_T, 9216, narrow) \
155+
f(in_T, out_T, W_T, 9472, narrow) \
127156
f(in_T, out_T, W_T, 10240, narrow) \
128157
f(in_T, out_T, W_T, 11008, narrow) \
129158
f(in_T, out_T, W_T, 11264, narrow) \
130159
f(in_T, out_T, W_T, 12288, narrow) \
131160
f(in_T, out_T, W_T, 13696, narrow) \
132161
f(in_T, out_T, W_T, 13824, narrow) \
133162
f(in_T, out_T, W_T, 14336, narrow) \
163+
f(in_T, out_T, W_T, 14784, narrow) \
164+
f(in_T, out_T, W_T, 14848, narrow) \
134165
f(in_T, out_T, W_T, 15360, narrow) \
135166
f(in_T, out_T, W_T, 16384, narrow) \
167+
f(in_T, out_T, W_T, 18944, narrow) \
136168
f(in_T, out_T, W_T, 20480, narrow) \
137169
f(in_T, out_T, W_T, 22016, narrow) \
138170
f(in_T, out_T, W_T, 22528, narrow) \
139171
f(in_T, out_T, W_T, 24576, narrow) \
140172
f(in_T, out_T, W_T, 27392, narrow) \
141173
f(in_T, out_T, W_T, 27648, narrow) \
142174
f(in_T, out_T, W_T, 28672, narrow) \
175+
f(in_T, out_T, W_T, 29568, narrow) \
176+
f(in_T, out_T, W_T, 29696, narrow) \
143177
f(in_T, out_T, W_T, 32000, narrow) \
144178
f(in_T, out_T, W_T, 32256, narrow) \
145179
f(in_T, out_T, W_T, 32512, narrow) \

tests/lora/test_punica.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,30 @@ def _lora_ref_impl(
4949
128,
5050
256,
5151
512,
52+
896,
5253
1024,
5354
1152,
55+
1216,
5456
1280,
5557
1536,
5658
1664,
5759
2048,
60+
2240,
5861
2304,
62+
2368,
63+
2432,
5964
2560,
6065
2752,
6166
3072,
6267
3328,
6368
3456,
6469
3584,
70+
3712,
6571
4096,
72+
4480,
6673
4608,
74+
4736,
75+
4864,
6776
5120,
6877
5504,
6978
5632,
@@ -73,19 +82,27 @@ def _lora_ref_impl(
7382
6848,
7483
6912,
7584
7168,
85+
7424,
7686
8192,
87+
8960,
7788
9216,
89+
9472,
7890
10240,
7991
11008,
8092
11264,
8193
13824,
8294
14336,
95+
14784,
96+
14848,
8397
15360,
98+
18944,
8499
22016,
85100
22528,
86101
24576,
87102
27392,
88103
27648,
104+
29568,
105+
29696,
89106
32000,
90107
32256,
91108
32512,

0 commit comments

Comments
 (0)