@@ -357,16 +357,16 @@ at::Tensor silu_mul_quantize_i8(at::Tensor X1, at::Tensor X2, double scale) {
357
357
constexpr int32_t kThreadsPerBlock = 1024 ;
358
358
dim3 threads = std::min<int32_t >(kThreadsPerBlock , X1.size (1 ) / 8 );
359
359
dim3 blocks = X1.size (0 );
360
- silu_mul_quantize_i8_kernel<<<
360
+ FBGEMM_LAUNCH_KERNEL (
361
+ (silu_mul_quantize_i8_kernel),
361
362
blocks,
362
363
threads,
363
364
0 ,
364
- at::cuda::getCurrentCUDAStream ()>>>(
365
+ at::cuda::getCurrentCUDAStream (),
365
366
X1.packed_accessor64 <at::BFloat16, 2 , at::RestrictPtrTraits>(),
366
367
X2.packed_accessor64 <at::BFloat16, 2 , at::RestrictPtrTraits>(),
367
368
Y.packed_accessor64 <int8_t , 2 , at::RestrictPtrTraits>(),
368
369
inv_scale);
369
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
370
370
return Y;
371
371
}
372
372
@@ -436,7 +436,7 @@ DEVICE_INLINE float stochastic_rounding_scalar_fp8(
436
436
}
437
437
438
438
template <bool QUANTIZE, typename T_OUT, typename T_S, typename T_IN>
439
- __global__ void scaleMatrix (
439
+ __global__ void scaleMatrix1 (
440
440
T_OUT* const output,
441
441
T_S const * const input_scale,
442
442
T_IN const * const input,
@@ -450,7 +450,7 @@ __global__ void scaleMatrix(
450
450
}
451
451
452
452
template <bool QUANTIZE, typename T_OUT, typename T_S, typename T_IN>
453
- __global__ void scaleMatrix (
453
+ __global__ void scaleMatrix2 (
454
454
T_OUT* const output,
455
455
T_S const * const input_scale,
456
456
T_IN const * const input,
@@ -485,7 +485,7 @@ __global__ void scaleMatrix(
485
485
}
486
486
487
487
template <bool QUANTIZE, typename T_OUT, typename T_S, typename T_IN>
488
- __global__ void scaleMatrixRowwise (
488
+ __global__ void scaleMatrixRowwise2 (
489
489
T_OUT* const output,
490
490
T_S const * const input_scale,
491
491
T_IN const * const input,
@@ -526,7 +526,7 @@ __global__ void scaleMatrixRowwise(
526
526
}
527
527
528
528
template <bool QUANTIZE, typename T_OUT, typename T_S, typename T_IN>
529
- __global__ void scaleMatrixRowwise (
529
+ __global__ void scaleMatrixRowwise1 (
530
530
T_OUT* const output,
531
531
T_S const * const input_scale,
532
532
T_IN const * const input,
@@ -563,7 +563,7 @@ void invokeQuantizeMatrix(
563
563
const int64_t numel,
564
564
const int64_t lda,
565
565
bool stochastic_rounding,
566
- const cudaStream_t stream) {
566
+ const c10::cuda::CUDAStream stream) {
567
567
constexpr dim3 grid (1024 );
568
568
const dim3 block (CTA_SIZE);
569
569
if (stochastic_rounding) {
@@ -572,13 +572,30 @@ void invokeQuantizeMatrix(
572
572
std::lock_guard<std::mutex> lock (gen.mutex ());
573
573
rng_engine_inputs =
574
574
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_cuda_state (4 );
575
- scaleMatrix<true ><<<grid, block, 0 , stream>>> (
576
- output, input_scale, input, numel, lda, rng_engine_inputs);
577
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
575
+ FBGEMM_LAUNCH_KERNEL (
576
+ (scaleMatrix2<true , T_OUT, T_S, T_IN>),
577
+ grid,
578
+ block,
579
+ 0 ,
580
+ stream,
581
+ output,
582
+ input_scale,
583
+ input,
584
+ numel,
585
+ lda,
586
+ rng_engine_inputs);
578
587
} else {
579
- scaleMatrix<true >
580
- <<<grid, block, 0 , stream>>> (output, input_scale, input, numel, lda);
581
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
588
+ FBGEMM_LAUNCH_KERNEL (
589
+ (scaleMatrix1<true , T_OUT, T_S, T_IN>),
590
+ grid,
591
+ block,
592
+ 0 ,
593
+ stream,
594
+ output,
595
+ input_scale,
596
+ input,
597
+ numel,
598
+ lda);
582
599
}
583
600
}
584
601
@@ -590,7 +607,7 @@ void invokeQuantizeMatrixRowwise(
590
607
const int64_t numel,
591
608
const int64_t lda,
592
609
bool stochastic_rounding,
593
- const cudaStream_t stream) {
610
+ const c10::cuda::CUDAStream stream) {
594
611
constexpr dim3 grid (1024 );
595
612
const dim3 block (CTA_SIZE);
596
613
@@ -601,14 +618,30 @@ void invokeQuantizeMatrixRowwise(
601
618
rng_engine_inputs =
602
619
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_cuda_state (4 );
603
620
604
- scaleMatrixRowwise<true ><<<grid, block, 0 , stream>>> (
605
- output, input_scale, input, numel, lda, rng_engine_inputs);
606
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
607
-
621
+ FBGEMM_LAUNCH_KERNEL (
622
+ (scaleMatrixRowwise2<true , T_OUT, T_S, T_IN>),
623
+ grid,
624
+ block,
625
+ 0 ,
626
+ stream,
627
+ output,
628
+ input_scale,
629
+ input,
630
+ numel,
631
+ lda,
632
+ rng_engine_inputs);
608
633
} else {
609
- scaleMatrixRowwise<true >
610
- <<<grid, block, 0 , stream>>> (output, input_scale, input, numel, lda);
611
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
634
+ FBGEMM_LAUNCH_KERNEL (
635
+ (scaleMatrixRowwise1<true , T_OUT, T_S, T_IN>),
636
+ grid,
637
+ block,
638
+ 0 ,
639
+ stream,
640
+ output,
641
+ input_scale,
642
+ input,
643
+ numel,
644
+ lda);
612
645
}
613
646
}
614
647
@@ -1075,7 +1108,7 @@ void invokeComputeScalesAndQuantizeMatrix(
1075
1108
const int64_t lda,
1076
1109
const float * scale_ub,
1077
1110
bool stochastic_rounding,
1078
- cudaStream_t stream) {
1111
+ const c10::cuda::CUDAStream stream) {
1079
1112
dim3 grid (numel / lda);
1080
1113
#ifdef USE_ROCM
1081
1114
bool use_shmem = true ;
0 commit comments