@@ -22,7 +22,8 @@ __global__ void quant_per_token_per_block(const T *input,
22
22
float *quanted_scale,
23
23
const int token_num,
24
24
const int hidden_size,
25
- const int hidden_size_scale) {
25
+ const int hidden_size_scale,
26
+ const bool use_finegrained_range) {
26
27
const int bid = blockIdx .x ;
27
28
const int tid = threadIdx .x ;
28
29
const int warp_id = tid / 32 ;
@@ -58,6 +59,11 @@ __global__ void quant_per_token_per_block(const T *input,
58
59
// broadcast max_value
59
60
max_value_thread = __shfl_sync (0xFFFFFFFF , max_value_thread, 0 );
60
61
max_value_thread = max (max_value_thread, epsilon);
62
+
63
+ if (use_finegrained_range) {
64
+ max_value_thread *= 7 .0f ;
65
+ }
66
+
61
67
float scale_to_store = max_value_thread / MAX_VALUE;
62
68
// quant
63
69
#pragma unroll
@@ -89,6 +95,13 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
89
95
input.place ());
90
96
const int gridx = min (132 * 8 , token_num);
91
97
const int blockx = min (1024 , hidden_size / 128 * 32 );
98
+
99
+ bool use_finegrained_range = false ;
100
+ char *env_var = getenv (" PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE" );
101
+ if (env_var) {
102
+ use_finegrained_range = static_cast <bool >(std::stoi (env_var));
103
+ }
104
+
92
105
switch (input.dtype ()) {
93
106
case paddle::DataType::BFLOAT16:
94
107
quant_per_token_per_block<<<gridx, blockx, 0 , input.stream()>>> (
@@ -97,7 +110,8 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
97
110
quanted_scale.data <float >(),
98
111
token_num,
99
112
hidden_size,
100
- hidden_size_scale
113
+ hidden_size_scale,
114
+ use_finegrained_range
101
115
);
102
116
break ;
103
117
case paddle::DataType::FLOAT16:
@@ -107,7 +121,8 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
107
121
quanted_scale.data <float >(),
108
122
token_num,
109
123
hidden_size,
110
- hidden_size_scale
124
+ hidden_size_scale,
125
+ use_finegrained_range
111
126
);
112
127
break ;
113
128
default :
@@ -124,7 +139,8 @@ __global__ void quant_per_token_per_block_padding(const T *input,
124
139
const int token_num,
125
140
const int padded_token_num,
126
141
const int hidden_size,
127
- const int hidden_size_scale) {
142
+ const int hidden_size_scale,
143
+ const bool use_finegrained_range) {
128
144
const int bid = blockIdx .x ;
129
145
const int tid = threadIdx .x ;
130
146
const int warp_id = tid / 32 ;
@@ -160,6 +176,11 @@ __global__ void quant_per_token_per_block_padding(const T *input,
160
176
// broadcast max_value
161
177
max_value_thread = __shfl_sync (0xFFFFFFFF , max_value_thread, 0 );
162
178
max_value_thread = max (max_value_thread, epsilon);
179
+
180
+ if (use_finegrained_range) {
181
+ max_value_thread *= 7 .0f ;
182
+ }
183
+
163
184
float scale_to_store = max_value_thread / MAX_VALUE;
164
185
// quant
165
186
#pragma unroll
@@ -198,6 +219,13 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
198
219
input.place ());
199
220
const int gridx = min (132 * 8 , token_num);
200
221
const int blockx = min (1024 , hidden_size / 128 * 32 );
222
+
223
+ bool use_finegrained_range = false ;
224
+ char *env_var = getenv (" PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE" );
225
+ if (env_var) {
226
+ use_finegrained_range = static_cast <bool >(std::stoi (env_var));
227
+ }
228
+
201
229
switch (input.dtype ()) {
202
230
case paddle::DataType::BFLOAT16:
203
231
quant_per_token_per_block_padding<<<gridx, blockx, 0 , input.stream()>>> (
@@ -207,7 +235,8 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
207
235
token_num,
208
236
padded_token_num,
209
237
hidden_size,
210
- hidden_size_scale
238
+ hidden_size_scale,
239
+ use_finegrained_range
211
240
);
212
241
break ;
213
242
case paddle::DataType::FLOAT16:
@@ -218,7 +247,8 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
218
247
token_num,
219
248
padded_token_num,
220
249
hidden_size,
221
- hidden_size_scale
250
+ hidden_size_scale,
251
+ use_finegrained_range
222
252
);
223
253
break ;
224
254
default :
@@ -236,7 +266,8 @@ __global__ void masked_quant_per_token_per_block(const T *input,
236
266
const int token_num,
237
267
const int hidden_size,
238
268
const int hidden_size_scale,
239
- const int num_max_tokens_per_expert) {
269
+ const int num_max_tokens_per_expert,
270
+ const bool use_finegrained_range) {
240
271
const int bid = blockIdx .x ;
241
272
const int tid = threadIdx .x ;
242
273
const int warp_id = tid / 32 ;
@@ -281,6 +312,11 @@ __global__ void masked_quant_per_token_per_block(const T *input,
281
312
// broadcast max_value
282
313
max_value_thread = __shfl_sync (0xFFFFFFFF , max_value_thread, 0 );
283
314
max_value_thread = max (max_value_thread, epsilon);
315
+
316
+ if (use_finegrained_range) {
317
+ max_value_thread *= 7 .0f ;
318
+ }
319
+
284
320
float scale_to_store = max_value_thread / MAX_VALUE;
285
321
// quant
286
322
#pragma unroll
@@ -317,6 +353,12 @@ std::vector<paddle::Tensor> MaskedPerTokenQuant(paddle::Tensor& input,
317
353
const int gridx = min (132 * 2 , token_num);
318
354
const int blockx = min (1024 , hidden_size / 128 * 32 );
319
355
356
+ bool use_finegrained_range = false ;
357
+ char *env_var = getenv (" PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE" );
358
+ if (env_var) {
359
+ use_finegrained_range = static_cast <bool >(std::stoi (env_var));
360
+ }
361
+
320
362
switch (input.dtype ()) {
321
363
case paddle::DataType::BFLOAT16:
322
364
masked_quant_per_token_per_block<<<gridx, blockx, 0 , input.stream()>>> (
@@ -327,7 +369,8 @@ std::vector<paddle::Tensor> MaskedPerTokenQuant(paddle::Tensor& input,
327
369
token_num,
328
370
hidden_size,
329
371
hidden_size_scale,
330
- num_max_tokens_per_expert
372
+ num_max_tokens_per_expert,
373
+ use_finegrained_range
331
374
);
332
375
break ;
333
376
case paddle::DataType::FLOAT16:
@@ -339,7 +382,8 @@ std::vector<paddle::Tensor> MaskedPerTokenQuant(paddle::Tensor& input,
339
382
token_num,
340
383
hidden_size,
341
384
hidden_size_scale,
342
- num_max_tokens_per_expert
385
+ num_max_tokens_per_expert,
386
+ use_finegrained_range
343
387
);
344
388
break ;
345
389
default :
0 commit comments