Skip to content

Commit fe29ed6

Browse files
bugfix: guard fp8 e8m0 and e2m1 compile (#1287)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Fixes #1282 ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent de55a8f commit fe29ed6

File tree

5 files changed

+39
-0
lines changed

5 files changed

+39
-0
lines changed

β€Žcsrc/nv_internal/tensorrt_llm/kernels/quantization.cuh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,12 +419,16 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
419419
float outputScale;
420420
// Write the SF to global memory (STG.8).
421421
if constexpr (UE8M0_SF) {
422+
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
422423
__nv_fp8_e8m0 tmp;
423424
// Scale the max value to the range of E2m1.
424425
vecMax *= reciprocal_approximate_ftz(6.0f);
425426
tmp.__x = __nv_cvt_float_to_e8m0(vecMax, __NV_SATFINITE, cudaRoundPosInf);
426427
fp8SFVal = tmp.__x;
427428
outputScale = exp2f_rcp(fp8SFVal);
429+
#else
430+
#error "FP8 E8M0 support requires CUDA 12.8 or newer."
431+
#endif
428432
} else {
429433
// Get the SF (max value of the vector / max value of e2m1).
430434
// maximum value of e2m1 = 6.0.
@@ -511,16 +515,21 @@ __device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
511515
uint8_t fp8SFVal;
512516
// Write the SF to global memory (STG.8).
513517
if constexpr (UE8M0_SF) {
518+
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
514519
__nv_fp8_e8m0 tmp;
515520
tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf);
516521
SFValue = static_cast<float>(tmp);
517522
fp8SFVal = tmp.__x;
523+
#else
524+
#error "FP8 E8M0 support requires CUDA 12.8 or newer."
525+
#endif
518526
} else {
519527
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
520528
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
521529
fp8SFVal = tmp.__x;
522530
SFValue = static_cast<float>(tmp);
523531
}
532+
524533
// Get the output scale.
525534
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal))
526535
float outputScale = SFValue != 0 ? SFScaleVal * reciprocal_approximate_ftz(SFValue) : 0.0f;
@@ -551,6 +560,7 @@ __device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
551560
}
552561

553562
// Quantizes the provided PackedVec into the uint64_t output
563+
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
554564
template <class Type, int SF_VEC_SIZE>
555565
__device__ uint64_t cvt_warp_fp16_to_mxfp8(PackedVec<Type>& vec, uint8_t* SFout) {
556566
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
@@ -612,6 +622,9 @@ __device__ uint64_t cvt_warp_fp16_to_mxfp8(PackedVec<Type>& vec, uint8_t* SFout)
612622
return 0;
613623
#endif
614624
}
625+
#else
626+
#error "FP8 E8M0 support requires CUDA 12.8 or newer."
627+
#endif
615628

616629
inline __host__ __device__ int64_t get_sf_out_offset_128x4(std::optional<int> batchIdx, int mIdx,
617630
int kIdx, std::optional<int> numRows,

β€Žcsrc/pytorch_extension_utils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,23 +145,39 @@ FLASHINFER_EXT_MODULE_INIT_EXPAND(TORCH_EXTENSION_NAME)
145145

146146
// Should not be used together with _DISPATCH_SF_CASE_FP8_E8M0
147147
#ifdef FLASHINFER_ENABLE_FP4_E2M1
148+
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
148149
#define _DISPATCH_CASE_FP4_E2M1(c_type, ...) \
149150
case at::ScalarType::Byte: { \
150151
using c_type = __nv_fp4_e2m1; \
151152
return __VA_ARGS__(); \
152153
}
153154
#else
155+
#define _DISPATCH_CASE_FP4_E2M1(c_type, ...) \
156+
case at::ScalarType::Byte: { \
157+
static_assert(false, "FP4 E2M1 support requires CUDA 12.8 or newer."); \
158+
break; \
159+
}
160+
#endif
161+
#else
154162
#define _DISPATCH_CASE_FP4_E2M1(c_type, ...)
155163
#endif
156164

157165
// Should not be used together with _DISPATCH_CASE_FP4_E2M1
158166
#ifdef FLASHINFER_ENABLE_FP8_E8M0
167+
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
159168
#define _DISPATCH_SF_CASE_FP8_E8M0(c_type, ...) \
160169
case at::ScalarType::Byte: { \
161170
using c_type = __nv_fp8_e8m0; \
162171
return __VA_ARGS__(); \
163172
}
164173
#else
174+
#define _DISPATCH_SF_CASE_FP8_E8M0(c_type, ...) \
175+
case at::ScalarType::Byte: { \
176+
static_assert(false, "FP8 E8M0 support requires CUDA 12.8 or newer."); \
177+
break; \
178+
}
179+
#endif
180+
#else
165181
#define _DISPATCH_SF_CASE_FP8_E8M0(c_type, ...)
166182
#endif
167183

β€Žinclude/flashinfer/comm/trtllm_allreduce_fusion.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,10 +623,14 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(vec_t<T, VEC_SIZE>& vec, float SFScaleV
623623
uint8_t fp8SFVal;
624624
// Write the SF to global memory (STG.8).
625625
if constexpr (UE8M0_SF) {
626+
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
626627
__nv_fp8_e8m0 tmp;
627628
tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf);
628629
SFValue = static_cast<float>(tmp);
629630
fp8SFVal = tmp.__x;
631+
#else
632+
#error "FP8 E8M0 support requires CUDA 12.8 or newer."
633+
#endif
630634
} else {
631635
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
632636
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);

β€Žinclude/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,10 +610,14 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(vec_t<T, VEC_SIZE>& vec, float SFScaleV
610610
uint8_t fp8SFVal;
611611
// Write the SF to global memory (STG.8).
612612
if constexpr (UE8M0_SF) {
613+
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
613614
__nv_fp8_e8m0 tmp;
614615
tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf);
615616
SFValue = static_cast<float>(tmp);
616617
fp8SFVal = tmp.__x;
618+
#else
619+
#error "FP8 E8M0 support requires CUDA 12.8 or newer."
620+
#endif
617621
} else {
618622
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
619623
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);

β€Žinclude/flashinfer/cutlass_utils.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ struct cutlass_dtype<__nv_fp8_e5m2> {
7171
using type = cutlass::float_e5m2_t;
7272
};
7373

74+
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
7475
template <>
7576
struct cutlass_dtype<__nv_fp8_e8m0> {
7677
using type = cutlass::float_ue8m0_t;
@@ -82,6 +83,7 @@ struct cutlass_dtype<__nv_fp4_e2m1> {
8283
using type = cutlass::float_e2m1_t;
8384
};
8485
#endif
86+
#endif
8587

8688
template <typename T>
8789
using cutlass_dtype_t = typename cutlass_dtype<T>::type;

0 commit comments

Comments
Β (0)