Skip to content

Commit 3b0269b

Browse files
q10facebook-github-bot
authored andcommitted
Migrate GenAI quantize kernels to FBGEMM_LAUNCH_KERNEL, pt 4
Summary: X-link: facebookresearch/FBGEMM#1885 - Migrate GenAI quantize kernels to `FBGEMM_LAUNCH_KERNEL`, pt 4 Reviewed By: cthi Differential Revision: D81540820 fbshipit-source-id: e18a894c626d7b030bfc7c9da62ec5be24f158ab
1 parent 270be2e commit 3b0269b

File tree

1 file changed

+39
-11
lines changed

1 file changed

+39
-11
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,9 +1214,16 @@ void invokeComputeScalesAndQuantizeMatrixCol(
12141214
dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE);
12151215
C10_CUDA_CHECK(cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream));
12161216
C10_CUDA_KERNEL_LAUNCH_CHECK();
1217-
computeFP8QuantizeScaleColwise<<<grid, block, 0, stream>>>(
1218-
quant_ptr, input, numel, lda);
1219-
C10_CUDA_KERNEL_LAUNCH_CHECK();
1217+
FBGEMM_LAUNCH_KERNEL(
1218+
(computeFP8QuantizeScaleColwise<T_S, T_IN>),
1219+
grid,
1220+
block,
1221+
0,
1222+
stream,
1223+
quant_ptr,
1224+
input,
1225+
numel,
1226+
lda);
12201227
invokeQuantizeMatrixColwise(output, quant_ptr, input, numel, lda, stream);
12211228
}
12221229

@@ -1639,15 +1646,25 @@ void invokeFP4Quantization(
16391646

16401647
// Launch the cvt kernel.
16411648
if (useUE8M0) {
1642-
cvt_fp16_to_fp4<T, true><<<grid, block, 0, stream>>>(
1649+
FBGEMM_LAUNCH_KERNEL(
1650+
(cvt_fp16_to_fp4<T, true>),
1651+
grid,
1652+
block,
1653+
0,
1654+
stream,
16431655
m,
16441656
n,
16451657
input,
16461658
SFScale,
16471659
reinterpret_cast<uint32_t*>(output),
16481660
reinterpret_cast<uint32_t*>(SFOuput));
16491661
} else {
1650-
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
1662+
FBGEMM_LAUNCH_KERNEL(
1663+
(cvt_fp16_to_fp4<T, false>),
1664+
grid,
1665+
block,
1666+
0,
1667+
stream,
16511668
m,
16521669
n,
16531670
input,
@@ -1924,10 +1941,17 @@ void fp4_fused_amax_quantize(
19241941
const dim3 block(blocksize, blocks_per_cta);
19251942
const int blocks = ceil_div(numel, blocksize * blocks_per_cta);
19261943

1927-
compute_amax_and_quantize_kernel<__nv_bfloat16, 16, 4>
1928-
<<<blocks, block, 0, stream>>>(x, y, numel, blocksize, global_amax_ptr);
1929-
1930-
C10_CUDA_KERNEL_LAUNCH_CHECK();
1944+
FBGEMM_LAUNCH_KERNEL(
1945+
(compute_amax_and_quantize_kernel<__nv_bfloat16, 16, 4>),
1946+
blocks,
1947+
block,
1948+
0,
1949+
stream,
1950+
x,
1951+
y,
1952+
numel,
1953+
blocksize,
1954+
global_amax_ptr);
19311955
}
19321956

19331957
template <typename T_S, typename T_W>
@@ -1974,15 +1998,19 @@ void invokeComputeFP4GlobalAmax(
19741998
constexpr dim3 grid(1024);
19751999
int64_t numel_scale = numel;
19762000
C10_CUDA_CHECK(cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream));
1977-
computeFP4GlobalAmax<<<grid, block, 0, stream>>>(
2001+
FBGEMM_LAUNCH_KERNEL(
2002+
(computeFP4GlobalAmax<T_S, T_IN>),
2003+
grid,
2004+
block,
2005+
0,
2006+
stream,
19782007
quant_ptr,
19792008
input,
19802009
numel_scale,
19812010
lda,
19822011
total_elements_per_slice,
19832012
bs,
19842013
scale_ub);
1985-
C10_CUDA_KERNEL_LAUNCH_CHECK();
19862014
}
19872015

19882016
std::vector<at::Tensor> fake_quantize_nvfp4_per_tensor(

0 commit comments

Comments
 (0)