Skip to content

Commit 53a31d5

Browse files
ttyioyzh119
andauthored
feature: add cutlass as bmm_fp8 backend. (#1397)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description add cutlass as bmm_fp8 backend. ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Zihao Ye <[email protected]>
1 parent ef11d2a commit 53a31d5

File tree

7 files changed

+854
-8
lines changed

7 files changed

+854
-8
lines changed

β€Žcsrc/fp8_gemm_cutlass.cu

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
/*
2+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <ATen/cuda/EmptyTensor.h>
17+
#include <cuda_fp16.h>
18+
19+
#include <cstddef>
20+
#include <cstdint>
21+
#include <functional>
22+
#include <type_traits>
23+
#include <vector>
24+
25+
#include "flashinfer/gemm/cutlass_gemm_configs.h"
26+
#include "flashinfer/gemm/fp8_gemm_cutlass.h"
27+
#include "flashinfer/gemm/fp8_gemm_cutlass_template.h"
28+
#include "pytorch_extension_utils.h"
29+
30+
using flashinfer::gemm::ClusterShape;
31+
using flashinfer::gemm::CutlassFp8GemmRunner;
32+
using flashinfer::gemm::CutlassFp8GemmRunnerInterface;
33+
using flashinfer::gemm::CutlassGemmConfig;
34+
using flashinfer::gemm::CutlassTileConfigSM100;
35+
using flashinfer::gemm::EpilogueScheduleType;
36+
using flashinfer::gemm::MainloopScheduleType;
37+
38+
namespace flashinfer {
39+
namespace gemm {
40+
template class CutlassFp8GemmRunner<__nv_bfloat16>;
41+
template class CutlassFp8GemmRunner<half>;
42+
} // namespace gemm
43+
} // namespace flashinfer
44+
45+
namespace torch_ext {
46+
47+
namespace {
48+
49+
CutlassGemmConfig getFp8GemmConfig(int64_t m, int64_t n, int64_t k, int64_t tactic) {
50+
auto getCutlassFp8GemmConfigs = []() {
51+
CutlassFp8GemmRunner<__nv_bfloat16> gemmRunner;
52+
return gemmRunner.getConfigs();
53+
};
54+
static std::vector<CutlassGemmConfig> globalConfigs = getCutlassFp8GemmConfigs();
55+
TORCH_CHECK(tactic >= 0 && tactic < globalConfigs.size(), "tactic must be between 0 and ",
56+
globalConfigs.size());
57+
return globalConfigs[tactic];
58+
}
59+
60+
template <typename T>
61+
void runGemm(at::Tensor& out, at::Tensor const& mat1, at::Tensor const& mat2,
62+
at::Tensor const& scale, int64_t m, int64_t n, int64_t k, int64_t b,
63+
CutlassGemmConfig const& gemmConfig, at::Tensor workspace_buffer) {
64+
CutlassFp8GemmRunner<T> gemmRunner;
65+
66+
int64_t const required_workspace_size = gemmRunner.getWorkspaceSize(m, n, k);
67+
int64_t const provided_workspace_size =
68+
workspace_buffer.numel() * workspace_buffer.element_size();
69+
70+
auto runKernel = [&](void* workspace) {
71+
gemmRunner.gemm(reinterpret_cast<__nv_fp8_e4m3 const*>(mat1.const_data_ptr()),
72+
reinterpret_cast<__nv_fp8_e4m3 const*>(mat2.const_data_ptr()),
73+
reinterpret_cast<float const*>(scale.const_data_ptr()), out.data_ptr(), m, n, k,
74+
b, gemmConfig, reinterpret_cast<char*>(workspace), required_workspace_size,
75+
at::cuda::getCurrentCUDAStream(mat1.get_device()));
76+
};
77+
78+
if (provided_workspace_size < required_workspace_size) {
79+
at::Tensor new_workspace = at::detail::empty_cuda(
80+
{required_workspace_size}, at::ScalarType::Char, mat1.device(), std::nullopt);
81+
82+
runKernel(new_workspace.data_ptr());
83+
} else {
84+
runKernel(workspace_buffer.data_ptr());
85+
}
86+
}
87+
88+
at::Tensor fp8_bmm_impl(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& scale,
89+
at::Tensor out, at::Tensor workspace_buffer, int64_t tactic) {
90+
CHECK_INPUT(mat1);
91+
CHECK_INPUT(mat2);
92+
93+
int mat2_k_scale = 1;
94+
95+
int64_t m, n, k, b;
96+
if (mat1.dim() == 2) {
97+
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
98+
TORCH_CHECK(mat1.sizes()[1] == mat2.sizes()[1] * mat2_k_scale,
99+
"mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[0], "x",
100+
mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
101+
m = mat1.sizes()[0];
102+
n = mat2.sizes()[0];
103+
k = mat2.sizes()[1];
104+
b = 1;
105+
} else if (mat1.dim() == 3) {
106+
TORCH_CHECK(mat2.dim() == 3, "mat2 must be a batch of matrices");
107+
TORCH_CHECK(mat1.sizes()[0] == mat2.sizes()[0], "mat1 and mat2 must have the same batch size (",
108+
mat1.sizes()[0], " and ", mat2.sizes()[0], ")");
109+
TORCH_CHECK(mat1.sizes()[2] == mat2.sizes()[2] * mat2_k_scale,
110+
"mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[1], "x",
111+
mat1.sizes()[2], " and ", mat2.sizes()[1], "x", mat2.sizes()[2], ")");
112+
m = mat1.sizes()[1];
113+
n = mat2.sizes()[1];
114+
k = mat2.sizes()[2];
115+
b = mat1.sizes()[0];
116+
} else {
117+
C10_THROW_ERROR(NotImplementedError, "mat1 must be a matrix or a batch of matrices");
118+
}
119+
120+
// No heuristic for now, we rely on the autotuner to select the best tactic.
121+
if (tactic == -1) {
122+
tactic = 0;
123+
}
124+
auto config = getFp8GemmConfig(m, n, k, tactic);
125+
126+
// Validate out dimensions
127+
std::vector<int64_t> out_shape =
128+
mat1.dim() == 2 ? std::vector<int64_t>{m, n} : std::vector<int64_t>{b, m, n};
129+
TORCH_CHECK(out.dim() == out_shape.size(), "out must have ", out_shape.size(),
130+
" dimensions, but got ", out.dim());
131+
for (int i = 0; i < out_shape.size(); ++i) {
132+
TORCH_CHECK(out.sizes()[i] == out_shape[i], "out shape mismatch at dimension ", i,
133+
": expected ", out_shape[i], ", got ", out.sizes()[i]);
134+
}
135+
136+
switch (out.scalar_type()) {
137+
case at::ScalarType::Half:
138+
runGemm<half>(out, mat1, mat2, scale, m, n, k, b, config, workspace_buffer);
139+
break;
140+
case at::ScalarType::BFloat16:
141+
runGemm<__nv_bfloat16>(out, mat1, mat2, scale, m, n, k, b, config, workspace_buffer);
142+
break;
143+
default:
144+
TORCH_CHECK(false, "out_dtype must be one of fp16/bf16.");
145+
}
146+
return out;
147+
}
148+
149+
} // namespace
150+
151+
at::Tensor fp8_gemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& scale,
152+
at::Tensor out, at::Tensor workspace_buffer, int64_t tactic) {
153+
return fp8_bmm_impl(mat1, mat2, scale, out, workspace_buffer, tactic);
154+
}
155+
156+
int64_t fp8_gemm_tactic_num() {
157+
auto getCutlassConfigs = []() {
158+
CutlassFp8GemmRunner<__nv_bfloat16> gemmRunner;
159+
return gemmRunner.getConfigs();
160+
};
161+
static int64_t totalTactics = getCutlassConfigs().size();
162+
return totalTactics;
163+
}
164+
165+
} // namespace torch_ext
166+
167+
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
168+
m.def("fp8_gemm", &torch_ext::fp8_gemm);
169+
m.def("fp8_gemm_tactic_num", &torch_ext::fp8_gemm_tactic_num);
170+
}

β€Žcsrc/fp8_gemm_cutlass.jinja

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "flashinfer/gemm/fp8_gemm_template_sm100.h"
18+
19+
namespace flashinfer {
20+
namespace gemm {
21+
INSTANCE_FP8_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 1, 1, _1SM);
22+
INSTANCE_FP8_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 2, 1, _1SM);
23+
INSTANCE_FP8_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 1, 1, _2SM);
24+
INSTANCE_FP8_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 2, 1, _2SM);
25+
} // namespace gemm
26+
} // namespace flashinfer

β€Žflashinfer/gemm.py

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,53 @@ def gen_gemm_sm100_module_cutlass_fp4() -> JitSpec:
219219
)
220220

221221

222+
def gen_gemm_sm100_module_cutlass_fp8() -> JitSpec:
223+
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_fp8"
224+
os.makedirs(gen_directory, exist_ok=True)
225+
source_paths = [
226+
jit_env.FLASHINFER_CSRC_DIR / "fp8_gemm_cutlass.cu",
227+
]
228+
229+
with open(jit_env.FLASHINFER_CSRC_DIR / "fp8_gemm_cutlass.jinja") as f:
230+
kernel_inst_templ = jinja2.Template(f.read())
231+
dtype_list = ["__nv_bfloat16", "half"]
232+
cta_m_n_k_list = [
233+
(64, 64, 128),
234+
(64, 128, 128),
235+
(64, 256, 128),
236+
(128, 64, 128),
237+
(128, 128, 128),
238+
(128, 256, 128),
239+
]
240+
for cta_m, cta_n, cta_k in cta_m_n_k_list:
241+
for dtype in dtype_list:
242+
dest_path = (
243+
gen_directory
244+
/ f"fp8_gemm_cutlass_{dtype}_{cta_m}_{cta_n}_{cta_k}.cu"
245+
)
246+
source_paths.append(dest_path)
247+
source = kernel_inst_templ.render(
248+
type=dtype,
249+
cta_m=cta_m,
250+
cta_n=cta_n,
251+
cta_k=cta_k,
252+
)
253+
write_if_different(dest_path, source)
254+
255+
return gen_jit_spec(
256+
"fp8_gemm_cutlass",
257+
source_paths,
258+
extra_cuda_cflags=sm100a_nvcc_flags
259+
+ [
260+
"-DENABLE_BF16",
261+
],
262+
extra_cflags=[
263+
"-DFAST_BUILD",
264+
],
265+
extra_ldflags=["-lcuda"],
266+
)
267+
268+
222269
def gen_gemm_sm100_module() -> JitSpec:
223270
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100"
224271
os.makedirs(gen_directory, exist_ok=True)
@@ -337,6 +384,82 @@ def get_trtllm_gemm_module():
337384
return op
338385

339386

387+
@functools.cache
388+
def get_gemm_sm100_module_cutlass_fp8():
389+
module = gen_gemm_sm100_module_cutlass_fp8().build_and_load()
390+
391+
class CutlassFp8GemmRunner(TunableRunner):
392+
def __init__(self):
393+
self._fp8_gemm_runner = module.fp8_gemm
394+
395+
def get_valid_tactics(
396+
self,
397+
inputs: List[torch.Tensor],
398+
profile: OptimizationProfile,
399+
) -> List[int]:
400+
return list(range(module.fp8_gemm_tactic_num()))
401+
402+
def forward(
403+
self,
404+
inputs: List[torch.Tensor],
405+
*,
406+
tactic: int = -1,
407+
do_preparation: bool = False,
408+
):
409+
a, b, alpha, out, workspace_buffer = inputs
410+
module.fp8_gemm.default(a, b, alpha, out, workspace_buffer, tactic)
411+
return out
412+
413+
@register_custom_op(
414+
"flashinfer::cutlass_fp8_gemm",
415+
mutates_args=(""),
416+
)
417+
def cutlass_fp8_gemm(
418+
a: torch.Tensor,
419+
b: torch.Tensor,
420+
alpha: torch.Tensor,
421+
out: torch.Tensor,
422+
workspace_buffer: torch.Tensor,
423+
):
424+
tuner = AutoTuner.get()
425+
426+
a_tensor_index = 0
427+
out_tensor_index = 3
428+
429+
tuning_config = TuningConfig(
430+
dynamic_tensor_specs=(
431+
DynamicTensorSpec(
432+
a_tensor_index,
433+
-2,
434+
get_last_power_of_2_num_tokens_buckets,
435+
last_positive_power_of_2,
436+
),
437+
),
438+
constraint_specs=(
439+
ConstraintSpec(
440+
out_tensor_index, -2, lambda shapes: shapes[a_tensor_index][-2]
441+
),
442+
),
443+
)
444+
445+
fp8_runner = CutlassFp8GemmRunner()
446+
447+
inputs = [a, b, alpha, out, workspace_buffer]
448+
_, tactic = tuner.choose_one(
449+
"cutlass_fp8_gemm",
450+
[fp8_runner],
451+
tuning_config,
452+
inputs,
453+
)
454+
455+
fp8_runner(inputs=inputs, tactic=tactic)
456+
457+
# Register the module
458+
return SimpleNamespace(
459+
cutlass_fp8_gemm=cutlass_fp8_gemm,
460+
)
461+
462+
340463
@functools.cache
341464
def get_gemm_sm100_module_cutlass_fp4():
342465
module = gen_gemm_sm100_module_cutlass_fp4().build_and_load()
@@ -1520,7 +1643,7 @@ def bmm_fp8(
15201643
B_scale: torch.Tensor,
15211644
dtype: torch.dtype,
15221645
out: Optional[torch.Tensor] = None,
1523-
backend: Literal["cudnn", "cublas"] = "cublas",
1646+
backend: Literal["cudnn", "cublas", "cutlass"] = "cublas",
15241647
) -> torch.Tensor:
15251648
r"""BMM FP8
15261649
@@ -1544,7 +1667,7 @@ def bmm_fp8(
15441667
out: Optional[torch.Tensor]
15451668
Out tensor, shape (b, m, n), bf16 or fp16, defaults to ``None``.
15461669
1547-
backend: Literal["cudnn", "cublas"]
1670+
backend: Literal["cudnn", "cublas", "cutlass"]
15481671
The backend to use for the operation. Defaults to ``"cublas"``.
15491672
15501673
Returns
@@ -1592,6 +1715,13 @@ def bmm_fp8(
15921715
return _cudnn_gemm_fp8(workspace_buffer, A, B, A_scale, B_scale, out, dtype)
15931716
elif backend == "cublas":
15941717
get_gemm_module().bmm_fp8(workspace_buffer, A, B, out, A_scale, B_scale)
1718+
elif backend == "cutlass":
1719+
if A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2:
1720+
raise ValueError("e5m2 is not supported for cutlass backend")
1721+
1722+
get_gemm_sm100_module_cutlass_fp8().cutlass_fp8_gemm(
1723+
A, B.transpose(-2, -1), A_scale * B_scale, out, workspace_buffer
1724+
)
15951725
return out
15961726

15971727

0 commit comments

Comments
Β (0)