Skip to content

Commit 540336e

Browse files
Add Custom Kernels For LoRA Performance (vllm-project#1884)
### What this PR does / why we need it? Add two custom kernels(bgmv_shrink and bgmv expand) to solve the performance of LoRA ### Does this PR introduce _any_ user-facing change? no user-facing change ### How was this patch tested? we add Unit Test file to test the custom ascendc kernel. See vllm-ascend/tests/e2e/singlecard/ops/test_bgmv_expand.py and vllm-ascend/tests/e2e/singlecard/ops/test_bgmv_expand.py Based on the actual test of the QWen2.5 7B model using vllm-ascend version v0.9.2.rc1, the TTFT, TPOT and throughput have increased by about 70%. - vLLM version: v0.9.2 - vLLM main: vllm-project/vllm@40d86ee --------- Signed-off-by: taoxudonghaha <[email protected]>
1 parent 2da281e commit 540336e

File tree

8 files changed

+946
-3
lines changed

8 files changed

+946
-3
lines changed

csrc/kernels/bgmv_expand.cpp

Lines changed: 369 additions & 0 deletions
Large diffs are not rendered by default.

csrc/kernels/bgmv_shrink.cpp

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
/*
2+
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "kernel_operator.h"
18+
#include "types.h"
19+
20+
template <typename scalar_t>
21+
class BGMVShrink {
22+
public:
23+
using X_T = scalar_t;
24+
using W_T = scalar_t;
25+
using Y_T = float;
26+
27+
static constexpr uint64_t BUFFER_NUM = 1;
28+
static constexpr uint64_t TILE_LENGTH = 11776; // optimal performance tile length
29+
30+
public:
31+
__aicore__ inline BGMVShrink(AscendC::TPipe *pipe) : pipe_(pipe) {}
32+
__aicore__ inline void Init(__gm__ void *x, __gm__ void *weight, __gm__ void *indices, __gm__ void *y,
33+
uint32_t batchSize, uint32_t numTokensPerCore, uint32_t inputHiddenDim,
34+
uint32_t maxLoRARank, float scale)
35+
{
36+
batchSize_ = batchSize;
37+
numTokensPerCore_ = numTokensPerCore;
38+
inputHiddenDim_ = inputHiddenDim;
39+
maxLoRARank_ = maxLoRARank;
40+
scale_ = scale;
41+
singleLoRAWeightLen_ = inputHiddenDim_ * maxLoRARank_;
42+
incremental_ = inputHiddenDim_ > TILE_LENGTH;
43+
44+
xGm_.SetGlobalBuffer((__gm__ X_T *)x);
45+
yOutGm_.SetGlobalBuffer((__gm__ Y_T *)y);
46+
wGm_.SetGlobalBuffer((__gm__ W_T *)weight);
47+
indicesGm_.SetGlobalBuffer((__gm__ int64_t *)indices);
48+
49+
pipe_->InitBuffer(inQueueX_, BUFFER_NUM, TILE_LENGTH * sizeof(X_T));
50+
pipe_->InitBuffer(inQueueW_, BUFFER_NUM, TILE_LENGTH * sizeof(W_T));
51+
pipe_->InitBuffer(tmpBufferX_, TILE_LENGTH * sizeof(float));
52+
pipe_->InitBuffer(tmpBufferW_, TILE_LENGTH * sizeof(float));
53+
54+
pipe_->InitBuffer(outQueueY_, 1, maxLoRARank_ * sizeof(Y_T));
55+
pipe_->InitBuffer(outBufferY_, maxLoRARank_ * sizeof(float));
56+
}
57+
58+
__aicore__ inline void Process()
59+
{
60+
int64_t blockIdx = AscendC::GetBlockIdx();
61+
int64_t startIdx = blockIdx * numTokensPerCore_;
62+
int64_t endIdx = startIdx + numTokensPerCore_;
63+
if (endIdx > batchSize_) {
64+
endIdx = batchSize_;
65+
}
66+
for (int64_t idx = startIdx; idx < endIdx; idx++) {
67+
// set up LoRA index
68+
CopyInIndex(idx);
69+
if (reqLoRAIndex_ < 0) {
70+
continue;
71+
}
72+
reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_;
73+
74+
if (incremental_) {
75+
ProcessImpl<true>(idx);
76+
} else {
77+
ProcessImpl<false>(idx);
78+
}
79+
80+
ScaleOutput();
81+
CopyOut(idx);
82+
}
83+
}
84+
85+
private:
86+
template <bool INCREMENTAL_MODE>
87+
__aicore__ inline void ProcessImpl(const int64_t idx)
88+
{
89+
AscendC::LocalTensor<float> yOutLocal = outBufferY_.Get<float>();
90+
if constexpr (!INCREMENTAL_MODE) {
91+
CopyInX(idx, 0, inputHiddenDim_);
92+
AscendC::LocalTensor<float> xTmpTensor = tmpBufferX_.Get<float>();
93+
AscendC::LocalTensor<X_T> xLocal = inQueueX_.DeQue<X_T>();
94+
Cast(xTmpTensor, xLocal, AscendC::RoundMode::CAST_NONE, inputHiddenDim_);
95+
pipe_barrier(PIPE_V);
96+
inQueueX_.FreeTensor(xLocal);
97+
}
98+
99+
for (int i = 0; i < maxLoRARank_; i++) {
100+
float acc(0);
101+
for (int32_t j = 0; j < inputHiddenDim_ / TILE_LENGTH; j++) {
102+
if constexpr (INCREMENTAL_MODE) {
103+
CopyInX(idx, j);
104+
}
105+
CopyInW(i, j);
106+
Compute<INCREMENTAL_MODE>(acc);
107+
}
108+
CopyAndComputeLastIteration<INCREMENTAL_MODE>(idx, i, acc);
109+
yOutLocal.SetValue(i, acc);
110+
}
111+
}
112+
113+
__aicore__ inline void CopyInIndex(const int64_t idx)
114+
{
115+
// look up the LoRA index
116+
reqLoRAIndex_ = indicesGm_.GetValue(idx);
117+
}
118+
119+
__aicore__ inline void CopyInX(const int64_t idx, int32_t colIdx, int32_t numElements = TILE_LENGTH)
120+
{
121+
AscendC::LocalTensor<X_T> xLocal = inQueueX_.AllocTensor<X_T>();
122+
DataCopy(xLocal, xGm_[inputHiddenDim_ * idx + colIdx * TILE_LENGTH], numElements);
123+
inQueueX_.EnQue(xLocal);
124+
}
125+
126+
__aicore__ inline void CopyInW(int32_t rowIdx, int32_t colIdx, int32_t numElements = TILE_LENGTH)
127+
{
128+
AscendC::LocalTensor<W_T> wLocal = inQueueW_.AllocTensor<W_T>();
129+
DataCopy(wLocal, wGm_[reqLoRAWeightOffset_ + rowIdx * inputHiddenDim_ + colIdx * TILE_LENGTH], numElements);
130+
inQueueW_.EnQue(wLocal);
131+
}
132+
133+
template <bool INCREMENTAL_MODE>
134+
__aicore__ inline void Compute(float &acc, int32_t numElements = TILE_LENGTH)
135+
{
136+
AscendC::LocalTensor<W_T> wLocal = inQueueW_.DeQue<W_T>();
137+
AscendC::LocalTensor<float> xTmpTensor = tmpBufferX_.Get<float>();
138+
AscendC::LocalTensor<float> wTmpTensor = tmpBufferW_.Get<float>();
139+
140+
if constexpr (INCREMENTAL_MODE) {
141+
AscendC::LocalTensor<X_T> xLocal = inQueueX_.DeQue<X_T>();
142+
Cast(xTmpTensor, xLocal, AscendC::RoundMode::CAST_NONE, numElements);
143+
Cast(wTmpTensor, wLocal, AscendC::RoundMode::CAST_NONE, numElements);
144+
pipe_barrier(PIPE_V);
145+
inQueueX_.FreeTensor(xLocal);
146+
inQueueW_.FreeTensor(wLocal);
147+
} else {
148+
Cast(wTmpTensor, wLocal, AscendC::RoundMode::CAST_NONE, numElements);
149+
pipe_barrier(PIPE_V);
150+
inQueueW_.FreeTensor(wLocal);
151+
}
152+
// dot product of the one tile of X and W
153+
Mul(wTmpTensor, xTmpTensor, wTmpTensor, numElements);
154+
pipe_barrier(PIPE_V);
155+
// reduce sum generate one number, which is the summation of all the dot product
156+
ReduceSum<float>(wTmpTensor, wTmpTensor, wTmpTensor, numElements);
157+
pipe_barrier(PIPE_V);
158+
159+
acc += wTmpTensor.GetValue(0);
160+
}
161+
162+
template <bool INCREMENTAL_MODE>
163+
__aicore__ inline void CopyAndComputeLastIteration(const int64_t idx, int32_t rowIdx, float &acc)
164+
{
165+
int32_t colIdx = inputHiddenDim_ / TILE_LENGTH;
166+
int32_t remaining = inputHiddenDim_ % TILE_LENGTH;
167+
if (remaining == 0) {
168+
return;
169+
}
170+
if constexpr (INCREMENTAL_MODE) {
171+
CopyInX(idx, colIdx, remaining);
172+
}
173+
CopyInW(rowIdx, colIdx, remaining);
174+
Compute<INCREMENTAL_MODE>(acc, remaining);
175+
}
176+
177+
__aicore__ inline void ScaleOutput()
178+
{
179+
AscendC::LocalTensor<float> yLocal = outBufferY_.Get<float>();
180+
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.AllocTensor<Y_T>();
181+
182+
Muls(yOutLocal, yLocal, scale_, maxLoRARank_);
183+
pipe_barrier(PIPE_V);
184+
185+
outQueueY_.EnQue<Y_T>(yOutLocal);
186+
}
187+
188+
__aicore__ inline void CopyOut(const int64_t idx)
189+
{
190+
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.DeQue<Y_T>();
191+
DataCopy(yOutGm_[maxLoRARank_ * idx], yOutLocal, maxLoRARank_);
192+
outQueueY_.FreeTensor(yOutLocal);
193+
}
194+
195+
private:
196+
AscendC::TPipe *pipe_;
197+
AscendC::TQue<AscendC::QuePosition::VECIN, BUFFER_NUM> inQueueX_, inQueueW_;
198+
AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueueY_;
199+
AscendC::TBuf<AscendC::QuePosition::VECCALC> tmpBufferX_, tmpBufferW_, outBufferY_;
200+
AscendC::GlobalTensor<X_T> xGm_;
201+
AscendC::GlobalTensor<W_T> wGm_;
202+
AscendC::GlobalTensor<int64_t> indicesGm_;
203+
AscendC::GlobalTensor<Y_T> yOutGm_;
204+
uint32_t batchSize_;
205+
uint32_t numTokensPerCore_;
206+
uint32_t inputHiddenDim_;
207+
uint32_t maxLoRARank_;
208+
float scale_;
209+
uint32_t singleLoRAWeightLen_;
210+
int64_t reqLoRAIndex_;
211+
uint64_t reqLoRAWeightOffset_;
212+
bool incremental_;
213+
};
214+
215+
#define BGMV_SHRINK_TYPE_DECLARE(TYPE) \
216+
extern "C" __global__ __aicore__ void bgmv_shrink_##TYPE(__gm__ void* x, __gm__ void* weight, __gm__ void* indices,\
217+
__gm__ void* y, uint32_t batchSize, \
218+
uint32_t numTokensPerCore, uint32_t inputHiddenDim, \
219+
uint32_t maxLoRARank, float scale) \
220+
{ \
221+
AscendC::TPipe pipe; \
222+
BGMVShrink<TYPE> op(&pipe); \
223+
op.Init(x, weight, indices, y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); \
224+
op.Process(); \
225+
}
226+
227+
// declare all dtype kernel
228+
BGMV_SHRINK_TYPE_DECLARE(half)
229+
#if (__CCE_AICORE__ >= 220)
230+
BGMV_SHRINK_TYPE_DECLARE(bfloat16_t)
231+
#endif
232+
233+
namespace vllm_ascend {
234+
extern void bgmv_shrink_impl(AscendType type, void* stream, void* x, void* weight, void* indices,
235+
void* y, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t inputHiddenDim,
236+
uint32_t maxLoRARank, float scale)
237+
{
238+
uint32_t blockDim = (batchSize + numTokensPerCore - 1) / numTokensPerCore;
239+
if (type == AscendType::FP16) {
240+
bgmv_shrink_half<<<blockDim, nullptr, stream>>>(x, weight, indices, y, batchSize, numTokensPerCore,
241+
inputHiddenDim, maxLoRARank, scale);
242+
} else if (type == AscendType::BF16) {
243+
#if (__CCE_AICORE__ >= 220)
244+
bgmv_shrink_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, indices, y, batchSize, numTokensPerCore,
245+
inputHiddenDim, maxLoRARank, scale);
246+
#endif
247+
} else {
248+
return;
249+
}
250+
}
251+
252+
} // namespace vllm_ascend

csrc/ops.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,32 @@ namespace vllm_ascend {
6060
auto new_tensor = at_npu::native::from_blob(data_ptr, sizes, strides, options);
6161
return new_tensor;
6262
}
63+
64+
extern void bgmv_shrink_impl(
65+
AscendType type,
66+
void *stream,
67+
void *x,
68+
void *weight,
69+
void *indices,
70+
void *y,
71+
uint32_t batch_size,
72+
uint32_t num_tokens_per_core,
73+
uint32_t input_hidden_dim,
74+
uint32_t lora_rank,
75+
float scale);
76+
77+
extern void bgmv_expand_impl(
78+
AscendType type,
79+
void *stream,
80+
void *x,
81+
void *weight,
82+
void *indices,
83+
void *y,
84+
void *y_out,
85+
uint32_t batch_size,
86+
uint32_t num_tokens_per_core,
87+
uint32_t lora_rank,
88+
uint32_t output_hidden_dim,
89+
uint32_t slice_offset,
90+
uint32_t output_full_dim);
6391
}

csrc/torch_binding.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,90 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
199199
cmd.Run();
200200
return {masked_input, mask};
201201
}
202+
203+
void bgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y, double scale)
204+
{
205+
at::ScalarType scalar_type = x.scalar_type();
206+
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
207+
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
208+
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
209+
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
210+
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
211+
TORCH_CHECK(indices.dim() == 1, "indices should be [batch_size]");
212+
TORCH_CHECK(x.size(0) == y.size(0) && x.size(0) == indices.size(0),
213+
"the first dimension of x, y, indices should be same");
214+
TORCH_CHECK(x.size(1) > y.size(1), "hidden in should be greater than hidden out");
215+
void* x_ptr = x.data_ptr();
216+
void* weight_ptr = weight.data_ptr();
217+
void* indices_ptr = indices.data_ptr();
218+
void* y_ptr = y.data_ptr();
219+
int batch_size = x.size(0);
220+
int input_hidden_token = x.size(1);
221+
uint32_t lora_rank = y.size(1);
222+
float scale_f = static_cast<float>(scale);
223+
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
224+
at_npu::native::OpCommand cmd;
225+
cmd.Name("bgmv_shrink");
226+
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, batch_size, input_hidden_token,
227+
lora_rank, scale_f]() -> int {
228+
auto dtype = get_dtype_from_torch(scalar_type);
229+
int device_id = 0;
230+
int64_t aiv_num = 0;
231+
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
232+
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
233+
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
234+
bgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, batch_size, num_tokens_per_core,
235+
input_hidden_token, lora_rank, scale_f);
236+
return 0;
237+
});
238+
cmd.Run();
239+
return;
240+
}
241+
242+
at::Tensor bgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y,
243+
int64_t slice_offset, int64_t slice_size)
244+
{
245+
at::ScalarType scalar_type = y.scalar_type();
246+
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
247+
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
248+
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
249+
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
250+
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
251+
TORCH_CHECK(indices.dim() == 1, "indices should be [batch_size]");
252+
TORCH_CHECK(x.size(0) == y.size(0) && x.size(0) == indices.size(0),
253+
"the first dimension of x, y, indices should be same");
254+
TORCH_CHECK(x.size(1) <= slice_size, "hidden in should be smaller than hidden out");
255+
TORCH_CHECK(slice_offset >= 0, "slice offset should be no smaller than 0");
256+
TORCH_CHECK((slice_size + slice_offset) <= y.size(1),
257+
"slice_size + slice_offset should be smaller than the second dimension of y")
258+
259+
at::Tensor y_out = y;
260+
void* x_ptr = x.data_ptr();
261+
void* weight_ptr = weight.data_ptr();
262+
void* indices_ptr = indices.data_ptr();
263+
void* y_ptr = y.data_ptr();
264+
void* y_out_ptr = y_out.data_ptr();
265+
int batch_size = x.size(0);
266+
int lora_rank = x.size(1);
267+
int output_full_dim = y.size(1);
268+
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
269+
at_npu::native::OpCommand cmd;
270+
cmd.Name("bgmv_expand");
271+
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, y_out_ptr, batch_size, lora_rank,
272+
slice_offset, slice_size, output_full_dim]() -> int {
273+
auto dtype = get_dtype_from_torch(scalar_type);
274+
int device_id = 0;
275+
int64_t aiv_num = 0;
276+
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
277+
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
278+
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
279+
bgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, y_out_ptr, batch_size,
280+
num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
281+
return 0;
282+
});
283+
cmd.Run();
284+
return y_out;
285+
}
202286
} // namespace vllm_ascend
203287

204288
TORCH_LIBRARY_EXPAND(_C, ops)
@@ -223,6 +307,14 @@ TORCH_LIBRARY_EXPAND(_C, ops)
223307
" int added_vocab_start_index, "
224308
" int added_vocab_end_index) -> (Tensor masked_input, Tensor mask)");
225309
ops.impl("get_masked_input_and_mask", torch::kPrivateUse1, &vllm_ascend::get_masked_input_and_mask);
310+
311+
ops.def("bgmv_shrink(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y, float scale) -> ()");
312+
ops.impl("bgmv_shrink", torch::kPrivateUse1, &vllm_ascend::bgmv_shrink);
313+
314+
ops.def(
315+
"bgmv_expand(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y,"
316+
" int slice_offset, int slice_size) -> Tensor");
317+
ops.impl("bgmv_expand", torch::kPrivateUse1, &vllm_ascend::bgmv_expand);
226318
}
227319

228320
REGISTER_EXTENSION(_C)

0 commit comments

Comments
 (0)