Skip to content

Commit 60fc073

Browse files
q10facebook-github-bot
authored andcommitted
Migrate GenAI quantize kernels to FBGEMM_LAUNCH_KERNEL, pt 2 (#4849)
Summary: Pull Request resolved: #4849 - Migrate GenAI quantize kernels to `FBGEMM_LAUNCH_KERNEL`, pt 2 Reviewed By: cthi Differential Revision: D81538270 fbshipit-source-id: dfb76d833b55a908dd4bfbe2222b0fdd88f4ea61
1 parent 517b712 commit 60fc073

File tree

1 file changed

+56
-23
lines changed

1 file changed

+56
-23
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -357,16 +357,16 @@ at::Tensor silu_mul_quantize_i8(at::Tensor X1, at::Tensor X2, double scale) {
357357
constexpr int32_t kThreadsPerBlock = 1024;
358358
dim3 threads = std::min<int32_t>(kThreadsPerBlock, X1.size(1) / 8);
359359
dim3 blocks = X1.size(0);
360-
silu_mul_quantize_i8_kernel<<<
360+
FBGEMM_LAUNCH_KERNEL(
361+
(silu_mul_quantize_i8_kernel),
361362
blocks,
362363
threads,
363364
0,
364-
at::cuda::getCurrentCUDAStream()>>>(
365+
at::cuda::getCurrentCUDAStream(),
365366
X1.packed_accessor64<at::BFloat16, 2, at::RestrictPtrTraits>(),
366367
X2.packed_accessor64<at::BFloat16, 2, at::RestrictPtrTraits>(),
367368
Y.packed_accessor64<int8_t, 2, at::RestrictPtrTraits>(),
368369
inv_scale);
369-
C10_CUDA_KERNEL_LAUNCH_CHECK();
370370
return Y;
371371
}
372372

@@ -436,7 +436,7 @@ DEVICE_INLINE float stochastic_rounding_scalar_fp8(
436436
}
437437

438438
template <bool QUANTIZE, typename T_OUT, typename T_S, typename T_IN>
439-
__global__ void scaleMatrix(
439+
__global__ void scaleMatrix1(
440440
T_OUT* const output,
441441
T_S const* const input_scale,
442442
T_IN const* const input,
@@ -450,7 +450,7 @@ __global__ void scaleMatrix(
450450
}
451451

452452
template <bool QUANTIZE, typename T_OUT, typename T_S, typename T_IN>
453-
__global__ void scaleMatrix(
453+
__global__ void scaleMatrix2(
454454
T_OUT* const output,
455455
T_S const* const input_scale,
456456
T_IN const* const input,
@@ -485,7 +485,7 @@ __global__ void scaleMatrix(
485485
}
486486

487487
template <bool QUANTIZE, typename T_OUT, typename T_S, typename T_IN>
488-
__global__ void scaleMatrixRowwise(
488+
__global__ void scaleMatrixRowwise2(
489489
T_OUT* const output,
490490
T_S const* const input_scale,
491491
T_IN const* const input,
@@ -526,7 +526,7 @@ __global__ void scaleMatrixRowwise(
526526
}
527527

528528
template <bool QUANTIZE, typename T_OUT, typename T_S, typename T_IN>
529-
__global__ void scaleMatrixRowwise(
529+
__global__ void scaleMatrixRowwise1(
530530
T_OUT* const output,
531531
T_S const* const input_scale,
532532
T_IN const* const input,
@@ -563,7 +563,7 @@ void invokeQuantizeMatrix(
563563
const int64_t numel,
564564
const int64_t lda,
565565
bool stochastic_rounding,
566-
const cudaStream_t stream) {
566+
const c10::cuda::CUDAStream stream) {
567567
constexpr dim3 grid(1024);
568568
const dim3 block(CTA_SIZE);
569569
if (stochastic_rounding) {
@@ -572,13 +572,30 @@ void invokeQuantizeMatrix(
572572
std::lock_guard<std::mutex> lock(gen.mutex());
573573
rng_engine_inputs =
574574
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_cuda_state(4);
575-
scaleMatrix<true><<<grid, block, 0, stream>>>(
576-
output, input_scale, input, numel, lda, rng_engine_inputs);
577-
C10_CUDA_KERNEL_LAUNCH_CHECK();
575+
FBGEMM_LAUNCH_KERNEL(
576+
(scaleMatrix2<true, T_OUT, T_S, T_IN>),
577+
grid,
578+
block,
579+
0,
580+
stream,
581+
output,
582+
input_scale,
583+
input,
584+
numel,
585+
lda,
586+
rng_engine_inputs);
578587
} else {
579-
scaleMatrix<true>
580-
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
581-
C10_CUDA_KERNEL_LAUNCH_CHECK();
588+
FBGEMM_LAUNCH_KERNEL(
589+
(scaleMatrix1<true, T_OUT, T_S, T_IN>),
590+
grid,
591+
block,
592+
0,
593+
stream,
594+
output,
595+
input_scale,
596+
input,
597+
numel,
598+
lda);
582599
}
583600
}
584601

@@ -590,7 +607,7 @@ void invokeQuantizeMatrixRowwise(
590607
const int64_t numel,
591608
const int64_t lda,
592609
bool stochastic_rounding,
593-
const cudaStream_t stream) {
610+
const c10::cuda::CUDAStream stream) {
594611
constexpr dim3 grid(1024);
595612
const dim3 block(CTA_SIZE);
596613

@@ -601,14 +618,30 @@ void invokeQuantizeMatrixRowwise(
601618
rng_engine_inputs =
602619
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_cuda_state(4);
603620

604-
scaleMatrixRowwise<true><<<grid, block, 0, stream>>>(
605-
output, input_scale, input, numel, lda, rng_engine_inputs);
606-
C10_CUDA_KERNEL_LAUNCH_CHECK();
607-
621+
FBGEMM_LAUNCH_KERNEL(
622+
(scaleMatrixRowwise2<true, T_OUT, T_S, T_IN>),
623+
grid,
624+
block,
625+
0,
626+
stream,
627+
output,
628+
input_scale,
629+
input,
630+
numel,
631+
lda,
632+
rng_engine_inputs);
608633
} else {
609-
scaleMatrixRowwise<true>
610-
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
611-
C10_CUDA_KERNEL_LAUNCH_CHECK();
634+
FBGEMM_LAUNCH_KERNEL(
635+
(scaleMatrixRowwise1<true, T_OUT, T_S, T_IN>),
636+
grid,
637+
block,
638+
0,
639+
stream,
640+
output,
641+
input_scale,
642+
input,
643+
numel,
644+
lda);
612645
}
613646
}
614647

@@ -1075,7 +1108,7 @@ void invokeComputeScalesAndQuantizeMatrix(
10751108
const int64_t lda,
10761109
const float* scale_ub,
10771110
bool stochastic_rounding,
1078-
cudaStream_t stream) {
1111+
const c10::cuda::CUDAStream stream) {
10791112
dim3 grid(numel / lda);
10801113
#ifdef USE_ROCM
10811114
bool use_shmem = true;

0 commit comments

Comments
 (0)