Skip to content

Commit beebc60

Browse files
aleozlxttyioyongwwwsunghyunp-nvdianv-yunzheq
authored
feat: initial support for SM103, SM110, SM120, SM121 (#1608)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Vincent Huang <[email protected]> Co-authored-by: Yong Wu <[email protected]> Co-authored-by: Sunghyun Park <[email protected]> Co-authored-by: Yunzhe Qiu <[email protected]> Co-authored-by: Brian Ryu <[email protected]> Co-authored-by: Ka-Hyun Nam <[email protected]> Co-authored-by: yzh119 <[email protected]> Co-authored-by: Zihao Ye <[email protected]>
1 parent 2cd065b commit beebc60

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+1719
-914
lines changed

.github/workflows/release_wheel_aarch64.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ on:
1616
required: true
1717

1818
env:
19-
TORCH_CUDA_ARCH_LIST: "7.5 8.0 8.9 9.0+PTX"
19+
FLASHINFER_CUDA_ARCH_LIST: "7.5 8.0 8.9 9.0+PTX"
2020

2121
jobs:
2222
build:
@@ -77,7 +77,7 @@ jobs:
7777
-e FLASHINFER_CI_CUDA_VERSION=${{ matrix.cuda }} \
7878
-e FLASHINFER_CI_TORCH_VERSION=${{ matrix.torch }} \
7979
-e FLASHINFER_CI_PYTHON_VERSION=${{ matrix.python }} \
80-
-e TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST" \
80+
-e FLASHINFER_CUDA_ARCH_LIST="$FLASHINFER_CUDA_ARCH_LIST" \
8181
-e MAX_JOBS=128 \
8282
--user $(id -u):$(id -g) \
8383
$BUILDER_IMAGE \

.github/workflows/release_wheel_sglang_x86_64.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ on:
1515
required: true
1616

1717
env:
18-
TORCH_CUDA_ARCH_LIST: "7.5 8.0 8.9 9.0+PTX"
18+
FLASHINFER_CUDA_ARCH_LIST: "7.5 8.0 8.9 9.0+PTX"
1919

2020
jobs:
2121
build:
@@ -59,7 +59,7 @@ jobs:
5959
-e FLASHINFER_CI_TORCH_VERSION=${{ matrix.torch }} \
6060
-e FLASHINFER_CI_PYTHON_VERSION=3.10 \
6161
-e FLASHINFER_HEAD_DIMS="64,128,256" \
62-
-e TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST" \
62+
-e FLASHINFER_CUDA_ARCH_LIST="$FLASHINFER_CUDA_ARCH_LIST" \
6363
-e MAX_JOBS=128 \
6464
--user $CI_UID:$CI_GID \
6565
$BUILDER_IMAGE \

.github/workflows/release_wheel_x86_64.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ on:
1818
# required: true
1919

2020
env:
21-
TORCH_CUDA_ARCH_LIST: "7.5 8.0 8.9 9.0+PTX"
21+
FLASHINFER_CUDA_ARCH_LIST: "7.5 8.0 8.9 9.0+PTX"
2222

2323
jobs:
2424
build:
@@ -82,7 +82,7 @@ jobs:
8282
-e FLASHINFER_CI_CUDA_VERSION=${{ matrix.cuda }} \
8383
-e FLASHINFER_CI_TORCH_VERSION=${{ matrix.torch }} \
8484
-e FLASHINFER_CI_PYTHON_VERSION=3.10 \
85-
-e TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST" \
85+
-e FLASHINFER_CUDA_ARCH_LIST="$FLASHINFER_CUDA_ARCH_LIST" \
8686
-e MAX_JOBS=128 \
8787
--user $CI_UID:$CI_GID \
8888
$BUILDER_IMAGE \

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ To pre-compile essential kernels ahead-of-time (AOT), run the following command:
6868

6969
```bash
7070
# Set target CUDA architectures
71-
export TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a"
71+
export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a"
7272
# Build AOT kernels. Will produce AOT kernels in aot-ops/
7373
python -m flashinfer.aot
7474
# Build AOT wheel
@@ -124,6 +124,10 @@ Starting from FlashInfer v0.2, users can customize their own attention variants
124124

125125
FlashInfer also provides C++ API and TVM bindings, please refer to [documentation](https://docs.flashinfer.ai/) for more details.
126126

127+
## GPU Support
128+
129+
FlashInfer currently provides support for NVIDIA SM architectures 80 and higher and beta support for 103, 110, 120, and 121.
130+
127131
## Adoption
128132

129133
We are thrilled to share that FlashInfer is being adopted by many cutting-edge projects, including but not limited to:

csrc/group_gemm_mxfp4_groupwise_sm100.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ void CutlassGroupGemmMXFP4GroupwiseScaledSM100(at::Tensor int_workspace_buffer,
134134
int64_t k, int64_t mma_sm, int64_t tile_m,
135135
int64_t tile_n, int64_t tile_k, bool swap_ab) {
136136
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
137-
auto stream = at::cuda::getCurrentCUDAStream();
137+
auto stream = at::cuda::getCurrentCUDAStream(A.device().index());
138138
int num_groups = m_indptr.size(0) - 1;
139139
DISPATCH_PYTORCH_INPUT_OUTPUT_DTYPE(
140140
A.scalar_type(), B.scalar_type(), SFA.scalar_type(), SFB.scalar_type(), D.scalar_type(),

csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ struct CutlassGemmConfig {
374374

375375
int getTileConfigAsInt() const {
376376
if (sm_version == 120) return (int)tile_config_sm120;
377+
if (sm_version == 110) return (int)tile_config_sm100;
377378
if (sm_version >= 100) return (int)tile_config_sm100;
378379
if (sm_version == 90) return (int)tile_config_sm90;
379380
if (sm_version < 90) return (int)tile_config_sm80;

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,62 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100(
415415
#endif
416416
}
417417

418+
std::vector<CutlassGemmConfig> get_candidate_configs_sm110(
419+
CutlassGemmConfig::CandidateConfigTypeParam const config) {
420+
#ifdef FAST_BUILD
421+
// Fast build disables all configs except this
422+
return {CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
423+
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
424+
ClusterShape::ClusterShape_1x1x1}};
425+
#else
426+
std::vector<CutlassGemmConfig> candidate_configs;
427+
for (int cluster_m = 1; cluster_m <= 2; cluster_m++) {
428+
bool Is2SM = cluster_m == 2;
429+
for (int cluster_n = 1; cluster_n <= 2; cluster_n++) {
430+
std::vector base = {// M=128
431+
CutlassTileConfigSM100::CtaShape128x128x128B,
432+
CutlassTileConfigSM100::CtaShape128x256x128B};
433+
434+
if (Is2SM) {
435+
if (cluster_n == 1) {
436+
base.push_back(CutlassTileConfigSM100::CtaShape128x64x128B);
437+
base.push_back(CutlassTileConfigSM100::CtaShape256x64x128B);
438+
}
439+
440+
std::vector twosm = {// M=256
441+
CutlassTileConfigSM100::CtaShape256x128x128B,
442+
CutlassTileConfigSM100::CtaShape256x256x128B};
443+
std::copy(twosm.begin(), twosm.end(), std::back_inserter(base));
444+
} else {
445+
if (cluster_n == 1) {
446+
base.push_back(CutlassTileConfigSM100::CtaShape128x32x128B);
447+
if ((config & CutlassGemmConfig::FP8_ONLY) != 0) {
448+
base.push_back(CutlassTileConfigSM100::CtaShape128x16x128B);
449+
}
450+
}
451+
452+
std::vector onesm{CutlassTileConfigSM100::CtaShape64x64x128B,
453+
CutlassTileConfigSM100::CtaShape64x128x128B,
454+
CutlassTileConfigSM100::CtaShape64x256x128B,
455+
CutlassTileConfigSM100::CtaShape128x64x128B};
456+
std::copy(onesm.begin(), onesm.end(), std::back_inserter(base));
457+
}
458+
459+
constexpr std::array cluster_shapes = {
460+
std::array{ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1},
461+
std::array{ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1}};
462+
auto cluster = cluster_shapes[cluster_m - 1][cluster_n - 1];
463+
for (auto tile : base) {
464+
CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
465+
cluster};
466+
candidate_configs.push_back(config);
467+
}
468+
}
469+
}
470+
return candidate_configs;
471+
#endif
472+
}
473+
418474
std::vector<CutlassGemmConfig> get_candidate_configs_sm120(
419475
CutlassGemmConfig::CandidateConfigTypeParam const config) {
420476
#ifdef FAST_BUILD
@@ -478,6 +534,9 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
478534
if (sm == 90 && (config_type_param & CutlassGemmConfig::HOPPER)) {
479535
return get_candidate_configs_sm90(config_type_param);
480536
}
537+
if (sm == 110 && (config_type_param & CutlassGemmConfig::BLACKWELL)) {
538+
return get_candidate_configs_sm110(config_type_param);
539+
}
481540
if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) {
482541
return get_candidate_configs_sm100(config_type_param);
483542
}

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,12 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
726726
// We allow both tma warp specialized and SM80 configurations to coexist because for some
727727
// cases with small numbers of tokens SM80 is faster. We check here to see which is selected
728728
if (inputs.gemm_config.sm_version >= 90) {
729-
TLLM_CHECK_WITH_INFO(inputs.gemm_config.sm_version == sm_,
729+
bool is_same_sm = inputs.gemm_config.sm_version == sm_;
730+
// gemm_config.sm_version indicates the kernel pipeline, which is always 100 for 100, 103,
731+
// 110 below logging helps confirming the cutlass pipeline matches the device major version
732+
bool is_sm110 = inputs.gemm_config.sm_version == 100 && sm_ == 110;
733+
bool is_sm103 = inputs.gemm_config.sm_version == 100 && sm_ == 103;
734+
TLLM_CHECK_WITH_INFO(is_same_sm || is_sm110 || is_sm103,
730735
"Using SM %d configuration for SM %d device",
731736
inputs.gemm_config.sm_version, sm_);
732737
TLLM_CHECK_WITH_INFO(inputs.biases != nullptr || hopper_inputs.ptr_c == nullptr,

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,26 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized(
327327
} else {
328328
TLLM_THROW("Unsupported SM90 configuration requested");
329329
}
330-
} else if (gemm_config.sm_version >= 100 && gemm_config.sm_version < 120) {
330+
} else if (gemm_config.sm_version == 110) {
331+
if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation<
332+
T, WeightType, EpilogueTag, FUSION>()) {
333+
switch (gemm_config.tile_config_sm100) {
334+
SHAPE_CASE(100, 64, 64, 128)
335+
SHAPE_CASE(100, 64, 128, 128)
336+
SHAPE_CASE(100, 64, 256, 128)
337+
338+
SHAPE_CASE(100, 128, 16, 128)
339+
SHAPE_CASE(100, 128, 32, 128)
340+
SHAPE_CASE(100, 128, 64, 128)
341+
SHAPE_CASE(100, 128, 128, 128)
342+
SHAPE_CASE(100, 128, 256, 128)
343+
344+
DEFAULT_CASE(100)
345+
}
346+
} else {
347+
TLLM_THROW("Unsupported SM110 configuration requested");
348+
}
349+
} else if (gemm_config.sm_version >= 100 && gemm_config.sm_version < 110) {
331350
if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation<
332351
T, WeightType, EpilogueTag, FUSION>()) {
333352
switch (gemm_config.tile_config_sm100) {

csrc/pytorch_extension_utils.h

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -146,40 +146,26 @@ FLASHINFER_EXT_MODULE_INIT_EXPAND(TORCH_EXTENSION_NAME)
146146
#endif
147147

148148
// Should not be used together with _DISPATCH_SF_CASE_FP8_E8M0
149-
#ifdef FLASHINFER_ENABLE_FP4_E2M1
150-
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
149+
#if defined(FLASHINFER_ENABLE_FP4_E2M1) && \
150+
(__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
151151
#define _DISPATCH_CASE_FP4_E2M1(c_type, ...) \
152152
case at::ScalarType::Byte: { \
153153
using c_type = __nv_fp4_e2m1; \
154154
return __VA_ARGS__(); \
155155
}
156156
#else
157-
#define _DISPATCH_CASE_FP4_E2M1(c_type, ...) \
158-
case at::ScalarType::Byte: { \
159-
static_assert(false, "FP4 E2M1 support requires CUDA 12.8 or newer."); \
160-
break; \
161-
}
162-
#endif
163-
#else
164157
#define _DISPATCH_CASE_FP4_E2M1(c_type, ...)
165158
#endif
166159

167160
// Should not be used together with _DISPATCH_CASE_FP4_E2M1
168-
#ifdef FLASHINFER_ENABLE_FP8_E8M0
169-
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
161+
#if defined(FLASHINFER_ENABLE_FP8_E8M0) && \
162+
(__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
170163
#define _DISPATCH_SF_CASE_FP8_E8M0(c_type, ...) \
171164
case at::ScalarType::Byte: { \
172165
using c_type = __nv_fp8_e8m0; \
173166
return __VA_ARGS__(); \
174167
}
175168
#else
176-
#define _DISPATCH_SF_CASE_FP8_E8M0(c_type, ...) \
177-
case at::ScalarType::Byte: { \
178-
static_assert(false, "FP8 E8M0 support requires CUDA 12.8 or newer."); \
179-
break; \
180-
}
181-
#endif
182-
#else
183169
#define _DISPATCH_SF_CASE_FP8_E8M0(c_type, ...)
184170
#endif
185171

0 commit comments

Comments
 (0)