|
16 | 16 |
|
17 | 17 | constexpr float epsilon = 1e-10; |
18 | 18 |
|
19 | | -template <typename T> |
20 | | -__global__ void quant_per_token_per_block( |
21 | | - const T *input, |
22 | | - phi::dtype::float8_e4m3fn *quanted_res, |
23 | | - float *quanted_scale, |
24 | | - const int token_num, |
25 | | - const int hidden_size, |
26 | | - const int hidden_size_scale, |
27 | | - const bool use_finegrained_range) { |
28 | | - const int bid = blockIdx.x; |
29 | | - const int tid = threadIdx.x; |
30 | | - const int warp_id = tid / 32; |
31 | | - const int lane_id = tid % 32; |
32 | | - const int num_warp = blockDim.x / 32; |
33 | | - static constexpr int NUM_PER_THREADS = 128 / 32; // 4 |
34 | | - static constexpr float MAX_VALUE = 448.f; |
35 | | - // Note(ZKK) use ceil_div!! |
36 | | - const int end_iter = (hidden_size + 127) / 128; // warp_iter_num |
37 | | - AlignedVector<T, NUM_PER_THREADS> load_vec; |
38 | | - AlignedVector<float, NUM_PER_THREADS> load_vec_float; |
39 | | - AlignedVector<phi::dtype::float8_e4m3fn, NUM_PER_THREADS> res_vec; |
40 | | - for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) { |
41 | | - const T *input_now = input + token_idx * hidden_size; |
42 | | - phi::dtype::float8_e4m3fn *quanted_res_now = |
43 | | - quanted_res + token_idx * hidden_size; |
44 | | - float *quanted_scale_now = quanted_scale + token_idx * hidden_size_scale; |
45 | | - // deal a block per warp |
46 | | - for (int iter = warp_id; iter < end_iter; iter += num_warp) { |
47 | | - const int start_offset = iter * 128; |
48 | | - |
49 | | - const bool is_valid_data = |
50 | | - start_offset + lane_id * NUM_PER_THREADS < hidden_size; |
51 | | - |
52 | | - if (is_valid_data) { |
53 | | - Load<T, NUM_PER_THREADS>( |
54 | | - input_now + start_offset + lane_id * NUM_PER_THREADS, &load_vec); |
55 | | - } else { |
56 | | -#pragma unroll |
57 | | - for (int vid = 0; vid < NUM_PER_THREADS; vid++) load_vec[vid] = T(0.f); |
58 | | - } |
59 | | - // get max value per thread |
60 | | - float max_value_thread = -5e4; |
61 | | -#pragma unroll |
62 | | - for (int vid = 0; vid < NUM_PER_THREADS; vid++) { |
63 | | - load_vec_float[vid] = static_cast<float>(load_vec[vid]); |
64 | | - max_value_thread = max(abs(load_vec_float[vid]), max_value_thread); |
65 | | - } |
66 | | - // get max value per warp |
67 | | - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 16), |
68 | | - max_value_thread); |
69 | | - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 8), |
70 | | - max_value_thread); |
71 | | - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 4), |
72 | | - max_value_thread); |
73 | | - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 2), |
74 | | - max_value_thread); |
75 | | - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 1), |
76 | | - max_value_thread); |
77 | | - // broadcast max_value |
78 | | - max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0); |
79 | | - max_value_thread = max(max_value_thread, epsilon); |
80 | | - |
81 | | - if (use_finegrained_range) { |
82 | | - max_value_thread *= 7.0f; |
83 | | - } |
84 | | - |
85 | | - float scale_to_store = max_value_thread / MAX_VALUE; |
86 | | - // quant |
87 | | -#pragma unroll |
88 | | - for (int vid = 0; vid < NUM_PER_THREADS; vid++) { |
89 | | - res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>( |
90 | | - load_vec_float[vid] * MAX_VALUE / max_value_thread); |
91 | | - } |
92 | | - // store |
93 | | - if (is_valid_data) |
94 | | - Store<phi::dtype::float8_e4m3fn, NUM_PER_THREADS>( |
95 | | - res_vec, |
96 | | - quanted_res_now + start_offset + lane_id * NUM_PER_THREADS); |
97 | | - if (lane_id == 0) { |
98 | | - quanted_scale_now[iter] = scale_to_store; |
99 | | - } |
100 | | - } |
101 | | - } |
102 | | -} |
103 | | - |
104 | | -std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input, |
105 | | - const int block_size) { |
106 | | - auto input_dim = input.dims(); |
107 | | - const int token_num = input_dim[0]; |
108 | | - const int hidden_size = input_dim[1]; |
109 | | - // Note(ZKK) here we use ceil_dive to support 4.5T runing on 8 GPUS |
110 | | - // where moe_intermediate_size is 448, can not be divided by 128. |
111 | | - const int hidden_size_scale = (hidden_size + block_size - 1) / block_size; |
112 | | - |
113 | | - auto quanted_x = GetEmptyTensor( |
114 | | - {token_num, hidden_size}, paddle::DataType::FLOAT8_E4M3FN, input.place()); |
115 | | - auto quanted_scale = GetEmptyTensor( |
116 | | - {token_num, hidden_size_scale}, paddle::DataType::FLOAT32, input.place()); |
117 | | - const int gridx = min(132 * 8, token_num); |
118 | | - const int blockx = min(1024, hidden_size / 128 * 32); |
119 | | - |
120 | | - bool use_finegrained_range = false; |
121 | | - char *env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE"); |
122 | | - if (env_var) { |
123 | | - use_finegrained_range = static_cast<bool>(std::stoi(env_var)); |
124 | | - } |
125 | | - |
126 | | - switch (input.dtype()) { |
127 | | - case paddle::DataType::BFLOAT16: |
128 | | - quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>( |
129 | | - input.data<paddle::bfloat16>(), |
130 | | - quanted_x.data<phi::dtype::float8_e4m3fn>(), |
131 | | - quanted_scale.data<float>(), |
132 | | - token_num, |
133 | | - hidden_size, |
134 | | - hidden_size_scale, |
135 | | - use_finegrained_range); |
136 | | - break; |
137 | | - case paddle::DataType::FLOAT16: |
138 | | - quant_per_token_per_block<<<gridx, blockx, 0, input.stream()>>>( |
139 | | - input.data<paddle::float16>(), |
140 | | - quanted_x.data<phi::dtype::float8_e4m3fn>(), |
141 | | - quanted_scale.data<float>(), |
142 | | - token_num, |
143 | | - hidden_size, |
144 | | - hidden_size_scale, |
145 | | - use_finegrained_range); |
146 | | - break; |
147 | | - default: |
148 | | - PD_THROW("Unsupported data type for PerTokenQuant"); |
149 | | - } |
150 | | - return {quanted_x, quanted_scale}; |
151 | | -} |
152 | | - |
153 | | -std::vector<std::vector<int64_t>> PerTokenQuantInferShape( |
154 | | - std::vector<int64_t> input_shape, const int block_size) { |
155 | | - const int token_num = input_shape[0]; |
156 | | - const int hidden_size = input_shape[1]; |
157 | | - const int hidden_size_scale = (hidden_size + block_size - 1) / block_size; |
158 | | - return {{token_num, hidden_size}, {token_num, hidden_size_scale}}; |
159 | | -} |
160 | | - |
161 | | -std::vector<paddle::DataType> PerTokenQuantInferDtype( |
162 | | - paddle::DataType input_dtype, const int block_size) { |
163 | | - return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32}; |
164 | | -} |
165 | | - |
166 | | -template <typename T> |
167 | | -__global__ void quant_per_token_per_block_padding( |
168 | | - const T *input, |
169 | | - phi::dtype::float8_e4m3fn *quanted_res, |
170 | | - float *quanted_scale, |
171 | | - const int token_num, |
172 | | - const int padded_token_num, |
173 | | - const int hidden_size, |
174 | | - const int hidden_size_scale, |
175 | | - const bool use_finegrained_range) { |
176 | | - const int bid = blockIdx.x; |
177 | | - const int tid = threadIdx.x; |
178 | | - const int warp_id = tid / 32; |
179 | | - const int lane_id = tid % 32; |
180 | | - const int num_warp = blockDim.x / 32; |
181 | | - static constexpr int NUM_PER_THREADS = 128 / 32; // 4 |
182 | | - static constexpr float MAX_VALUE = 448.f; |
183 | | - const int end_iter = hidden_size / 128; // warp_iter_num |
184 | | - AlignedVector<T, NUM_PER_THREADS> load_vec; |
185 | | - AlignedVector<float, NUM_PER_THREADS> load_vec_float; |
186 | | - AlignedVector<phi::dtype::float8_e4m3fn, NUM_PER_THREADS> res_vec; |
187 | | - for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) { |
188 | | - const T *input_now = input + token_idx * hidden_size; |
189 | | - phi::dtype::float8_e4m3fn *quanted_res_now = |
190 | | - quanted_res + token_idx * hidden_size; |
191 | | - // deal a block per warp |
192 | | - for (int iter = warp_id; iter < end_iter; iter += num_warp) { |
193 | | - float *quanted_scale_now = |
194 | | - quanted_scale + iter * padded_token_num + token_idx; |
195 | | - const int start_offset = iter * 128; |
196 | | - Load<T, NUM_PER_THREADS>( |
197 | | - input_now + start_offset + lane_id * NUM_PER_THREADS, &load_vec); |
198 | | - // get max value per thread |
199 | | - float max_value_thread = -5e4; |
200 | | -#pragma unroll |
201 | | - for (int vid = 0; vid < NUM_PER_THREADS; vid++) { |
202 | | - load_vec_float[vid] = static_cast<float>(load_vec[vid]); |
203 | | - max_value_thread = max(abs(load_vec_float[vid]), max_value_thread); |
204 | | - } |
205 | | - // get max value per warp |
206 | | - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 16), |
207 | | - max_value_thread); |
208 | | - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 8), |
209 | | - max_value_thread); |
210 | | - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 4), |
211 | | - max_value_thread); |
212 | | - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 2), |
213 | | - max_value_thread); |
214 | | - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 1), |
215 | | - max_value_thread); |
216 | | - // broadcast max_value |
217 | | - max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0); |
218 | | - max_value_thread = max(max_value_thread, epsilon); |
219 | | - |
220 | | - if (use_finegrained_range) { |
221 | | - max_value_thread *= 7.0f; |
222 | | - } |
223 | | - |
224 | | - float scale_to_store = max_value_thread / MAX_VALUE; |
225 | | - // quant |
226 | | -#pragma unroll |
227 | | - for (int vid = 0; vid < NUM_PER_THREADS; vid++) { |
228 | | - res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>( |
229 | | - load_vec_float[vid] * MAX_VALUE / max_value_thread); |
230 | | - } |
231 | | - // store |
232 | | - Store<phi::dtype::float8_e4m3fn, NUM_PER_THREADS>( |
233 | | - res_vec, quanted_res_now + start_offset + lane_id * NUM_PER_THREADS); |
234 | | - if (lane_id == 0) { |
235 | | - *quanted_scale_now = scale_to_store; |
236 | | - } |
237 | | - } |
238 | | - } |
239 | | -} |
240 | | - |
241 | | -std::vector<paddle::Tensor> PerTokenQuantPadding(paddle::Tensor &input, |
242 | | - const int block_size) { |
243 | | - using ScaleDtype = float; |
244 | | - |
245 | | - auto input_dim = input.dims(); |
246 | | - const int token_num = input_dim[0]; |
247 | | - const int hidden_size = input_dim[1]; |
248 | | - |
249 | | - PADDLE_ENFORCE(block_size == 128, "now only support block_size = 128"); |
250 | | - PADDLE_ENFORCE(hidden_size % 128 == 0, |
251 | | - "hidden_size must be divisible by 128"); |
252 | | - |
253 | | - const int hidden_size_scale = hidden_size / block_size; |
254 | | - auto quanted_x = GetEmptyTensor( |
255 | | - {token_num, hidden_size}, paddle::DataType::FLOAT8_E4M3FN, input.place()); |
256 | | - |
257 | | - const int tma_alignment_bytes = 16; |
258 | | - const int tma_alignment_elements = tma_alignment_bytes / sizeof(ScaleDtype); |
259 | | - const int padded_token_num = |
260 | | - ((token_num + tma_alignment_elements - 1) / tma_alignment_elements) * |
261 | | - tma_alignment_elements; |
262 | | - auto quanted_scale = GetEmptyTensor({padded_token_num, hidden_size_scale}, |
263 | | - {1, padded_token_num}, |
264 | | - paddle::DataType::FLOAT32, |
265 | | - input.place()); |
266 | | - const int gridx = min(132 * 8, token_num); |
267 | | - const int blockx = min(1024, hidden_size / 128 * 32); |
268 | | - |
269 | | - bool use_finegrained_range = false; |
270 | | - char *env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE"); |
271 | | - if (env_var) { |
272 | | - use_finegrained_range = static_cast<bool>(std::stoi(env_var)); |
273 | | - } |
274 | | - |
275 | | - switch (input.dtype()) { |
276 | | - case paddle::DataType::BFLOAT16: |
277 | | - quant_per_token_per_block_padding<<<gridx, blockx, 0, input.stream()>>>( |
278 | | - input.data<paddle::bfloat16>(), |
279 | | - quanted_x.data<phi::dtype::float8_e4m3fn>(), |
280 | | - quanted_scale.data<ScaleDtype>(), |
281 | | - token_num, |
282 | | - padded_token_num, |
283 | | - hidden_size, |
284 | | - hidden_size_scale, |
285 | | - use_finegrained_range); |
286 | | - break; |
287 | | - case paddle::DataType::FLOAT16: |
288 | | - quant_per_token_per_block_padding<<<gridx, blockx, 0, input.stream()>>>( |
289 | | - input.data<paddle::float16>(), |
290 | | - quanted_x.data<phi::dtype::float8_e4m3fn>(), |
291 | | - quanted_scale.data<ScaleDtype>(), |
292 | | - token_num, |
293 | | - padded_token_num, |
294 | | - hidden_size, |
295 | | - hidden_size_scale, |
296 | | - use_finegrained_range); |
297 | | - break; |
298 | | - default: |
299 | | - PD_THROW("Unsupported data type for PerTokenQuant"); |
300 | | - } |
301 | | - return {quanted_x, quanted_scale}; |
302 | | -} |
303 | | - |
304 | | -std::vector<std::vector<int64_t>> PerTokenQuantPaddingInferShape( |
305 | | - std::vector<int64_t> input_shape, const int block_size) { |
306 | | - using ScaleDtype = float; |
307 | | - |
308 | | - const int token_num = input_shape[0]; |
309 | | - const int hidden_size = input_shape[1]; |
310 | | - const int hidden_size_scale = hidden_size / block_size; |
311 | | - |
312 | | - const int tma_alignment_bytes = 16; |
313 | | - const int tma_alignment_elements = tma_alignment_bytes / sizeof(ScaleDtype); |
314 | | - const int padded_token_num = |
315 | | - ((token_num + tma_alignment_elements - 1) / tma_alignment_elements) * |
316 | | - tma_alignment_elements; |
317 | | - |
318 | | - return {{token_num, hidden_size}, {padded_token_num, hidden_size_scale}}; |
319 | | -} |
320 | | - |
321 | | -std::vector<paddle::DataType> PerTokenQuantPaddingInferDtype( |
322 | | - paddle::DataType input_dtype) { |
323 | | - return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32}; |
324 | | -} |
325 | | - |
326 | 19 | template <typename T> |
327 | 20 | __global__ void masked_quant_per_token_per_block( |
328 | 21 | const T *input, |
@@ -472,22 +165,6 @@ std::vector<paddle::Tensor> MaskedPerTokenQuant( |
472 | 165 | return {quanted_x, quanted_scale}; |
473 | 166 | } |
474 | 167 |
|
475 | | -PD_BUILD_STATIC_OP(per_token_quant) |
476 | | - .Inputs({"input"}) |
477 | | - .Outputs({"output", "output_scale"}) |
478 | | - .Attrs({"block_size: int"}) |
479 | | - .SetKernelFn(PD_KERNEL(PerTokenQuant)) |
480 | | - .SetInferShapeFn(PD_INFER_SHAPE(PerTokenQuantInferShape)) |
481 | | - .SetInferDtypeFn(PD_INFER_DTYPE(PerTokenQuantInferDtype)); |
482 | | - |
483 | | -PD_BUILD_STATIC_OP(per_token_quant_padding) |
484 | | - .Inputs({"input"}) |
485 | | - .Outputs({"output", "output_scale"}) |
486 | | - .Attrs({"block_size: int"}) |
487 | | - .SetKernelFn(PD_KERNEL(PerTokenQuantPadding)) |
488 | | - .SetInferShapeFn(PD_INFER_SHAPE(PerTokenQuantPaddingInferShape)) |
489 | | - .SetInferDtypeFn(PD_INFER_DTYPE(PerTokenQuantPaddingInferDtype)); |
490 | | - |
491 | 168 | PD_BUILD_STATIC_OP(masked_per_token_quant) |
492 | 169 | .Inputs({"input", "recv_expert_count"}) |
493 | 170 | .Outputs({"output", "output_scale"}) |
|
0 commit comments