Skip to content

Commit e39159f

Browse files
Add switch to apply fine-grained per token quant fp8 (#3192)
Co-authored-by: yuanxiaolan <[email protected]>
1 parent 88596c0 commit e39159f

File tree

1 file changed

+53
-9
lines changed

1 file changed

+53
-9
lines changed

custom_ops/gpu_ops/per_token_quant_fp8.cu

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ __global__ void quant_per_token_per_block(const T *input,
2222
float *quanted_scale,
2323
const int token_num,
2424
const int hidden_size,
25-
const int hidden_size_scale) {
25+
const int hidden_size_scale,
26+
const bool use_finegrained_range) {
2627
const int bid = blockIdx.x;
2728
const int tid = threadIdx.x;
2829
const int warp_id = tid / 32;
@@ -58,6 +59,11 @@ __global__ void quant_per_token_per_block(const T *input,
5859
// broadcast max_value
5960
max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0);
6061
max_value_thread = max(max_value_thread, epsilon);
62+
63+
if (use_finegrained_range) {
64+
max_value_thread *= 7.0f;
65+
}
66+
6167
float scale_to_store = max_value_thread / MAX_VALUE;
6268
// quant
6369
#pragma unroll
@@ -89,6 +95,13 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
8995
input.place());
9096
const int gridx = min(132 * 8, token_num);
9197
const int blockx = min(1024, hidden_size / 128 * 32);
98+
99+
bool use_finegrained_range = false;
100+
char *env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE");
101+
if (env_var) {
102+
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
103+
}
104+
92105
switch (input.dtype()) {
93106
case paddle::DataType::BFLOAT16:
94107
quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
@@ -97,7 +110,8 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
97110
quanted_scale.data<float>(),
98111
token_num,
99112
hidden_size,
100-
hidden_size_scale
113+
hidden_size_scale,
114+
use_finegrained_range
101115
);
102116
break;
103117
case paddle::DataType::FLOAT16:
@@ -107,7 +121,8 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input,
107121
quanted_scale.data<float>(),
108122
token_num,
109123
hidden_size,
110-
hidden_size_scale
124+
hidden_size_scale,
125+
use_finegrained_range
111126
);
112127
break;
113128
default:
@@ -124,7 +139,8 @@ __global__ void quant_per_token_per_block_padding(const T *input,
124139
const int token_num,
125140
const int padded_token_num,
126141
const int hidden_size,
127-
const int hidden_size_scale) {
142+
const int hidden_size_scale,
143+
const bool use_finegrained_range) {
128144
const int bid = blockIdx.x;
129145
const int tid = threadIdx.x;
130146
const int warp_id = tid / 32;
@@ -160,6 +176,11 @@ __global__ void quant_per_token_per_block_padding(const T *input,
160176
// broadcast max_value
161177
max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0);
162178
max_value_thread = max(max_value_thread, epsilon);
179+
180+
if (use_finegrained_range) {
181+
max_value_thread *= 7.0f;
182+
}
183+
163184
float scale_to_store = max_value_thread / MAX_VALUE;
164185
// quant
165186
#pragma unroll
@@ -198,6 +219,13 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
198219
input.place());
199220
const int gridx = min(132 * 8, token_num);
200221
const int blockx = min(1024, hidden_size / 128 * 32);
222+
223+
bool use_finegrained_range = false;
224+
char *env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE");
225+
if (env_var) {
226+
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
227+
}
228+
201229
switch (input.dtype()) {
202230
case paddle::DataType::BFLOAT16:
203231
quant_per_token_per_block_padding<<<gridx, blockx, 0, input.stream()>>>(
@@ -207,7 +235,8 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
207235
token_num,
208236
padded_token_num,
209237
hidden_size,
210-
hidden_size_scale
238+
hidden_size_scale,
239+
use_finegrained_range
211240
);
212241
break;
213242
case paddle::DataType::FLOAT16:
@@ -218,7 +247,8 @@ std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor& input,
218247
token_num,
219248
padded_token_num,
220249
hidden_size,
221-
hidden_size_scale
250+
hidden_size_scale,
251+
use_finegrained_range
222252
);
223253
break;
224254
default:
@@ -236,7 +266,8 @@ __global__ void masked_quant_per_token_per_block(const T *input,
236266
const int token_num,
237267
const int hidden_size,
238268
const int hidden_size_scale,
239-
const int num_max_tokens_per_expert) {
269+
const int num_max_tokens_per_expert,
270+
const bool use_finegrained_range) {
240271
const int bid = blockIdx.x;
241272
const int tid = threadIdx.x;
242273
const int warp_id = tid / 32;
@@ -281,6 +312,11 @@ __global__ void masked_quant_per_token_per_block(const T *input,
281312
// broadcast max_value
282313
max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0);
283314
max_value_thread = max(max_value_thread, epsilon);
315+
316+
if (use_finegrained_range) {
317+
max_value_thread *= 7.0f;
318+
}
319+
284320
float scale_to_store = max_value_thread / MAX_VALUE;
285321
// quant
286322
#pragma unroll
@@ -317,6 +353,12 @@ std::vector<paddle::Tensor> MaskedPerTokenQuant(paddle::Tensor& input,
317353
const int gridx = min(132 * 2, token_num);
318354
const int blockx = min(1024, hidden_size / 128 * 32);
319355

356+
bool use_finegrained_range = false;
357+
char *env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE");
358+
if (env_var) {
359+
use_finegrained_range = static_cast<bool>(std::stoi(env_var));
360+
}
361+
320362
switch (input.dtype()) {
321363
case paddle::DataType::BFLOAT16:
322364
masked_quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>(
@@ -327,7 +369,8 @@ std::vector<paddle::Tensor> MaskedPerTokenQuant(paddle::Tensor& input,
327369
token_num,
328370
hidden_size,
329371
hidden_size_scale,
330-
num_max_tokens_per_expert
372+
num_max_tokens_per_expert,
373+
use_finegrained_range
331374
);
332375
break;
333376
case paddle::DataType::FLOAT16:
@@ -339,7 +382,8 @@ std::vector<paddle::Tensor> MaskedPerTokenQuant(paddle::Tensor& input,
339382
token_num,
340383
hidden_size,
341384
hidden_size_scale,
342-
num_max_tokens_per_expert
385+
num_max_tokens_per_expert,
386+
use_finegrained_range
343387
);
344388
break;
345389
default:

0 commit comments

Comments
 (0)