Skip to content

Commit a6e9c9f

Browse files
neurusLAnrui Liu
andauthored
TVM: support TVM binding for GroupedGemm (#1725)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> The PR add support for GroupedGemm tvm_binding from FlashInfer side. - ```flashinfer/tvm_binding/grouped_gemm_fp8.cu``` contains implementation of dispatching templates to ```group_gemm::CutlassFP8GroupwiseScaledGroupGEMMSM100```, supporting JIT compilation - ```flashinfer/tvm_binding/grouped_gemm_fp8_jit_tvm_binding.cu``` contains declaration of above function - ```flashinfer/flashinfer/jit/gemm``` contains interface exposed to tvm ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Anrui Liu <[email protected]>
1 parent b6cfc2c commit a6e9c9f

File tree

7 files changed

+342
-3
lines changed

7 files changed

+342
-3
lines changed

.gitmodules

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
[submodule "3rdparty/cutlass"]
22
path = 3rdparty/cutlass
33
url = https://github.com/NVIDIA/cutlass.git
4-
[submodule "3rdparty/composable_kernels"]
5-
path = 3rdparty/composable_kernels
6-
url = https://github.com/ROCm/composable_kernel.git
74
[submodule "3rdparty/spdlog"]
85
path = 3rdparty/spdlog
96
url = https://github.com/gabime/spdlog.git

flashinfer/jit/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@
7878
from .core import current_compilation_context as current_compilation_context
7979
from .cubin_loader import setup_cubin_loader
8080

81+
from .gemm import gen_grouped_gemm_fp8_tvm_binding as gen_grouped_gemm_fp8_tvm_binding
82+
from .gemm import get_grouped_gemm_fp8_uri as get_grouped_gemm_fp8_uri
83+
8184

8285
@functools.cache
8386
def get_cudnn_fmha_gen_module():

flashinfer/jit/gemm/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from .tvm import gen_grouped_gemm_fp8_tvm_binding as gen_grouped_gemm_fp8_tvm_binding
18+
from .tvm import get_grouped_gemm_fp8_uri as get_grouped_gemm_fp8_uri

flashinfer/jit/gemm/tvm.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import os
18+
from typing import Tuple
19+
20+
import torch
21+
22+
from .. import env as jit_env
23+
from ..utils import write_if_different
24+
25+
26+
def gen_grouped_gemm_fp8_tvm_binding(
27+
uri: str,
28+
dtype_a: torch.dtype,
29+
dtype_b: torch.dtype,
30+
dtype_out: torch.dtype,
31+
scale_granularity_m: int,
32+
scale_granularity_n: int,
33+
scale_granularity_k: int,
34+
scale_major_mode: str, # "K" or "MN"
35+
mma_sm: int,
36+
) -> Tuple[str, list]:
37+
"""Generate TVM binding for FP8 grouped GEMM.
38+
39+
Parameters
40+
----------
41+
uri : str
42+
Unique identifier for this kernel configuration
43+
dtype_a : torch.dtype
44+
Data type of matrix A
45+
dtype_b : torch.dtype
46+
Data type of matrix B
47+
dtype_out : torch.dtype
48+
Data type of output matrix
49+
scale_granularity_m : int
50+
Scaling granularity in M dimension
51+
scale_granularity_n : int
52+
Scaling granularity in N dimension
53+
scale_granularity_k : int
54+
Scaling granularity in K dimension
55+
scale_major_mode : str
56+
Scale storage mode ("K" or "MN")
57+
mma_sm : int
58+
MMA scheduling mode (1 or 2)
59+
60+
Returns
61+
-------
62+
Tuple[str, list]
63+
URI and list of generated source file paths
64+
"""
65+
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
66+
os.makedirs(gen_directory, exist_ok=True)
67+
68+
source_paths = []
69+
70+
# Copy the base implementation file unchanged
71+
src_path = jit_env.FLASHINFER_TVM_BINDING_DIR / "grouped_gemm_fp8.cu"
72+
dest_path = gen_directory / "grouped_gemm_fp8.cu"
73+
source_paths.append(dest_path)
74+
with open(src_path, "r") as f:
75+
source = f.read()
76+
write_if_different(dest_path, source)
77+
78+
# Read the base TVM binding file and create specialized version
79+
tvm_binding_src = (
80+
jit_env.FLASHINFER_TVM_BINDING_DIR / "grouped_gemm_fp8_jit_tvm_binding.cu"
81+
)
82+
with open(tvm_binding_src, "r") as f:
83+
base_content = f.read()
84+
85+
# Convert scale_major_mode to integer
86+
scale_major_mode_val = 0 if scale_major_mode == "K" else 1
87+
88+
# Create specialized version by modifying the function export
89+
# Replace the direct export with a specialized wrapper
90+
specialized_content = base_content.replace(
91+
"TVM_FFI_DLL_EXPORT_TYPED_FUNC(grouped_gemm_fp8_run, GroupedGemmFp8Run);",
92+
f"""// Specialized wrapper for this configuration
93+
int GroupedGemmFp8RunSpecialized(
94+
DLTensor* int_workspace_buffer,
95+
DLTensor* float_workspace_buffer,
96+
DLTensor* A,
97+
DLTensor* B,
98+
DLTensor* SFA,
99+
DLTensor* SFB,
100+
DLTensor* D,
101+
DLTensor* m_indptr,
102+
int64_t n, int64_t k,
103+
TVMStreamHandle cuda_stream
104+
) {{
105+
return GroupedGemmFp8Run(
106+
int_workspace_buffer,
107+
float_workspace_buffer,
108+
A, B, SFA, SFB, D, m_indptr,
109+
n, k,
110+
{scale_granularity_m}, // scale_granularity_m
111+
{scale_granularity_n}, // scale_granularity_n
112+
{scale_granularity_k}, // scale_granularity_k
113+
{scale_major_mode_val}, // scale_major_mode
114+
{mma_sm}, // mma_sm
115+
cuda_stream
116+
);
117+
}}
118+
119+
TVM_FFI_DLL_EXPORT_TYPED_FUNC(grouped_gemm_fp8_run, GroupedGemmFp8RunSpecialized);""",
120+
)
121+
122+
binding_dest_path = gen_directory / "grouped_gemm_fp8_jit_tvm_binding.cu"
123+
source_paths.append(binding_dest_path)
124+
write_if_different(binding_dest_path, specialized_content)
125+
126+
return uri, source_paths
127+
128+
129+
def get_grouped_gemm_fp8_uri(
130+
dtype_a: torch.dtype,
131+
dtype_b: torch.dtype,
132+
dtype_out: torch.dtype,
133+
scale_granularity_m: int,
134+
scale_granularity_n: int,
135+
scale_granularity_k: int,
136+
scale_major_mode: str,
137+
mma_sm: int,
138+
) -> str:
139+
"""Generate URI for FP8 grouped GEMM configuration."""
140+
dtype_a_str = str(dtype_a).split(".")[-1]
141+
dtype_b_str = str(dtype_b).split(".")[-1]
142+
dtype_out_str = str(dtype_out).split(".")[-1]
143+
144+
return (
145+
f"group_gemm_fp8_{dtype_a_str}_{dtype_b_str}_{dtype_out_str}_"
146+
f"sg_{scale_granularity_m}_{scale_granularity_n}_{scale_granularity_k}_"
147+
f"sm_{scale_major_mode}_mma_{mma_sm}"
148+
)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ packages = [
5353
"flashinfer.jit.attention",
5454
"flashinfer.jit.cutlass_gemm",
5555
"flashinfer.testing",
56+
"flashinfer.jit.gemm",
5657
"flashinfer.triton",
5758
"flashinfer.tuning_configs",
5859
"flashinfer.profiler",

tvm_binding/grouped_gemm_fp8.cu

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
#include <dlpack/dlpack.h>
2+
3+
#include <flashinfer/cutlass_utils.cuh>
4+
#include <flashinfer/gemm/group_gemm_fp8_groupwise_sm100.cuh>
5+
6+
#include "tvm_binding_utils.h"
7+
8+
__global__ void simple_print_kernel(void* data, int dtype_code) {
9+
if (threadIdx.x == 0 && blockIdx.x == 0) {
10+
if (dtype_code == kDLBfloat) {
11+
// bfloat16
12+
uint16_t* bf16_data = static_cast<uint16_t*>(data);
13+
uint32_t full = ((uint32_t)bf16_data[0]) << 16;
14+
float val = *reinterpret_cast<float*>(&full);
15+
printf("GPU: D[0] = %.6f\n", val);
16+
} else {
17+
// float32
18+
float* f32_data = static_cast<float*>(data);
19+
printf("GPU: D[0] = %.6f\n", f32_data[0]);
20+
}
21+
}
22+
}
23+
24+
// following MACROS duplicates from flashinfer/csrc/group_gemm_fp8_groupwise_sm100.cu
25+
#define DISPATCH_TVM_DTYPE_TO_CTYPE(tvm_dtype_in, tvm_dtype_out, c_type_in, c_type_out, ...) \
26+
[&]() -> bool { \
27+
if (tvm_dtype_in.code == kDLFloat8_e4m3fn && tvm_dtype_in.bits == 8) { \
28+
using c_type_in = cutlass::float_e4m3_t; \
29+
if (tvm_dtype_out.code == kDLFloat && tvm_dtype_out.bits == 16) { \
30+
using c_type_out = cutlass::half_t; \
31+
return __VA_ARGS__(); \
32+
} \
33+
if (tvm_dtype_out.code == kDLBfloat && tvm_dtype_out.bits == 16) { \
34+
using c_type_out = cutlass::bfloat16_t; \
35+
return __VA_ARGS__(); \
36+
} \
37+
} \
38+
CHECK(false) << "Unsupported TVM dtype combination: input(" << tvm_dtype_in.code << "," \
39+
<< tvm_dtype_in.bits << ") output(" << tvm_dtype_out.code << "," \
40+
<< tvm_dtype_out.bits << ")"; \
41+
return false; \
42+
}()
43+
44+
#define DISPATCH_MMA_SM(mma_sm, MMA_SM, ...) \
45+
[&]() -> bool { \
46+
if (mma_sm == 1) { \
47+
constexpr int MMA_SM = 1; \
48+
return __VA_ARGS__(); \
49+
} else if (mma_sm == 2) { \
50+
constexpr int MMA_SM = 2; \
51+
return __VA_ARGS__(); \
52+
} \
53+
CHECK(false) << "Unsupported MMA SM: " << mma_sm; \
54+
return false; \
55+
}()
56+
57+
#define DISPATCH_SCALE_GRANULARITY(scale_granularity_m, scale_granularity_n, scale_granularity_k, \
58+
SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, \
59+
...) \
60+
[&]() -> bool { \
61+
if (scale_granularity_m == 1 && scale_granularity_n == 128 && scale_granularity_k == 128) { \
62+
constexpr int SCALE_GRANULARITY_M = 1; \
63+
constexpr int SCALE_GRANULARITY_N = 128; \
64+
constexpr int SCALE_GRANULARITY_K = 128; \
65+
return __VA_ARGS__(); \
66+
} else if (scale_granularity_m == 128 && scale_granularity_n == 128 && \
67+
scale_granularity_k == 128) { \
68+
constexpr int SCALE_GRANULARITY_M = 128; \
69+
constexpr int SCALE_GRANULARITY_N = 128; \
70+
constexpr int SCALE_GRANULARITY_K = 128; \
71+
return __VA_ARGS__(); \
72+
} \
73+
CHECK(false) << "Unsupported scale granularity: (" << scale_granularity_m << "," \
74+
<< scale_granularity_n << "," << scale_granularity_k << ")"; \
75+
return false; \
76+
}()
77+
78+
#define DISPATCH_SCALE_MAJOR_K(scale_major_mode, SCALE_MAJOR_K, ...) \
79+
[&]() -> bool { \
80+
if (scale_major_mode == 0) { \
81+
constexpr bool SCALE_MAJOR_K = true; \
82+
return __VA_ARGS__(); \
83+
} else if (scale_major_mode == 1) { \
84+
constexpr bool SCALE_MAJOR_K = false; \
85+
return __VA_ARGS__(); \
86+
} \
87+
CHECK(false) << "Unsupported Scale Major Mode: " << scale_major_mode; \
88+
return false; \
89+
}()
90+
91+
namespace flashinfer {
92+
namespace group_gemm {
93+
94+
template <int ScaleGranularityM, int ScaleGranularityN, int ScaleGranularityK, bool ScaleMajorK,
95+
int MmaSM, typename DTypeIn, typename DTypeOut>
96+
cudaError_t CutlassFP8GroupwiseScaledGroupGEMMSM100(
97+
void* int_buffer, size_t int_buffer_size_in_bytes, void* float_buffer,
98+
size_t float_buffer_size_in_bytes, DTypeIn* A, DTypeIn* B, float* SFA, float* SFB, DTypeOut* D,
99+
int* m_indptr, int max_m, int n, int k, int num_groups, cudaStream_t stream);
100+
101+
}
102+
} // namespace flashinfer
103+
104+
// FP8 Group GEMM implementation with CUTLASS for SM100A (Blackwell)
105+
void GroupedGemmFp8Run(DLTensor* int_workspace_buffer, DLTensor* float_workspace_buffer,
106+
DLTensor* A, DLTensor* B, DLTensor* SFA, DLTensor* SFB, DLTensor* D,
107+
DLTensor* m_indptr, int64_t n, int64_t k, int64_t scale_granularity_m,
108+
int64_t scale_granularity_n, int64_t scale_granularity_k,
109+
int64_t scale_major_mode, int64_t mma_sm, TVMStreamHandle cuda_stream) {
110+
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
111+
112+
size_t float_workspace_size =
113+
float_workspace_buffer->shape[0] * DataType(float_workspace_buffer->dtype).bytes();
114+
size_t int_workspace_size =
115+
int_workspace_buffer->shape[0] * DataType(int_workspace_buffer->dtype).bytes();
116+
117+
int64_t num_groups = m_indptr->shape[0] - 1;
118+
int64_t max_m = SFA->shape[1];
119+
120+
try {
121+
DISPATCH_TVM_DTYPE_TO_CTYPE(A->dtype, D->dtype, c_type_in, c_type_out, [&] {
122+
return DISPATCH_SCALE_MAJOR_K(scale_major_mode, SCALE_MAJOR_K, [&] {
123+
return DISPATCH_MMA_SM(mma_sm, MMA_SM, [&] {
124+
return DISPATCH_SCALE_GRANULARITY(
125+
scale_granularity_m, scale_granularity_n, scale_granularity_k, SCALE_GRANULARITY_M,
126+
SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, [&] {
127+
using cutlass_t_in = flashinfer::cutlass_dtype_t<c_type_in>;
128+
using cutlass_t_out = flashinfer::cutlass_dtype_t<c_type_out>;
129+
130+
auto status = flashinfer::group_gemm::CutlassFP8GroupwiseScaledGroupGEMMSM100<
131+
SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, SCALE_MAJOR_K,
132+
MMA_SM>(
133+
static_cast<int32_t*>(int_workspace_buffer->data) +
134+
int_workspace_buffer->byte_offset / sizeof(int32_t),
135+
int_workspace_buffer->shape[0] * sizeof(int32_t),
136+
static_cast<float*>(float_workspace_buffer->data) +
137+
float_workspace_buffer->byte_offset / sizeof(float),
138+
float_workspace_buffer->shape[0] * sizeof(float),
139+
static_cast<cutlass_t_in*>(A->data) + A->byte_offset / sizeof(cutlass_t_in),
140+
static_cast<cutlass_t_in*>(B->data) + B->byte_offset / sizeof(cutlass_t_in),
141+
static_cast<float*>(SFA->data) + SFA->byte_offset / sizeof(float),
142+
static_cast<float*>(SFB->data) + SFB->byte_offset / sizeof(float),
143+
static_cast<cutlass_t_out*>(D->data) + D->byte_offset / sizeof(cutlass_t_out),
144+
static_cast<int32_t*>(m_indptr->data) + m_indptr->byte_offset / sizeof(int32_t),
145+
max_m, n, k, num_groups, stream);
146+
147+
// Check for CUDA errors immediately after kernel call
148+
cudaError_t cuda_error = cudaGetLastError();
149+
if (cuda_error != cudaSuccess) {
150+
return false;
151+
}
152+
LOG(INFO) << "Kernel execution completed successfully";
153+
return status == cudaSuccess;
154+
});
155+
});
156+
});
157+
});
158+
} catch (const std::exception& e) {
159+
LOG(INFO) << "Exception caught:" << e.what();
160+
}
161+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#include "tvm_binding_utils.h"
2+
3+
// Function declarations (implementations in grouped_gemm.cu)
4+
IntTuple GroupedGemmGetWorkspaceSize(int64_t batch_size, int64_t max_m, int64_t max_n,
5+
int64_t max_k);
6+
7+
void GroupedGemmFp8Run(DLTensor* int_workspace_buffer, DLTensor* float_workspace_buffer,
8+
DLTensor* A, DLTensor* B, DLTensor* SFA, DLTensor* SFB, DLTensor* D,
9+
DLTensor* m_indptr, int64_t n, int64_t k, int64_t scale_granularity_m,
10+
int64_t scale_granularity_n, int64_t scale_granularity_k,
11+
int64_t scale_major_mode, int64_t mma_sm, TVMStreamHandle cuda_stream);

0 commit comments

Comments
 (0)