Skip to content

Commit 2c8ef86

Browse files
q10facebook-github-bot
authored andcommitted
Migrate GenAI quantize kernels to FBGEMM_LAUNCH_KERNEL, pt 1 (#4834)
Summary: Pull Request resolved: #4834 X-link: facebookresearch/FBGEMM#1861 - Migrate GenAI quantize kernels to `FBGEMM_LAUNCH_KERNEL`, pt 1 Reviewed By: ionuthristodorescu Differential Revision: D79978899 fbshipit-source-id: 69a11e082b633b476cde2620049f85801f441cf2
1 parent 7204685 commit 2c8ef86

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "fbgemm_gpu/utils/cuda_block_count.h"
4040
#include "fbgemm_gpu/utils/cuda_prelude.cuh"
4141
#include "fbgemm_gpu/utils/device_sort.cuh"
42+
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
4243
#include "fbgemm_gpu/utils/stochastic_rounding.cuh"
4344

4445
#if !( \
@@ -165,8 +166,8 @@ struct __align__(8) i8x8 {
165166
};
166167

167168
__global__ void per_tensor_quantize_i8_kernel(
168-
at::PackedTensorAccessor64<at::BFloat16, 1, at::RestrictPtrTraits> X,
169-
at::PackedTensorAccessor64<int8_t, 1, at::RestrictPtrTraits> XQ,
169+
pta::PackedTensorAccessor64<at::BFloat16, 1, at::RestrictPtrTraits> X,
170+
pta::PackedTensorAccessor64<int8_t, 1, at::RestrictPtrTraits> XQ,
170171
at::BFloat16* scale_device,
171172
float inv_scale) {
172173
auto N = X.size(0);
@@ -237,16 +238,17 @@ at::Tensor per_tensor_quantize_i8(at::Tensor X, double scale) {
237238
dim3 threads = kThreadsPerBlock;
238239
dim3 blocks =
239240
cuda_calc_block_count(div_round_up(X.numel(), 8), kThreadsPerBlock);
240-
per_tensor_quantize_i8_kernel<<<
241+
242+
FBGEMM_LAUNCH_KERNEL(
243+
(per_tensor_quantize_i8_kernel),
241244
blocks,
242245
threads,
243246
0,
244-
at::cuda::getCurrentCUDAStream()>>>(
245-
X.packed_accessor64<at::BFloat16, 1, at::RestrictPtrTraits>(),
246-
XQ.packed_accessor64<int8_t, 1, at::RestrictPtrTraits>(),
247+
at::cuda::getCurrentCUDAStream(),
248+
PTA_B(X, at::BFloat16, 1, 64),
249+
PTA_B(XQ, int8_t, 1, 64),
247250
nullptr,
248251
inv_scale);
249-
C10_CUDA_KERNEL_LAUNCH_CHECK();
250252
return XQ;
251253
}
252254

@@ -265,16 +267,16 @@ std::tuple<at::Tensor, at::Tensor> per_tensor_dynamic_quantize_i8(
265267
dim3 blocks =
266268
cuda_calc_block_count(div_round_up(X.numel(), 8), kThreadsPerBlock);
267269

268-
per_tensor_quantize_i8_kernel<<<
270+
FBGEMM_LAUNCH_KERNEL(
271+
(per_tensor_quantize_i8_kernel),
269272
blocks,
270273
threads,
271274
0,
272-
at::cuda::getCurrentCUDAStream()>>>(
273-
X.packed_accessor64<at::BFloat16, 1, at::RestrictPtrTraits>(),
274-
XQ.packed_accessor64<int8_t, 1, at::RestrictPtrTraits>(),
275+
at::cuda::getCurrentCUDAStream(),
276+
PTA_B(X, at::BFloat16, 1, 64),
277+
PTA_B(XQ, int8_t, 1, 64),
275278
scale.data_ptr<at::BFloat16>(),
276279
0.0);
277-
C10_CUDA_KERNEL_LAUNCH_CHECK();
278280
return {XQ, scale};
279281
}
280282

0 commit comments

Comments
 (0)