diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu index 72b36adefe..1baaac152c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu @@ -1214,9 +1214,16 @@ void invokeComputeScalesAndQuantizeMatrixCol( dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE); C10_CUDA_CHECK(cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream)); C10_CUDA_KERNEL_LAUNCH_CHECK(); - computeFP8QuantizeScaleColwise<<>>( - quant_ptr, input, numel, lda); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + FBGEMM_LAUNCH_KERNEL( + (computeFP8QuantizeScaleColwise), + grid, + block, + 0, + stream, + quant_ptr, + input, + numel, + lda); invokeQuantizeMatrixColwise(output, quant_ptr, input, numel, lda, stream); } @@ -1639,7 +1646,12 @@ void invokeFP4Quantization( // Launch the cvt kernel. if (useUE8M0) { - cvt_fp16_to_fp4<<>>( + FBGEMM_LAUNCH_KERNEL( + (cvt_fp16_to_fp4), + grid, + block, + 0, + stream, m, n, input, @@ -1647,7 +1659,12 @@ void invokeFP4Quantization( reinterpret_cast(output), reinterpret_cast(SFOuput)); } else { - cvt_fp16_to_fp4<<>>( + FBGEMM_LAUNCH_KERNEL( + (cvt_fp16_to_fp4), + grid, + block, + 0, + stream, m, n, input, @@ -1924,10 +1941,17 @@ void fp4_fused_amax_quantize( const dim3 block(blocksize, blocks_per_cta); const int blocks = ceil_div(numel, blocksize * blocks_per_cta); - compute_amax_and_quantize_kernel<__nv_bfloat16, 16, 4> - <<>>(x, y, numel, blocksize, global_amax_ptr); - - C10_CUDA_KERNEL_LAUNCH_CHECK(); + FBGEMM_LAUNCH_KERNEL( + (compute_amax_and_quantize_kernel<__nv_bfloat16, 16, 4>), + blocks, + block, + 0, + stream, + x, + y, + numel, + blocksize, + global_amax_ptr); } template @@ -1974,7 +1998,12 @@ void invokeComputeFP4GlobalAmax( constexpr dim3 grid(1024); int64_t numel_scale = numel; C10_CUDA_CHECK(cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream)); - computeFP4GlobalAmax<<>>( + FBGEMM_LAUNCH_KERNEL( + (computeFP4GlobalAmax), + grid, + block, + 0, + stream, quant_ptr, input, numel_scale, @@ -1982,7 +2011,6 @@ void invokeComputeFP4GlobalAmax( total_elements_per_slice, bs, scale_ub); - C10_CUDA_KERNEL_LAUNCH_CHECK(); } std::vector fake_quantize_nvfp4_per_tensor(