Skip to content

Commit 4c1ac5f

Browse files
strgrbZhang Kaihong
andauthored
Support cuda<12.8 built for trtllm_allreduce_fusion. (#1508)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> I want to use trtllm_allreduce_fusion with cuda<12.8 for hopper gpu in sglang, so I wrap fp4 code with cuda version check to make it compiled. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Zhang Kaihong <[email protected]>
1 parent 6fb5105 commit 4c1ac5f

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

include/flashinfer/comm/trtllm_allreduce_fusion.cuh

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
#include <cooperative_groups.h>
2+
#include <cuda.h>
23
#include <cuda_bf16.h>
34
#include <cuda_fp16.h>
5+
6+
#if CUDA_VERSION >= 120800
47
#include <cuda_fp4.h>
8+
#endif
59

610
#include <cuda/std/optional>
711
#include <tuple>
@@ -532,6 +536,7 @@ __forceinline__ __device__ uint32_t pack_bytes(uint8_t c0, uint8_t c1, uint8_t c
532536
return (val3 << 24) | (val2 << 16) | (val1 << 8) | val0;
533537
}
534538

539+
#if CUDA_VERSION >= 120800
535540
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
536541
// NOTE: bypass sm_100 requirement by __nv_cvt_float2_to_fp4x2
537542
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
@@ -672,6 +677,8 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(vec_t<T, VEC_SIZE>& vec, float SFScaleV
672677
#endif
673678
}
674679

680+
#endif
681+
675682
} // namespace utils
676683

677684
template <typename T, uint32_t VEC_SIZE>
@@ -943,14 +950,17 @@ class FusedOp {
943950
}
944951
}
945952

953+
#if CUDA_VERSION >= 120800
946954
if constexpr (GetQuantType<Pattern> == QuantType::kFP4) {
947955
// NOTE(Yingyi): might update later
948956
auto sf_out = utils::cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2>(
949957
std::nullopt /* batchIdx */, token_id, m_access_id_in_token, std::nullopt /* numRows */,
950958
m_params.hidden_dim, reinterpret_cast<uint32_t*>(m_params.scale_out), m_params.layout);
951959
reinterpret_cast<uint32_t*>(m_params.quant_out)[m_access_id] =
952960
utils::cvt_warp_fp16_to_fp4<T, VEC_SIZE>(val, m_scale_factor, sf_out);
953-
} else if constexpr (GetQuantType<Pattern> == QuantType::kFP8) {
961+
} else
962+
#endif
963+
if constexpr (GetQuantType<Pattern> == QuantType::kFP8) {
954964
using PackedQuantizedType = std::conditional_t<std::is_same_v<T, float>, float, float2>;
955965
PackedQuantizedType ret;
956966
#pragma unroll
@@ -1431,7 +1441,7 @@ cudaError_t allreduce_fusion_op(AllReduceFusionParams<T> const& params, bool lau
14311441
DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNormFP8Quant, NRanks); \
14321442
break; \
14331443
case AllReduceFusionPattern::kARResidualRMSNormFP4Quant: \
1434-
if constexpr (!std::is_same_v<T, float>) { \
1444+
if constexpr (!std::is_same_v<T, float> && CUDA_VERSION >= 120800) { \
14351445
DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNormFP4Quant, NRanks); \
14361446
} else { \
14371447
FLASHINFER_CHECK(false, "FP4Quant pattern cannot work with DType=float!"); \
@@ -1441,7 +1451,7 @@ cudaError_t allreduce_fusion_op(AllReduceFusionParams<T> const& params, bool lau
14411451
DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNormOutFP8Quant, NRanks); \
14421452
break; \
14431453
case AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant: \
1444-
if constexpr (!std::is_same_v<T, float>) { \
1454+
if constexpr (!std::is_same_v<T, float> && CUDA_VERSION >= 120800) { \
14451455
DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant, NRanks); \
14461456
} else { \
14471457
FLASHINFER_CHECK(false, "OutFP4Quant pattern cannot work with DType=float!"); \

include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
#include <cooperative_groups.h>
2+
#include <cuda.h>
23
#include <cuda_bf16.h>
34
#include <cuda_fp16.h>
5+
6+
#if CUDA_VERSION >= 120800
47
#include <cuda_fp4.h>
8+
#endif
59

610
#include <cuda/std/optional>
711
#include <tuple>
@@ -519,6 +523,7 @@ __forceinline__ __device__ uint32_t pack_bytes(uint8_t c0, uint8_t c1, uint8_t c
519523
return (val3 << 24) | (val2 << 16) | (val1 << 8) | val0;
520524
}
521525

526+
#if CUDA_VERSION >= 120800
522527
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
523528
// NOTE:bypass sm_100 requirement by __nv_cvt_float2_to_fp4x2
524529
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
@@ -658,6 +663,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(vec_t<T, VEC_SIZE>& vec, float SFScaleV
658663
return 0;
659664
#endif
660665
}
666+
#endif
661667
} // namespace utils
662668

663669
template <typename T>
@@ -828,6 +834,7 @@ __device__ __forceinline__ void fused_op(vec_t<T, VEC_SIZE> const& val, int acce
828834
if constexpr (NormOut) {
829835
norm_val.store(reinterpret_cast<T*>(params.norm_out) + access_id * VEC_SIZE);
830836
}
837+
#if CUDA_VERSION >= 120800
831838
if constexpr (QuantOut) {
832839
constexpr int SF_VEC_SIZE = 16;
833840
auto sf_out = utils::cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2>(
@@ -836,6 +843,7 @@ __device__ __forceinline__ void fused_op(vec_t<T, VEC_SIZE> const& val, int acce
836843
reinterpret_cast<uint32_t*>(params.quant_out)[access_id] =
837844
utils::cvt_warp_fp16_to_fp4<T, VEC_SIZE>(norm_val, params.scale_factor, sf_out);
838845
}
846+
#endif
839847
}
840848

841849
template <typename T>
@@ -1486,6 +1494,12 @@ cudaError_t moefinalize_allreduce_fusion_op(MoeFinalizeAllReduceFusionParams<T>
14861494
auto status = DISPATCH_MOEFINALIZEREDUCTION(
14871495
params.nranks, params.residual_out, params.rms_gamma, params.quant_out, N_RANKS, RES, RMS,
14881496
QUANT, [&]() -> cudaError_t {
1497+
if constexpr (CUDA_VERSION < 120800 && QUANT) {
1498+
FLASHINFER_CHECK(false,
1499+
"cuda version should be greater equal than 12.8 with "
1500+
"trtllm_moe_allreduce_fusion quant");
1501+
return cudaErrorNotSupported;
1502+
}
14891503
FLASHINFER_CUDA_CALL(
14901504
(moefinalize_allreduce_fusion_kernel_launcher<T, N_RANKS, RES, RMS, QUANT>(
14911505
(params), (launch_with_pdl))));

0 commit comments

Comments
 (0)