1
+ /*
2
+ * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include " kernel_operator.h"
18
+ #include " types.h"
19
+
20
+ template <typename scalar_t >
21
+ class BGMVShrink {
22
+ public:
23
+ using X_T = scalar_t ;
24
+ using W_T = scalar_t ;
25
+ using Y_T = float ;
26
+
27
+ static constexpr uint64_t BUFFER_NUM = 1 ;
28
+ static constexpr uint64_t TILE_LENGTH = 11776 ; // optimal performance tile length
29
+
30
+ public:
31
+ __aicore__ inline BGMVShrink (AscendC::TPipe *pipe) : pipe_(pipe) {}
32
+ __aicore__ inline void Init (__gm__ void *x, __gm__ void *weight, __gm__ void *indices, __gm__ void *y,
33
+ uint32_t batchSize, uint32_t numTokensPerCore, uint32_t inputHiddenDim,
34
+ uint32_t maxLoRARank, float scale)
35
+ {
36
+ batchSize_ = batchSize;
37
+ numTokensPerCore_ = numTokensPerCore;
38
+ inputHiddenDim_ = inputHiddenDim;
39
+ maxLoRARank_ = maxLoRARank;
40
+ scale_ = scale;
41
+ singleLoRAWeightLen_ = inputHiddenDim_ * maxLoRARank_;
42
+ incremental_ = inputHiddenDim_ > TILE_LENGTH;
43
+
44
+ xGm_.SetGlobalBuffer ((__gm__ X_T *)x);
45
+ yOutGm_.SetGlobalBuffer ((__gm__ Y_T *)y);
46
+ wGm_.SetGlobalBuffer ((__gm__ W_T *)weight);
47
+ indicesGm_.SetGlobalBuffer ((__gm__ int64_t *)indices);
48
+
49
+ pipe_->InitBuffer (inQueueX_, BUFFER_NUM, TILE_LENGTH * sizeof (X_T));
50
+ pipe_->InitBuffer (inQueueW_, BUFFER_NUM, TILE_LENGTH * sizeof (W_T));
51
+ pipe_->InitBuffer (tmpBufferX_, TILE_LENGTH * sizeof (float ));
52
+ pipe_->InitBuffer (tmpBufferW_, TILE_LENGTH * sizeof (float ));
53
+
54
+ pipe_->InitBuffer (outQueueY_, 1 , maxLoRARank_ * sizeof (Y_T));
55
+ pipe_->InitBuffer (outBufferY_, maxLoRARank_ * sizeof (float ));
56
+ }
57
+
58
+ __aicore__ inline void Process ()
59
+ {
60
+ int64_t blockIdx = AscendC::GetBlockIdx ();
61
+ int64_t startIdx = blockIdx * numTokensPerCore_;
62
+ int64_t endIdx = startIdx + numTokensPerCore_;
63
+ if (endIdx > batchSize_) {
64
+ endIdx = batchSize_;
65
+ }
66
+ for (int64_t idx = startIdx; idx < endIdx; idx++) {
67
+ // set up LoRA index
68
+ CopyInIndex (idx);
69
+ if (reqLoRAIndex_ < 0 ) {
70
+ continue ;
71
+ }
72
+ reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_;
73
+
74
+ if (incremental_) {
75
+ ProcessImpl<true >(idx);
76
+ } else {
77
+ ProcessImpl<false >(idx);
78
+ }
79
+
80
+ ScaleOutput ();
81
+ CopyOut (idx);
82
+ }
83
+ }
84
+
85
+ private:
86
+ template <bool INCREMENTAL_MODE>
87
+ __aicore__ inline void ProcessImpl (const int64_t idx)
88
+ {
89
+ AscendC::LocalTensor<float > yOutLocal = outBufferY_.Get <float >();
90
+ if constexpr (!INCREMENTAL_MODE) {
91
+ CopyInX (idx, 0 , inputHiddenDim_);
92
+ AscendC::LocalTensor<float > xTmpTensor = tmpBufferX_.Get <float >();
93
+ AscendC::LocalTensor<X_T> xLocal = inQueueX_.DeQue <X_T>();
94
+ Cast (xTmpTensor, xLocal, AscendC::RoundMode::CAST_NONE, inputHiddenDim_);
95
+ pipe_barrier (PIPE_V);
96
+ inQueueX_.FreeTensor (xLocal);
97
+ }
98
+
99
+ for (int i = 0 ; i < maxLoRARank_; i++) {
100
+ float acc (0 );
101
+ for (int32_t j = 0 ; j < inputHiddenDim_ / TILE_LENGTH; j++) {
102
+ if constexpr (INCREMENTAL_MODE) {
103
+ CopyInX (idx, j);
104
+ }
105
+ CopyInW (i, j);
106
+ Compute<INCREMENTAL_MODE>(acc);
107
+ }
108
+ CopyAndComputeLastIteration<INCREMENTAL_MODE>(idx, i, acc);
109
+ yOutLocal.SetValue (i, acc);
110
+ }
111
+ }
112
+
113
+ __aicore__ inline void CopyInIndex (const int64_t idx)
114
+ {
115
+ // look up the LoRA index
116
+ reqLoRAIndex_ = indicesGm_.GetValue (idx);
117
+ }
118
+
119
+ __aicore__ inline void CopyInX (const int64_t idx, int32_t colIdx, int32_t numElements = TILE_LENGTH)
120
+ {
121
+ AscendC::LocalTensor<X_T> xLocal = inQueueX_.AllocTensor <X_T>();
122
+ DataCopy (xLocal, xGm_[inputHiddenDim_ * idx + colIdx * TILE_LENGTH], numElements);
123
+ inQueueX_.EnQue (xLocal);
124
+ }
125
+
126
+ __aicore__ inline void CopyInW (int32_t rowIdx, int32_t colIdx, int32_t numElements = TILE_LENGTH)
127
+ {
128
+ AscendC::LocalTensor<W_T> wLocal = inQueueW_.AllocTensor <W_T>();
129
+ DataCopy (wLocal, wGm_[reqLoRAWeightOffset_ + rowIdx * inputHiddenDim_ + colIdx * TILE_LENGTH], numElements);
130
+ inQueueW_.EnQue (wLocal);
131
+ }
132
+
133
+ template <bool INCREMENTAL_MODE>
134
+ __aicore__ inline void Compute (float &acc, int32_t numElements = TILE_LENGTH)
135
+ {
136
+ AscendC::LocalTensor<W_T> wLocal = inQueueW_.DeQue <W_T>();
137
+ AscendC::LocalTensor<float > xTmpTensor = tmpBufferX_.Get <float >();
138
+ AscendC::LocalTensor<float > wTmpTensor = tmpBufferW_.Get <float >();
139
+
140
+ if constexpr (INCREMENTAL_MODE) {
141
+ AscendC::LocalTensor<X_T> xLocal = inQueueX_.DeQue <X_T>();
142
+ Cast (xTmpTensor, xLocal, AscendC::RoundMode::CAST_NONE, numElements);
143
+ Cast (wTmpTensor, wLocal, AscendC::RoundMode::CAST_NONE, numElements);
144
+ pipe_barrier (PIPE_V);
145
+ inQueueX_.FreeTensor (xLocal);
146
+ inQueueW_.FreeTensor (wLocal);
147
+ } else {
148
+ Cast (wTmpTensor, wLocal, AscendC::RoundMode::CAST_NONE, numElements);
149
+ pipe_barrier (PIPE_V);
150
+ inQueueW_.FreeTensor (wLocal);
151
+ }
152
+ // dot product of the one tile of X and W
153
+ Mul (wTmpTensor, xTmpTensor, wTmpTensor, numElements);
154
+ pipe_barrier (PIPE_V);
155
+ // reduce sum generate one number, which is the summation of all the dot product
156
+ ReduceSum<float >(wTmpTensor, wTmpTensor, wTmpTensor, numElements);
157
+ pipe_barrier (PIPE_V);
158
+
159
+ acc += wTmpTensor.GetValue (0 );
160
+ }
161
+
162
+ template <bool INCREMENTAL_MODE>
163
+ __aicore__ inline void CopyAndComputeLastIteration (const int64_t idx, int32_t rowIdx, float &acc)
164
+ {
165
+ int32_t colIdx = inputHiddenDim_ / TILE_LENGTH;
166
+ int32_t remaining = inputHiddenDim_ % TILE_LENGTH;
167
+ if (remaining == 0 ) {
168
+ return ;
169
+ }
170
+ if constexpr (INCREMENTAL_MODE) {
171
+ CopyInX (idx, colIdx, remaining);
172
+ }
173
+ CopyInW (rowIdx, colIdx, remaining);
174
+ Compute<INCREMENTAL_MODE>(acc, remaining);
175
+ }
176
+
177
+ __aicore__ inline void ScaleOutput ()
178
+ {
179
+ AscendC::LocalTensor<float > yLocal = outBufferY_.Get <float >();
180
+ AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.AllocTensor <Y_T>();
181
+
182
+ Muls (yOutLocal, yLocal, scale_, maxLoRARank_);
183
+ pipe_barrier (PIPE_V);
184
+
185
+ outQueueY_.EnQue <Y_T>(yOutLocal);
186
+ }
187
+
188
+ __aicore__ inline void CopyOut (const int64_t idx)
189
+ {
190
+ AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.DeQue <Y_T>();
191
+ DataCopy (yOutGm_[maxLoRARank_ * idx], yOutLocal, maxLoRARank_);
192
+ outQueueY_.FreeTensor (yOutLocal);
193
+ }
194
+
195
+ private:
196
+ AscendC::TPipe *pipe_;
197
+ AscendC::TQue<AscendC::QuePosition::VECIN, BUFFER_NUM> inQueueX_, inQueueW_;
198
+ AscendC::TQue<AscendC::QuePosition::VECOUT, 1 > outQueueY_;
199
+ AscendC::TBuf<AscendC::QuePosition::VECCALC> tmpBufferX_, tmpBufferW_, outBufferY_;
200
+ AscendC::GlobalTensor<X_T> xGm_;
201
+ AscendC::GlobalTensor<W_T> wGm_;
202
+ AscendC::GlobalTensor<int64_t > indicesGm_;
203
+ AscendC::GlobalTensor<Y_T> yOutGm_;
204
+ uint32_t batchSize_;
205
+ uint32_t numTokensPerCore_;
206
+ uint32_t inputHiddenDim_;
207
+ uint32_t maxLoRARank_;
208
+ float scale_;
209
+ uint32_t singleLoRAWeightLen_;
210
+ int64_t reqLoRAIndex_;
211
+ uint64_t reqLoRAWeightOffset_;
212
+ bool incremental_;
213
+ };
214
+
215
+ #define BGMV_SHRINK_TYPE_DECLARE (TYPE ) \
216
+ extern " C" __global__ __aicore__ void bgmv_shrink_##TYPE(__gm__ void * x, __gm__ void * weight, __gm__ void * indices,\
217
+ __gm__ void * y, uint32_t batchSize, \
218
+ uint32_t numTokensPerCore, uint32_t inputHiddenDim, \
219
+ uint32_t maxLoRARank, float scale) \
220
+ { \
221
+ AscendC::TPipe pipe; \
222
+ BGMVShrink<TYPE> op (&pipe); \
223
+ op.Init (x, weight, indices, y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); \
224
+ op.Process (); \
225
+ }
226
+
227
+ // declare all dtype kernel
228
+ BGMV_SHRINK_TYPE_DECLARE (half)
229
+ #if (__CCE_AICORE__ >= 220)
230
+ BGMV_SHRINK_TYPE_DECLARE (bfloat16_t )
231
+ #endif
232
+
233
+ namespace vllm_ascend {
234
+ extern void bgmv_shrink_impl (AscendType type, void * stream, void * x, void * weight, void * indices,
235
+ void * y, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t inputHiddenDim,
236
+ uint32_t maxLoRARank, float scale)
237
+ {
238
+ uint32_t blockDim = (batchSize + numTokensPerCore - 1 ) / numTokensPerCore;
239
+ if (type == AscendType::FP16) {
240
+ bgmv_shrink_half<<<blockDim, nullptr , stream>>>(x, weight, indices, y, batchSize, numTokensPerCore,
241
+ inputHiddenDim, maxLoRARank, scale);
242
+ } else if (type == AscendType::BF16) {
243
+ #if (__CCE_AICORE__ >= 220)
244
+ bgmv_shrink_bfloat16_t <<<blockDim, nullptr , stream>>>(x, weight, indices, y, batchSize, numTokensPerCore,
245
+ inputHiddenDim, maxLoRARank, scale);
246
+ #endif
247
+ } else {
248
+ return ;
249
+ }
250
+ }
251
+
252
+ } // namespace vllm_ascend
0 commit comments