39
39
#include " fbgemm_gpu/utils/cuda_block_count.h"
40
40
#include " fbgemm_gpu/utils/cuda_prelude.cuh"
41
41
#include " fbgemm_gpu/utils/device_sort.cuh"
42
+ #include " fbgemm_gpu/utils/kernel_launcher.cuh"
42
43
#include " fbgemm_gpu/utils/stochastic_rounding.cuh"
43
44
44
45
#if !( \
@@ -165,8 +166,8 @@ struct __align__(8) i8x8 {
165
166
};
166
167
167
168
__global__ void per_tensor_quantize_i8_kernel (
168
- at ::PackedTensorAccessor64<at::BFloat16, 1 , at::RestrictPtrTraits> X,
169
- at ::PackedTensorAccessor64<int8_t , 1 , at::RestrictPtrTraits> XQ,
169
+ pta ::PackedTensorAccessor64<at::BFloat16, 1 , at::RestrictPtrTraits> X,
170
+ pta ::PackedTensorAccessor64<int8_t , 1 , at::RestrictPtrTraits> XQ,
170
171
at::BFloat16* scale_device,
171
172
float inv_scale) {
172
173
auto N = X.size (0 );
@@ -237,16 +238,17 @@ at::Tensor per_tensor_quantize_i8(at::Tensor X, double scale) {
237
238
dim3 threads = kThreadsPerBlock ;
238
239
dim3 blocks =
239
240
cuda_calc_block_count (div_round_up (X.numel (), 8 ), kThreadsPerBlock );
240
- per_tensor_quantize_i8_kernel<<<
241
+
242
+ FBGEMM_LAUNCH_KERNEL (
243
+ (per_tensor_quantize_i8_kernel),
241
244
blocks,
242
245
threads,
243
246
0 ,
244
- at::cuda::getCurrentCUDAStream ()>>>(
245
- X.packed_accessor64< at::BFloat16, 1 , at::RestrictPtrTraits>( ),
246
- XQ.packed_accessor64< int8_t, 1, at::RestrictPtrTraits>( ),
247
+ at::cuda::getCurrentCUDAStream (),
248
+ PTA_B (X, at::BFloat16, 1 , 64 ),
249
+ PTA_B (XQ, int8_t , 1 , 64 ),
247
250
nullptr ,
248
251
inv_scale);
249
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
250
252
return XQ;
251
253
}
252
254
@@ -265,16 +267,16 @@ std::tuple<at::Tensor, at::Tensor> per_tensor_dynamic_quantize_i8(
265
267
dim3 blocks =
266
268
cuda_calc_block_count (div_round_up (X.numel (), 8 ), kThreadsPerBlock );
267
269
268
- per_tensor_quantize_i8_kernel<<<
270
+ FBGEMM_LAUNCH_KERNEL (
271
+ (per_tensor_quantize_i8_kernel),
269
272
blocks,
270
273
threads,
271
274
0 ,
272
- at::cuda::getCurrentCUDAStream ()>>>(
273
- X.packed_accessor64< at::BFloat16, 1 , at::RestrictPtrTraits>( ),
274
- XQ.packed_accessor64< int8_t, 1, at::RestrictPtrTraits>( ),
275
+ at::cuda::getCurrentCUDAStream (),
276
+ PTA_B (X, at::BFloat16, 1 , 64 ),
277
+ PTA_B (XQ, int8_t , 1 , 64 ),
275
278
scale.data_ptr <at::BFloat16>(),
276
279
0.0 );
277
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
278
280
return {XQ, scale};
279
281
}
280
282
0 commit comments