Skip to content

Commit 9d328b2

Browse files
authored
feat: cutlass fp4 gemm bringup for SM120 & SM121 (#1609)
<!-- .github/pull_request_template.md --> ## 📌 Description It depends on #1608, mainly the cutlass fp4 gemm support for sm120/121, will rebase once that pr get merged. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 0e3403b commit 9d328b2

File tree

9 files changed

+809
-19
lines changed

9 files changed

+809
-19
lines changed

csrc/fp4_gemm_cutlass_sm120.cu

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
/*
2+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <ATen/cuda/EmptyTensor.h>
17+
#include <cuda_fp16.h>
18+
19+
#include <cstddef>
20+
#include <cstdint>
21+
#include <functional>
22+
#include <type_traits>
23+
#include <vector>
24+
25+
#include "flashinfer/gemm/cutlass_gemm_configs.h"
26+
// Use SM120-specific dispatch template (includes fp4_gemm_cutlass.h)
27+
#include "flashinfer/gemm/fp4_gemm_cutlass_template_sm120.h"
28+
#include "pytorch_extension_utils.h"
29+
30+
using flashinfer::gemm::ClusterShape;
31+
using flashinfer::gemm::CutlassFp4GemmRunner;
32+
using flashinfer::gemm::CutlassGemmConfig;
33+
using flashinfer::gemm::CutlassTileConfigSM120;
34+
using flashinfer::gemm::EpilogueScheduleType;
35+
using flashinfer::gemm::FP4GemmType;
36+
using flashinfer::gemm::MainloopScheduleType;
37+
38+
namespace torch_ext {
39+
40+
namespace {
41+
42+
CutlassGemmConfig getFp4GemmConfig(int64_t m, int64_t n, int64_t k, int64_t tactic) {
43+
auto getCutlassFp4GemmConfigs = []() {
44+
CutlassFp4GemmRunner<__nv_bfloat16, FP4GemmType::W4A4_NVFP4_NVFP4> gemmRunner;
45+
return gemmRunner.getConfigs();
46+
};
47+
static std::vector<CutlassGemmConfig> globalConfigs = getCutlassFp4GemmConfigs();
48+
TORCH_CHECK(tactic >= 0 && tactic < globalConfigs.size(), "tactic must be between 0 and ",
49+
globalConfigs.size());
50+
return globalConfigs[tactic];
51+
}
52+
53+
template <typename T>
54+
void runGemm(at::Tensor& out, at::Tensor const& mat1, at::Tensor const& mat2,
55+
at::Tensor const& mat1Scale, at::Tensor const& mat2Scale,
56+
at::Tensor const& globalScale, int64_t m, int64_t n, int64_t k, int64_t batch_count,
57+
CutlassGemmConfig const& gemmConfig, at::Tensor workspace_buffer) {
58+
CutlassFp4GemmRunner<T, FP4GemmType::W4A4_NVFP4_NVFP4> gemmRunner;
59+
60+
int64_t const required_workspace_size = gemmRunner.getWorkspaceSize(m, n, k, batch_count);
61+
int64_t const provided_workspace_size =
62+
workspace_buffer.numel() * workspace_buffer.element_size();
63+
64+
auto runKernel = [&](void* workspace) {
65+
gemmRunner.gemm(out.data_ptr(), mat1.const_data_ptr(), mat2.const_data_ptr(),
66+
mat1Scale.const_data_ptr(), mat2Scale.const_data_ptr(),
67+
globalScale.data_ptr<float>(), m, n, k, batch_count, gemmConfig,
68+
reinterpret_cast<char*>(workspace), required_workspace_size,
69+
at::cuda::getCurrentCUDAStream(mat1.get_device()));
70+
};
71+
72+
if (provided_workspace_size < required_workspace_size) {
73+
at::Tensor new_workspace = at::detail::empty_cuda(
74+
{required_workspace_size}, at::ScalarType::Char, mat1.device(), std::nullopt);
75+
76+
runKernel(new_workspace.data_ptr());
77+
} else {
78+
runKernel(workspace_buffer.data_ptr());
79+
}
80+
}
81+
82+
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; // uint8_t
83+
constexpr auto SF_DTYPE = at::ScalarType::Byte; // uint8_t
84+
85+
at::Tensor fp4_bmm_impl(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale,
86+
at::Tensor const& mat2Scale, at::Tensor const& globalScale, at::Tensor out,
87+
at::Tensor workspace_buffer, int64_t tactic) {
88+
// Validate inputs
89+
TORCH_CHECK(mat1.dtype() == FLOAT4_E2M1X2, "mat1 must be FLOAT4_E2M1X2 (uint8)");
90+
TORCH_CHECK(mat2.dtype() == FLOAT4_E2M1X2, "mat2 must be FLOAT4_E2M1X2 (uint8)");
91+
TORCH_CHECK(mat1Scale.dtype() == SF_DTYPE, "mat1Scale must be SF_DTYPE (uint8)");
92+
TORCH_CHECK(mat2Scale.dtype() == SF_DTYPE, "mat2Scale must be SF_DTYPE (uint8)");
93+
TORCH_CHECK(globalScale.dtype() == at::ScalarType::Float, "globalScale must be float");
94+
TORCH_CHECK(mat1.is_cuda(), "mat1 must be on CUDA device");
95+
TORCH_CHECK(mat2.is_cuda(), "mat2 must be on CUDA device");
96+
TORCH_CHECK(mat1Scale.is_cuda(), "mat1Scale must be on CUDA device");
97+
TORCH_CHECK(mat2Scale.is_cuda(), "mat2Scale must be on CUDA device");
98+
TORCH_CHECK(globalScale.is_cuda(), "globalScale must be on CUDA device");
99+
TORCH_CHECK(out.is_cuda(), "out must be on CUDA device");
100+
TORCH_CHECK(workspace_buffer.is_cuda(), "workspace_buffer must be on CUDA device");
101+
102+
// Check device consistency
103+
TORCH_CHECK(mat1.device() == mat2.device() && mat1.device() == mat1Scale.device() &&
104+
mat1.device() == mat2Scale.device() && mat1.device() == globalScale.device() &&
105+
mat1.device() == out.device() && mat1.device() == workspace_buffer.device(),
106+
"All tensors must be on the same device");
107+
108+
// Get dimensions
109+
int64_t b = 1;
110+
int64_t m, k_packed, n;
111+
112+
if (mat1.dim() == 2) {
113+
m = mat1.size(0);
114+
k_packed = mat1.size(1);
115+
} else if (mat1.dim() == 3) {
116+
b = mat1.size(0);
117+
m = mat1.size(1);
118+
k_packed = mat1.size(2);
119+
} else {
120+
TORCH_CHECK(false, "mat1 must be 2D or 3D tensor");
121+
}
122+
123+
if (mat2.dim() == 2) {
124+
n = mat2.size(0);
125+
TORCH_CHECK(mat2.size(1) == k_packed, "mat2.size(1) must match mat1.size(-1)");
126+
} else if (mat2.dim() == 3) {
127+
TORCH_CHECK(mat2.size(0) == b, "Batch dimensions must match");
128+
n = mat2.size(1);
129+
TORCH_CHECK(mat2.size(2) == k_packed, "mat2.size(2) must match mat1.size(-1)");
130+
} else {
131+
TORCH_CHECK(false, "mat2 must be 2D or 3D tensor");
132+
}
133+
134+
// k_packed stores 2 FP4 values per byte
135+
int64_t k = k_packed * 2;
136+
137+
TORCH_CHECK(globalScale.numel() == 1, "globalScale must be a scalar tensor");
138+
139+
// Configure the kernel
140+
CutlassGemmConfig config =
141+
(tactic >= 0) ? getFp4GemmConfig(m, n, k, tactic)
142+
: CutlassGemmConfig(CutlassTileConfigSM120::CtaShape128x128x128B,
143+
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
144+
ClusterShape::ClusterShape_1x1x1);
145+
146+
// Validate output dimensions
147+
std::vector<int64_t> out_shape =
148+
(b > 1) ? std::vector<int64_t>{b, m, n} : std::vector<int64_t>{m, n};
149+
TORCH_CHECK(out.dim() == out_shape.size(), "out must have ", out_shape.size(), " dimensions");
150+
for (size_t i = 0; i < out_shape.size(); ++i) {
151+
TORCH_CHECK(out.sizes()[i] == out_shape[i], "out.size(", i, "): expected ", out_shape[i],
152+
", got ", out.sizes()[i]);
153+
}
154+
155+
c10::ScalarType out_dtype = out.scalar_type();
156+
157+
switch (out_dtype) {
158+
case at::ScalarType::Half:
159+
runGemm<half>(out, mat1, mat2, mat1Scale, mat2Scale, globalScale, m, n, k, b, config,
160+
workspace_buffer);
161+
break;
162+
case at::ScalarType::BFloat16:
163+
runGemm<__nv_bfloat16>(out, mat1, mat2, mat1Scale, mat2Scale, globalScale, m, n, k, b, config,
164+
workspace_buffer);
165+
break;
166+
default:
167+
TORCH_CHECK(false, "out_dtype must be one of fp16/bf16.");
168+
}
169+
return out;
170+
}
171+
172+
} // namespace
173+
174+
at::Tensor fp4_gemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale,
175+
at::Tensor const& mat2Scale, at::Tensor const& globalScale, at::Tensor out,
176+
at::Tensor workspace_buffer, int64_t tactic) {
177+
return fp4_bmm_impl(mat1, mat2, mat1Scale, mat2Scale, globalScale, out, workspace_buffer, tactic);
178+
}
179+
180+
int64_t fp4_gemm_tactic_num() {
181+
static const int64_t totalTactics =
182+
CutlassFp4GemmRunner<__nv_bfloat16, FP4GemmType::W4A4_NVFP4_NVFP4>{}.getConfigs().size();
183+
return totalTactics;
184+
}
185+
186+
} // namespace torch_ext
187+
188+
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
189+
m.def("fp4_gemm", &torch_ext::fp4_gemm);
190+
m.def("fp4_gemm_tactic_num", &torch_ext::fp4_gemm_tactic_num);
191+
}

csrc/fp4_gemm_cutlass_sm120.jinja

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
// Use SM120-specific dispatch template
18+
#include "flashinfer/gemm/fp4_gemm_cutlass_template_sm120.h"
19+
20+
namespace flashinfer {
21+
namespace gemm {
22+
// SM120/121 only supports 1x1x1 cluster shape
23+
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 1, 1, _1SM)
24+
25+
} // namespace gemm
26+
} // namespace flashinfer

flashinfer/fp4_quantization.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525
from .jit import env as jit_env
2626
from .jit import (
2727
gen_jit_spec,
28+
sm121a_nvcc_flags,
29+
sm120a_nvcc_flags,
2830
sm110a_nvcc_flags,
29-
sm100a_nvcc_flags,
3031
sm103a_nvcc_flags,
32+
sm100a_nvcc_flags,
3133
sm90a_nvcc_flags,
3234
)
3335
from .jit.cpp_ext import is_cuda_version_at_least
@@ -86,6 +88,14 @@ def gen_fp4_quantization_sm110_module() -> JitSpec:
8688
return gen_fp4_quantization_module(sm110a_nvcc_flags, "110")
8789

8890

91+
def gen_fp4_quantization_sm120_module() -> JitSpec:
92+
return gen_fp4_quantization_module(sm120a_nvcc_flags, "120")
93+
94+
95+
def gen_fp4_quantization_sm121_module() -> JitSpec:
96+
return gen_fp4_quantization_module(sm121a_nvcc_flags, "121")
97+
98+
8999
def gen_fp4_quantization_module(nvcc_flags: List[str], device_arch: str) -> JitSpec:
90100
return gen_jit_spec(
91101
f"fp4_quantization_{device_arch}",
@@ -119,17 +129,20 @@ def gen_fp4_quantization_module(nvcc_flags: List[str], device_arch: str) -> JitS
119129

120130
@functools.cache
121131
def get_fp4_quantization_module(backend: str = "100"):
122-
if backend == "110":
123-
module = gen_fp4_quantization_sm110_module().build_and_load()
124-
elif backend == "100":
125-
module = gen_fp4_quantization_sm100_module().build_and_load()
126-
elif backend == "103":
127-
module = gen_fp4_quantization_sm103_module().build_and_load()
128-
elif backend == "90":
129-
module = gen_fp4_quantization_sm90_module().build_and_load()
130-
else:
132+
backend_modules = {
133+
"121": gen_fp4_quantization_sm121_module,
134+
"120": gen_fp4_quantization_sm120_module,
135+
"110": gen_fp4_quantization_sm110_module,
136+
"103": gen_fp4_quantization_sm103_module,
137+
"100": gen_fp4_quantization_sm100_module,
138+
"90": gen_fp4_quantization_sm90_module,
139+
}
140+
141+
if backend not in backend_modules:
131142
raise ValueError(f"Invalid backend: {backend}")
132143

144+
module = backend_modules[backend]().build_and_load()
145+
133146
@register_custom_op(
134147
"flashinfer::fp4_quantize_sm100",
135148
mutates_args=(""),

0 commit comments

Comments
 (0)