@@ -1167,10 +1167,15 @@ at::Tensor nope_qkv_varseq_prefill(
1167
1167
is_precalculated_qparam =
1168
1168
static_cast <bool *>(kv_quant_scale_precomputed.value ().data_ptr ());
1169
1169
}
1170
- rope_xpos_qkv_varseq_prefill_kernel_fp8<
1171
- PositionEmbeddingMode::NOPE,
1172
- CacheLogicalDtype::FP8,
1173
- 1 ><<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
1170
+ FBGEMM_LAUNCH_KERNEL (
1171
+ (rope_xpos_qkv_varseq_prefill_kernel_fp8<
1172
+ PositionEmbeddingMode::NOPE,
1173
+ CacheLogicalDtype::FP8,
1174
+ 1 >),
1175
+ blocks,
1176
+ threads,
1177
+ 0 ,
1178
+ at::cuda::getCurrentCUDAStream (),
1174
1179
XQ.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1175
1180
XK.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1176
1181
XV.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
@@ -1205,7 +1210,7 @@ at::Tensor nope_qkv_varseq_prefill(
1205
1210
k_norm,
1206
1211
amax_ptr,
1207
1212
is_precalculated_qparam);
1208
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1213
+
1209
1214
} else {
1210
1215
CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL (
1211
1216
1 ,
@@ -1367,10 +1372,15 @@ at::Tensor nope_qkv_decoding(
1367
1372
if (amax_qkv.has_value ()) {
1368
1373
amax_ptr = static_cast <float *>(amax_qkv.value ().data_ptr ());
1369
1374
}
1370
- rope_xpos_qkv_varseq_prefill_kernel_fp8<
1371
- PositionEmbeddingMode::NOPE,
1372
- CacheLogicalDtype::FP8,
1373
- 1 ><<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
1375
+ FBGEMM_LAUNCH_KERNEL (
1376
+ (rope_xpos_qkv_varseq_prefill_kernel_fp8<
1377
+ PositionEmbeddingMode::NOPE,
1378
+ CacheLogicalDtype::FP8,
1379
+ 1 >),
1380
+ blocks,
1381
+ threads,
1382
+ 0 ,
1383
+ at::cuda::getCurrentCUDAStream (),
1374
1384
XQ.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1375
1385
XK.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1376
1386
XV.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
@@ -1406,8 +1416,6 @@ at::Tensor nope_qkv_decoding(
1406
1416
amax_ptr,
1407
1417
nullptr );
1408
1418
1409
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1410
-
1411
1419
} else {
1412
1420
CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL (
1413
1421
1 ,
@@ -1596,10 +1604,15 @@ at::Tensor rope_qkv_varseq_prefill(
1596
1604
is_precalculated_qparam =
1597
1605
static_cast <bool *>(kv_quant_scale_precomputed.value ().data_ptr ());
1598
1606
}
1599
- rope_xpos_qkv_varseq_prefill_kernel_fp8<
1600
- PositionEmbeddingMode::ROPE,
1601
- CacheLogicalDtype::FP8,
1602
- 1 ><<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
1607
+ FBGEMM_LAUNCH_KERNEL (
1608
+ (rope_xpos_qkv_varseq_prefill_kernel_fp8<
1609
+ PositionEmbeddingMode::ROPE,
1610
+ CacheLogicalDtype::FP8,
1611
+ 1 >),
1612
+ blocks,
1613
+ threads,
1614
+ 0 ,
1615
+ at::cuda::getCurrentCUDAStream (),
1603
1616
XQ.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1604
1617
XK.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1605
1618
XV.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
@@ -1634,7 +1647,6 @@ at::Tensor rope_qkv_varseq_prefill(
1634
1647
k_norm,
1635
1648
amax_ptr,
1636
1649
is_precalculated_qparam);
1637
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1638
1650
1639
1651
} else {
1640
1652
CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL (
@@ -1970,10 +1982,15 @@ at::Tensor rope_qkv_decoding(
1970
1982
if (amax_qkv.has_value ()) {
1971
1983
amax_ptr = static_cast <float *>(amax_qkv.value ().data_ptr ());
1972
1984
}
1973
- rope_xpos_qkv_varseq_prefill_kernel_fp8<
1974
- PositionEmbeddingMode::ROPE,
1975
- CacheLogicalDtype::FP8,
1976
- 1 ><<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
1985
+ FBGEMM_LAUNCH_KERNEL (
1986
+ (rope_xpos_qkv_varseq_prefill_kernel_fp8<
1987
+ PositionEmbeddingMode::ROPE,
1988
+ CacheLogicalDtype::FP8,
1989
+ 1 >),
1990
+ blocks,
1991
+ threads,
1992
+ 0 ,
1993
+ at::cuda::getCurrentCUDAStream (),
1977
1994
XQ.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1978
1995
XK.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
1979
1996
XV.packed_accessor32 <at::BFloat16, 3 , at::RestrictPtrTraits>(),
@@ -2009,7 +2026,6 @@ at::Tensor rope_qkv_decoding(
2009
2026
amax_ptr,
2010
2027
nullptr );
2011
2028
2012
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
2013
2029
} else {
2014
2030
CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL (
2015
2031
1 ,
0 commit comments