@@ -29,43 +29,62 @@ struct __align__(sizeof(T) * VecSize) VecType {
29
29
}
30
30
};
31
31
32
- template <int VecSize>
33
- __device__ void BlockLoad (const phi::bfloat16* input,
32
+ template <typename InT, int VecSize>
33
+ __device__ void BlockLoad (const InT* input,
34
+ const float * input_scales,
34
35
__nv_bfloat16 x[8 ][4 ],
35
- size_t K) {
36
+ size_t K,
37
+ size_t k_scaled) {
38
+ constexpr bool need_dequant = std::is_same_v<InT, phi::dtype::float8_e4m3fn>;
39
+
40
+ #pragma unroll
36
41
for (uint32_t i = 0 ; i < 8 ; i++) {
37
- size_t off_m = blockIdx .x * size_t (128 ) + threadIdx .y + i * 16 ;
38
- size_t off_k = blockIdx .y * 128 + threadIdx .x * VecSize;
39
- size_t offset = off_m * K + off_k;
42
+ const uint32_t local_off_M = threadIdx .y + i * 16 ;
43
+ const uint32_t off_m = blockIdx .x * 128 + local_off_M;
44
+ const uint32_t off_k = blockIdx .y * 128 + threadIdx .x * VecSize;
45
+ const size_t offset = off_m * K + off_k;
46
+
47
+ float scale;
48
+ if constexpr (need_dequant) {
49
+ const uint32_t m_base = blockIdx .x * 128 ;
50
+ const uint32_t m_stride = k_scaled;
51
+ scale = input_scales[off_m * m_stride + blockIdx .y ];
52
+ }
40
53
54
+ #pragma unroll
41
55
for (uint32_t j = 0 ; j < 4 ; j += VecSize) {
42
- if (off_k + j * 32 < K) {
43
- size_t idx = offset + j * 32 ;
44
- using LoadT = VecType<__nv_bfloat16, VecSize>;
45
- LoadT data = *reinterpret_cast <const LoadT*>(input + idx);
46
- for (uint32_t k = 0 ; k < VecSize; k++) {
47
- x[i][j + k] = data[k];
56
+ const size_t idx = offset + j * 32 ;
57
+ using LoadT = VecType<InT, VecSize>;
58
+ LoadT data = *reinterpret_cast <const LoadT*>(input + idx);
59
+ #pragma unroll
60
+ for (uint32_t k = 0 ; k < VecSize; k++) {
61
+ if constexpr (need_dequant) {
62
+ x[i][j + k] = __float2bfloat16 (static_cast <float >(data[k]) * scale);
63
+ } else {
64
+ x[i][j + k] = (*reinterpret_cast <__nv_bfloat16*>(&data[k]));
48
65
}
49
66
}
50
67
}
51
68
}
52
69
}
53
-
54
70
template <bool Pow2Scales>
55
71
__device__ void BlockColumnScale (const __nv_bfloat16 x[8 ][4 ],
56
- float col_scale [128 ],
72
+ float scales [128 ],
57
73
__nv_bfloat16* shm) {
58
74
// reduce [(8), 16, 32, 4] => [16, 32, 4]
59
75
__nv_bfloat16 warp_max[4 ];
76
+ #pragma unroll
60
77
for (uint32_t i = 0 ; i < 8 ; i++) {
78
+ #pragma unroll
61
79
for (uint32_t j = 0 ; j < 4 ; j++) {
62
- __nv_bfloat16 t = BF16_ABS (x[i][j]);
80
+ const __nv_bfloat16 t = BF16_ABS (x[i][j]);
63
81
warp_max[j] = i == 0 ? t : BF16_MAX (warp_max[j], t);
64
82
}
65
83
}
66
84
67
85
// reduce [(16), 32, 4] => [8, 32, 4]
68
86
if (threadIdx .y >= 8 ) {
87
+ #pragma unroll
69
88
for (uint32_t j = 0 ; j < 4 ; j++) {
70
89
shm[(threadIdx .y - 8 ) * 128 + threadIdx .x + j * 32 ] = warp_max[j];
71
90
}
@@ -75,8 +94,9 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4],
75
94
// reduce [(8), 32, 4] => [32, 4]
76
95
for (uint32_t offset = 8 ; offset > 0 ; offset /= 2 ) {
77
96
if (threadIdx .y < offset) {
97
+ #pragma unroll
78
98
for (uint32_t j = 0 ; j < 4 ; j++) {
79
- __nv_bfloat16 other =
99
+ const __nv_bfloat16 other =
80
100
offset == 8
81
101
? warp_max[j]
82
102
: shm[(threadIdx .y + offset) * 128 + threadIdx .x + j * 32 ];
@@ -85,7 +105,7 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4],
85
105
if (offset > 1 ) {
86
106
shm[threadIdx .y * 128 + threadIdx .x + j * 32 ] = next_val;
87
107
} else {
88
- col_scale [threadIdx .x + j * 32 ] =
108
+ scales [threadIdx .x + j * 32 ] =
89
109
ComputeScale<__nv_bfloat16, __nv_fp8_e4m3, Pow2Scales>(
90
110
static_cast <float >(next_val), 0 .0f );
91
111
}
@@ -98,7 +118,7 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4],
98
118
template <typename OutT, int VecSize>
99
119
__device__ void BlockStoreScale (float * scale,
100
120
size_t off_m,
101
- float col_scale [128 ],
121
+ float scales [128 ],
102
122
size_t K) {
103
123
if (threadIdx .y < 4 ) {
104
124
uint32_t off = threadIdx .y * 32 + threadIdx .x ;
@@ -107,10 +127,10 @@ __device__ void BlockStoreScale(float* scale,
107
127
} else if constexpr (VecSize == 2 ) {
108
128
off = (off / 64 ) * 64 + (off % 2 ) * 32 + (off % 64 ) / 2 ;
109
129
}
110
- float scale_out = 1 .0f / col_scale [off];
111
- size_t idx_y = blockIdx .x - off_m / 128 ;
112
- size_t idx_x = blockIdx .y * 128 + threadIdx .y * 32 + threadIdx .x ;
113
- size_t idx = idx_y * K + idx_x;
130
+ float scale_out = 1 .0f / scales [off];
131
+ const size_t idx_y = blockIdx .x - off_m / 128 ;
132
+ const size_t idx_x = blockIdx .y * 128 + threadIdx .y * 32 + threadIdx .x ;
133
+ const size_t idx = idx_y * K + idx_x;
114
134
if (idx_x < K) {
115
135
scale[idx] = scale_out;
116
136
}
@@ -123,14 +143,16 @@ __device__ void BlockStoreOut(OutT* out,
123
143
size_t cur_tokens,
124
144
const OutT shm[128 ][129 ],
125
145
size_t K) {
146
+ #pragma unroll
126
147
for (uint32_t i = 0 ; i < 8 ; i++) {
127
- size_t idx_m = blockIdx .x * size_t (128 ) + threadIdx .x * 4 ;
128
- size_t idx_k = blockIdx .y * 128 + threadIdx .y + i * 16 ;
129
- size_t idx = idx_k * cur_tokens + (idx_m - off_m);
148
+ const size_t idx_m = blockIdx .x * size_t (128 ) + threadIdx .x * 4 ;
149
+ const size_t idx_k = blockIdx .y * 128 + threadIdx .y + i * 16 ;
150
+ const size_t idx = idx_k * cur_tokens + (idx_m - off_m);
130
151
131
152
if (idx_k < K) {
132
153
using StoreT = VecType<OutT, VecSize>;
133
154
StoreT data;
155
+ #pragma unroll
134
156
for (uint32_t j = 0 ; j < VecSize; j++) {
135
157
data[j] = shm[i * 16 + threadIdx .y ][threadIdx .x * 4 + j];
136
158
}
@@ -139,23 +161,27 @@ __device__ void BlockStoreOut(OutT* out,
139
161
}
140
162
}
141
163
142
- template <typename OutT, bool Pow2Scales, int VecSize>
164
+ template <typename InT, typename OutT, bool Pow2Scales, int VecSize>
143
165
__global__ void __launch_bounds__ (512 )
144
- FusedTransposeSplitQuantKernel(const phi::bfloat16* __restrict__ input,
166
+ FusedTransposeSplitQuantKernel(const InT* __restrict__ input,
167
+ const float * __restrict__ input_scales,
145
168
int64_t * __restrict__ meta,
146
169
size_t num_experts,
147
- size_t K) {
170
+ size_t K,
171
+ size_t k_scaled) {
148
172
__shared__ OutT shm[128 ][129 ];
173
+ __shared__ size_t expert_info[2 ];
174
+ __shared__ float scales[128 ]; // May be reused? Is it worthy?
175
+
149
176
int64_t * tokens_per_expert = meta;
150
177
OutT** out_ptrs = reinterpret_cast <OutT**>(meta + num_experts);
151
178
float ** scale_ptrs = reinterpret_cast <float **>(meta + num_experts * 2 );
152
179
153
180
// 1. Load 128x128 elements from input
154
181
__nv_bfloat16 x[8 ][4 ];
155
- BlockLoad<VecSize>(input, x, K);
182
+ BlockLoad<InT, VecSize>(input, input_scales, x, K, k_scaled );
156
183
157
184
// 2. Get expert index and offset of the current block
158
- __shared__ size_t expert_info[2 ];
159
185
if (threadIdx .x == 0 && threadIdx .y == 0 ) {
160
186
size_t idx_m = blockIdx .x * size_t (128 );
161
187
size_t off_m = 0 , next_off_m = 0 ;
@@ -172,21 +198,23 @@ __global__ void __launch_bounds__(512)
172
198
}
173
199
174
200
// 3. Calculate scale along the column
175
- __shared__ float col_scale[128 ];
176
201
BlockColumnScale<Pow2Scales>(
177
- x, col_scale , reinterpret_cast <__nv_bfloat16*>(shm));
202
+ x, scales , reinterpret_cast <__nv_bfloat16*>(shm));
178
203
179
204
// 4. Store scale
180
205
const size_t expert_idx = expert_info[0 ];
181
206
const size_t off_m = expert_info[1 ];
182
- BlockStoreScale<OutT, VecSize>(scale_ptrs[expert_idx], off_m, col_scale , K);
207
+ BlockStoreScale<OutT, VecSize>(scale_ptrs[expert_idx], off_m, scales , K);
183
208
184
- // 5. Scale x and save into shared memory with transposed layout
209
+ // 5. Scale x and save into shared memory with transposed layout
210
+ #pragma unroll
185
211
for (uint32_t i = 0 ; i < 8 ; i++) {
212
+ #pragma unroll
186
213
for (uint32_t j = 0 ; j < 4 ; j += VecSize) {
214
+ #pragma unroll
187
215
for (uint32_t k = 0 ; k < VecSize; k++) {
188
216
float x_fp32 = static_cast <float >(x[i][j + k]);
189
- float x_scaled = x_fp32 * col_scale [threadIdx .x + (j + k) * 32 ];
217
+ float x_scaled = x_fp32 * scales [threadIdx .x + (j + k) * 32 ];
190
218
shm[threadIdx .x * VecSize + j * 32 + k][i * 16 + threadIdx .y ] =
191
219
static_cast <OutT>(x_scaled);
192
220
}
@@ -204,10 +232,11 @@ template <typename T, typename Context>
204
232
void FusedTransposeSplitQuantKernel (
205
233
const Context& dev_ctx,
206
234
const DenseTensor& x,
235
+ const paddle::optional<DenseTensor>& input_scales,
207
236
const std::vector<int64_t >& tokens_per_expert,
208
237
bool pow_2_scales,
209
238
std::vector<DenseTensor*> outs,
210
- std::vector<DenseTensor*> scales ) {
239
+ std::vector<DenseTensor*> output_scales ) {
211
240
auto x_dims = x.dims ();
212
241
const int64_t M = x_dims[0 ];
213
242
const int64_t K = x_dims[1 ];
@@ -221,8 +250,8 @@ void FusedTransposeSplitQuantKernel(
221
250
if (outs[i] != nullptr ) {
222
251
dev_ctx.template Alloc <phi::dtype::float8_e4m3fn>(outs[i]);
223
252
}
224
- if (scales [i] != nullptr ) {
225
- dev_ctx.template Alloc <float >(scales [i]);
253
+ if (output_scales [i] != nullptr ) {
254
+ dev_ctx.template Alloc <float >(output_scales [i]);
226
255
}
227
256
}
228
257
@@ -245,8 +274,8 @@ void FusedTransposeSplitQuantKernel(
245
274
246
275
for (size_t i = 0 ; i < num_experts; i++) {
247
276
meta_ptr[num_experts * 2 + i] =
248
- scales [i] != nullptr
249
- ? reinterpret_cast <int64_t >(scales [i]->data <float >())
277
+ output_scales [i] != nullptr
278
+ ? reinterpret_cast <int64_t >(output_scales [i]->data <float >())
250
279
: 0 ;
251
280
}
252
281
@@ -255,23 +284,35 @@ void FusedTransposeSplitQuantKernel(
255
284
256
285
auto stream = dev_ctx.stream ();
257
286
258
- dim3 grid (M / 128 , (K + 127 ) / 128 );
287
+ // pre-compute on CPU to reduce size_t division cost in kernel
288
+ const size_t k_scaled = (K + 127 ) / 128 ;
289
+ dim3 grid (M / 128 , k_scaled);
259
290
dim3 block (32 , 16 );
260
291
261
- #define LAUNCH_KERNEL (POW_2_SCALES, VEC_SIZE ) \
262
- FusedTransposeSplitQuantKernel<phi::dtype::float8_e4m3fn, \
263
- POW_2_SCALES, \
264
- VEC_SIZE> \
265
- <<<grid, block, 0 , stream>>> (x.data <phi::dtype::bfloat16>(), \
266
- meta_gpu.data <int64_t >(), \
267
- num_experts, \
268
- K);
292
+ #define DTYPE_CASE (dtype, type ) dtype == phi::DataType::type
293
+ #define LAUNCH_KERNEL (T, POW_2_SCALES, VEC_SIZE ) \
294
+ FusedTransposeSplitQuantKernel<T, \
295
+ phi::dtype::float8_e4m3fn, \
296
+ POW_2_SCALES, \
297
+ VEC_SIZE><<<grid, block, 0 , stream>>> ( \
298
+ x.data <T>(), \
299
+ input_scales ? input_scales.get_ptr ()->data <float >() : nullptr , \
300
+ meta_gpu.data <int64_t >(), \
301
+ num_experts, \
302
+ K, \
303
+ k_scaled);
304
+ #define DISPATCH_DATATYPE (POW_2_SCALES, VEC_SIZE ) \
305
+ if (DTYPE_CASE (x.dtype (), BFLOAT16)) { \
306
+ LAUNCH_KERNEL (phi::bfloat16, POW_2_SCALES, VEC_SIZE); \
307
+ } else if (DTYPE_CASE (x.dtype (), FLOAT8_E4M3FN)) { \
308
+ LAUNCH_KERNEL (phi::float8_e4m3fn, POW_2_SCALES, VEC_SIZE); \
309
+ }
269
310
270
311
#define LAUNCH_KERNEL_PARTIAL (VEC_SIZE ) \
271
312
if (pow_2_scales) { \
272
- LAUNCH_KERNEL (true , VEC_SIZE); \
313
+ DISPATCH_DATATYPE (true , VEC_SIZE); \
273
314
} else { \
274
- LAUNCH_KERNEL (false , VEC_SIZE); \
315
+ DISPATCH_DATATYPE (false , VEC_SIZE); \
275
316
}
276
317
277
318
if (K % 4 == 0 ) {
@@ -296,7 +337,8 @@ PD_REGISTER_KERNEL(fused_transpose_split_quant,
296
337
double ,
297
338
int ,
298
339
int64_t ,
299
- phi::dtype::bfloat16) {
340
+ phi::dtype::bfloat16,
341
+ phi::dtype::float8_e4m3fn) {
300
342
kernel->OutputAt (0 ).SetDataType (phi::DataType::FLOAT8_E4M3FN);
301
343
kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32);
302
344
}
0 commit comments