Skip to content

Commit 90abf04

Browse files
authored
feat: cutlass fp8 gemm bringup for SM120 & SM121 (#1610)
<!-- .github/pull_request_template.md --> ## 📌 Description It depends on #1608, mainly the cutlass fp8 gemm support for sm120/121, will rebase after #1608 lands. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent bd487ee commit 90abf04

11 files changed

+1156
-33
lines changed

csrc/gemm_groupwise_sm120.cu

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <flashinfer/cutlass_utils.cuh>
17+
18+
#include "pytorch_extension_utils.h"
19+
20+
using namespace flashinfer;
21+
22+
#define DISPATCH_PYTORCH_INPUT_OUTPUT_DTYPE(input_dtype, output_dtype, c_type_in, c_type_out, ...) \
23+
[&]() -> bool { \
24+
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, c_type_out, [&] { \
25+
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(input_dtype, c_type_in, \
26+
[&] { return __VA_ARGS__(); }); \
27+
}); \
28+
}()
29+
30+
#define DISPATCH_SCALE_GRANULARITY(scale_granularity_m, scale_granularity_n, scale_granularity_k, \
31+
SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, \
32+
...) \
33+
[&]() -> bool { \
34+
/* SM120 Cooperative schedule uses 128x128x128 tile shape */ \
35+
/* TODO (yongwww): PingPong schedule (64x128x128) will need additional dispatch logic */ \
36+
constexpr int SCALE_GRANULARITY_K = 128; \
37+
if (scale_granularity_k != 128) { \
38+
TORCH_CHECK( \
39+
false, \
40+
"SM120 requires scale_granularity_k=128. CUTLASS enforces ScaleGranularityK must equal " \
41+
"tile shape K dimension (128 for both Cooperative and PingPong schedules)."); \
42+
return false; \
43+
} \
44+
/* Support (1,128,128) and (128,128,128) as per SM100's approach */ \
45+
if (scale_granularity_m == 1 && scale_granularity_n == 128) { \
46+
constexpr int SCALE_GRANULARITY_M = 1; \
47+
constexpr int SCALE_GRANULARITY_N = 128; \
48+
return __VA_ARGS__(); \
49+
} else if (scale_granularity_m == 128 && scale_granularity_n == 128) { \
50+
constexpr int SCALE_GRANULARITY_M = 128; \
51+
constexpr int SCALE_GRANULARITY_N = 128; \
52+
return __VA_ARGS__(); \
53+
} \
54+
TORCH_CHECK(false, "SM120: Unsupported scale granularity combination (", scale_granularity_m, \
55+
",", scale_granularity_n, ",", scale_granularity_k, ")"); \
56+
return false; \
57+
}()
58+
59+
#define DISPATCH_SCALE_MAJOR_K(scale_major_mode, SCALE_MAJOR_K, ...) \
60+
[&]() -> bool { \
61+
if (scale_major_mode == "K") { \
62+
constexpr bool SCALE_MAJOR_K = true; \
63+
return __VA_ARGS__(); \
64+
} else if (scale_major_mode == "MN") { \
65+
constexpr bool SCALE_MAJOR_K = false; \
66+
return __VA_ARGS__(); \
67+
} \
68+
TORCH_CHECK(false, "Unsupported Scale Major Mode"); \
69+
return false; \
70+
}()
71+
72+
namespace flashinfer {
73+
namespace gemm {
74+
75+
template <int ScaleGranularityM, int ScaleGranularityN, int ScaleGranularityK, bool ScaleMajorK,
76+
typename DTypeIn, typename DTypeOut>
77+
cudaError_t CutlassGroupwiseScaledGEMMSM120(void* float_buffer, size_t float_buffer_size_in_bytes,
78+
DTypeIn* A_ptr, DTypeIn* B_ptr, float* SFA_ptr,
79+
float* SFB_ptr, DTypeOut* C_ptr, int m, int n, int k,
80+
int l, cudaStream_t stream);
81+
82+
} // namespace gemm
83+
} // namespace flashinfer
84+
85+
void CutlassGemmGroupwiseScaledSM120(at::Tensor float_workspace_buffer, at::Tensor A, at::Tensor B,
86+
at::Tensor SFA, at::Tensor SFB, at::Tensor C,
87+
int64_t scale_granularity_m, int64_t scale_granularity_n,
88+
int64_t scale_granularity_k, std::string scale_major_mode) {
89+
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
90+
auto stream = at::cuda::getCurrentCUDAStream();
91+
92+
// Ensure scales are contiguous
93+
// Note: We keep the original shape and let the kernel's layout handle interpretation
94+
at::Tensor SFA_contig = SFA.is_contiguous() ? SFA : SFA.contiguous();
95+
at::Tensor SFB_contig = SFB.is_contiguous() ? SFB : SFB.contiguous();
96+
97+
DISPATCH_SCALE_MAJOR_K(scale_major_mode, SCALE_MAJOR_K, [&] {
98+
return DISPATCH_PYTORCH_INPUT_OUTPUT_DTYPE(
99+
A.scalar_type(), C.scalar_type(), c_type_in, c_type_out, [&] {
100+
return DISPATCH_SCALE_GRANULARITY(
101+
scale_granularity_m, scale_granularity_n, scale_granularity_k, SCALE_GRANULARITY_M,
102+
SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, [&] {
103+
using cutlass_t_in = cutlass_dtype_t<c_type_in>;
104+
using cutlass_t_out = cutlass_dtype_t<c_type_out>;
105+
106+
// Handle both 2D and 3D tensors (BMM)
107+
int m, n, k, l;
108+
if (A.dim() == 2) {
109+
// 2D case: simple matrix multiplication
110+
m = A.size(0);
111+
k = A.size(1);
112+
n = B.size(0);
113+
l = 1; // no batch dimension
114+
} else if (A.dim() == 3) {
115+
// 3D case: batch matrix multiplication
116+
l = A.size(0); // batch size
117+
m = A.size(1); // per-batch m dimension
118+
k = A.size(2); // per-batch k dimension
119+
n = B.size(2); // per-batch n dimension (B is [batch, k, n] column-major)
120+
} else {
121+
return false; // Unsupported tensor dimension
122+
}
123+
124+
auto status = flashinfer::gemm::CutlassGroupwiseScaledGEMMSM120<
125+
SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, SCALE_MAJOR_K>(
126+
static_cast<void*>(float_workspace_buffer.data_ptr()),
127+
float_workspace_buffer.element_size() * float_workspace_buffer.numel(),
128+
static_cast<cutlass_t_in*>(A.data_ptr()),
129+
static_cast<cutlass_t_in*>(B.data_ptr()),
130+
static_cast<float*>(SFA_contig.data_ptr()),
131+
static_cast<float*>(SFB_contig.data_ptr()),
132+
static_cast<cutlass_t_out*>(C.data_ptr()), m, n, k, l,
133+
stream); // C is the output (D)
134+
return status == cudaSuccess;
135+
});
136+
});
137+
});
138+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <flashinfer/gemm/gemm_groupwise_sm120.cuh>
17+
18+
using namespace flashinfer;
19+
using namespace flashinfer::gemm;
20+
21+
namespace flashinfer {
22+
namespace gemm {
23+
24+
// Following SM100's approach: support (1,128,128) and (128,128,128)
25+
{% for scale_granularity_m in [1, 128] %}
26+
{% for scale_granularity_n in [128] %}
27+
{% for scale_granularity_k in [128] %}
28+
template cudaError_t
29+
CutlassGroupwiseScaledGEMMSM120<
30+
{{ scale_granularity_m }}, {{ scale_granularity_n }}, {{ scale_granularity_k }},
31+
{{ scale_major_k }},
32+
{{ dtype_in }},
33+
{{ dtype_out }}>(
34+
void* float_buffer, size_t float_buffer_size_in_bytes,
35+
{{ dtype_in }}* A_ptr, {{ dtype_in }}* B_ptr, float* SFA_ptr, float* SFB_ptr,
36+
{{ dtype_out }}* D_ptr, int m, int n, int k, int l, cudaStream_t stream);
37+
{% endfor %}
38+
{% endfor %}
39+
{% endfor %}
40+
41+
}; // namespace gemm
42+
}; // namespace flashinfer

csrc/gemm_sm120_pybind.cu

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "pytorch_extension_utils.h"
17+
18+
void CutlassGemmGroupwiseScaledSM120(at::Tensor float_workspace_buffer, at::Tensor A, at::Tensor B,
19+
at::Tensor SFA, at::Tensor SFB, at::Tensor C,
20+
int64_t scale_granularity_m, int64_t scale_granularity_n,
21+
int64_t scale_granularity_k, std::string scale_major_mode);
22+
23+
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
24+
m.def("gemm_fp8_nt_groupwise", CutlassGemmGroupwiseScaledSM120);
25+
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <flashinfer/cutlass_utils.cuh>
17+
18+
#include "pytorch_extension_utils.h"
19+
20+
using namespace flashinfer;
21+
22+
#define DISPATCH_PYTORCH_INPUT_OUTPUT_DTYPE(input_dtype, output_dtype, c_type_in, c_type_out, ...) \
23+
[&]() -> bool { \
24+
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, c_type_out, [&] { \
25+
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(input_dtype, c_type_in, \
26+
[&] { return __VA_ARGS__(); }); \
27+
}); \
28+
}()
29+
30+
#define DISPATCH_SCALE_GRANULARITY(scale_granularity_m, scale_granularity_n, scale_granularity_k, \
31+
SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, \
32+
...) \
33+
[&]() -> bool { \
34+
constexpr int SCALE_GRANULARITY_K = 128; \
35+
if (scale_granularity_k != 128) { \
36+
TORCH_CHECK( \
37+
false, \
38+
"SM120 requires scale_granularity_k=128. CUTLASS enforces ScaleGranularityK must equal " \
39+
"tile shape K dimension (128 for both Cooperative and PingPong schedules)."); \
40+
return false; \
41+
} \
42+
/* Match SM100's approach: support only (1,128,128) and (128,128,128) */ \
43+
if (scale_granularity_m == 1 && scale_granularity_n == 128) { \
44+
constexpr int SCALE_GRANULARITY_M = 1; \
45+
constexpr int SCALE_GRANULARITY_N = 128; \
46+
return __VA_ARGS__(); \
47+
} else if (scale_granularity_m == 128 && scale_granularity_n == 128) { \
48+
constexpr int SCALE_GRANULARITY_M = 128; \
49+
constexpr int SCALE_GRANULARITY_N = 128; \
50+
return __VA_ARGS__(); \
51+
} \
52+
TORCH_CHECK(false, "SM120: Unsupported scale granularity combination (", scale_granularity_m, \
53+
",", scale_granularity_n, ",", scale_granularity_k, ")"); \
54+
return false; \
55+
}()
56+
57+
#define DISPATCH_SCALE_MAJOR_K(scale_major_mode, SCALE_MAJOR_K, ...) \
58+
[&]() -> bool { \
59+
if (scale_major_mode == "K") { \
60+
constexpr bool SCALE_MAJOR_K = true; \
61+
return __VA_ARGS__(); \
62+
} else if (scale_major_mode == "MN") { \
63+
constexpr bool SCALE_MAJOR_K = false; \
64+
return __VA_ARGS__(); \
65+
} \
66+
TORCH_CHECK(false, "Unsupported Scale Major Mode"); \
67+
return false; \
68+
}()
69+
70+
namespace flashinfer {
71+
namespace group_gemm {
72+
73+
template <int ScaleGranularityM, int ScaleGranularityN, int ScaleGranularityK, bool ScaleMajorK,
74+
typename DTypeIn, typename DTypeOut>
75+
cudaError_t CutlassFP8GroupwiseScaledGroupGEMMSM120(
76+
void* int_buffer, size_t int_buffer_size_in_bytes, void* float_buffer,
77+
size_t float_buffer_size_in_bytes, DTypeIn* A, DTypeIn* B, float* SFA, float* SFB, DTypeOut* D,
78+
int* m_indptr, int max_m, int n, int k, int num_groups, cudaStream_t stream);
79+
80+
} // namespace group_gemm
81+
} // namespace flashinfer
82+
83+
void CutlassGroupGemmFP8GroupwiseScaledSM120(
84+
at::Tensor int_workspace_buffer, at::Tensor float_workspace_buffer, at::Tensor A, at::Tensor B,
85+
at::Tensor SFA, at::Tensor SFB, at::Tensor D, at::Tensor m_indptr, int64_t n, int64_t k,
86+
int64_t scale_granularity_m, int64_t scale_granularity_n, int64_t scale_granularity_k,
87+
std::string scale_major_mode) {
88+
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
89+
auto stream = at::cuda::getCurrentCUDAStream();
90+
int num_groups = m_indptr.size(0) - 1;
91+
92+
// Ensure scales are contiguous
93+
// Note: We keep the original shape and let the kernel's layout handle interpretation
94+
at::Tensor SFA_contig = SFA.is_contiguous() ? SFA : SFA.contiguous();
95+
at::Tensor SFB_contig = SFB.is_contiguous() ? SFB : SFB.contiguous();
96+
97+
// Get max_m from SFA shape
98+
int max_m = SFA.size(SFA.dim() > 1 ? 1 : 0);
99+
100+
DISPATCH_PYTORCH_INPUT_OUTPUT_DTYPE(A.scalar_type(), D.scalar_type(), c_type_in, c_type_out, [&] {
101+
return DISPATCH_SCALE_MAJOR_K(scale_major_mode, SCALE_MAJOR_K, [&] {
102+
return DISPATCH_SCALE_GRANULARITY(
103+
scale_granularity_m, scale_granularity_n, scale_granularity_k, SCALE_GRANULARITY_M,
104+
SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, [&] {
105+
using cutlass_t_in = cutlass_dtype_t<c_type_in>;
106+
using cutlass_t_out = cutlass_dtype_t<c_type_out>;
107+
auto status = flashinfer::group_gemm::CutlassFP8GroupwiseScaledGroupGEMMSM120<
108+
SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, SCALE_MAJOR_K,
109+
cutlass_t_in, cutlass_t_out>(
110+
static_cast<int*>(int_workspace_buffer.data_ptr()),
111+
int_workspace_buffer.element_size() * int_workspace_buffer.size(0),
112+
static_cast<float*>(float_workspace_buffer.data_ptr()),
113+
float_workspace_buffer.element_size() * float_workspace_buffer.size(0),
114+
static_cast<cutlass_t_in*>(A.data_ptr()), static_cast<cutlass_t_in*>(B.data_ptr()),
115+
static_cast<float*>(SFA_contig.data_ptr()),
116+
static_cast<float*>(SFB_contig.data_ptr()),
117+
static_cast<cutlass_t_out*>(D.data_ptr()), static_cast<int*>(m_indptr.data_ptr()),
118+
max_m, n, k, num_groups, stream);
119+
return status == cudaSuccess;
120+
});
121+
});
122+
});
123+
}

0 commit comments

Comments
 (0)