Skip to content

Commit 82f505a

Browse files
[BENCHMARK] Reuse CUTLASS's gemm configuration file (#4720)
# Description This PR introduces a new mechanism for fetching GEMM configurations. Instead of hardcoding the `(shape → config)` mapping, the `config-tool.py` script now parses a configuration file and generates the `gemm_config` structure dynamically. The configuration file consists of a list of GEMM kernel invocations with the corresponding `GemmConfig`. These will be extracted and used to invoke the kernel with the appropriate configuration. ## Note Currently, the configuration file is located in `benchmarks/cutlass_kernel/gemm`. In the future, this should be updated to fetch the file directly from the CUTLASS repository: https://github.com/intel/cutlass-sycl This change will be made once the CUTLASS repo includes a unified file containing the optimal configurations for all shapes used in the Triton benchmark. --------- Signed-off-by: Lukas Sommer <[email protected]> Signed-off-by: Jefferson Le Quellec <[email protected]> Co-authored-by: Lukas Sommer <[email protected]>
1 parent ac1b911 commit 82f505a

File tree

7 files changed

+140
-97
lines changed

7 files changed

+140
-97
lines changed

benchmarks/cmake/FindCUTLASSLibrary.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ if (NOT CUTLASSLibrary_FOUND)
2828
set(CUTLASSLibrary_INCLUDE_DIR "${CUTLASSLibrary_SOURCE_DIR}/include" CACHE INTERNAL "CUTLASSLibrary_SOURCE_DIR")
2929
set(CUTLASSLibrary_INCLUDE_TOOL_DIR "${CUTLASSLibrary_SOURCE_DIR}/tools/util/include" CACHE INTERNAL "CUTLASSLibrary_SOURCE_DIR")
3030
set(CUTLASSLibrary_INCLUDE_APPLICATION_DIR "${CUTLASSLibrary_SOURCE_DIR}/applications" CACHE INTERNAL "CUTLASSLibrary_SOURCE_DIR")
31+
set(CUTLASSLibrary_INCLUDE_BENCHMARK_DIR "${CUTLASSLibrary_SOURCE_DIR}/benchmarks" CACHE INTERNAL "CUTLASSLibrary_SOURCE_DIR")
32+
set(CUTLASSLibrary_BENCHMARK_CONFIG_DIR "${CUTLASSLibrary_SOURCE_DIR}/benchmarks/device/pvc/input_files" CACHE INTERNAL "CUTLASSLibrary_SOURCE_DIR")
3133

3234
find_package_handle_standard_args(
3335
CUTLASSLibrary

benchmarks/cutlass_kernel/CMakeLists.txt

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,45 @@ set(CUTLASS_KERNEL_FLAGS ${CUTLASS_KERNEL_FLAGS}
77
-Xs "-options \"-igc_opts 'VISAOptions=-perfmodel,VectorAliasBBThreshold=1000,ExtraOCLOptions=-cl-intel-256-GRF-per-thread'\" -options -ze-opt-large-register-file"
88
)
99

10+
# Path to the configuration tool
11+
set(CONFIG_TOOL ${CMAKE_CURRENT_SOURCE_DIR}/config-tool.py)
12+
13+
# Input and output files
14+
# The name of this file must be kept in sync with the best known CUTLASS config.
15+
# TODO: Re-enable gemm config input to come from `CUTLASSLibrary_BENCHMARK_CONFIG_DIR`
16+
# set(GEMM_CONFIG_INPUT ${CUTLASSLibrary_BENCHMARK_CONFIG_DIR}/input_gemm.in)
17+
set(GEMM_CONFIG_INPUT ${CMAKE_CURRENT_SOURCE_DIR}/gemm/input_gemm.in)
18+
set(GEMM_CONFIG_OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/gemm_table.hpp)
19+
set(GEMM_CONFIG_NAME "gemm_config")
20+
21+
# Use a custom command to generate a C++ header with the configuration table
22+
# from the CUTLASS benchmark configuration.
23+
add_custom_command(
24+
OUTPUT ${GEMM_CONFIG_OUTPUT}
25+
COMMAND ${CMAKE_COMMAND} -E echo "Generating GEMM config header..."
26+
COMMAND ${Python3_EXECUTABLE} ${CONFIG_TOOL} ${GEMM_CONFIG_INPUT} -o ${GEMM_CONFIG_OUTPUT} --name ${GEMM_CONFIG_NAME}
27+
DEPENDS ${GEMM_CONFIG_INPUT} ${CONFIG_TOOL}
28+
COMMENT "Generate GEMM configuration"
29+
VERBATIM
30+
)
31+
32+
# Create a target that other targets can depend on
33+
add_custom_target(generate_gemm_config DEPENDS ${GEMM_CONFIG_OUTPUT})
34+
1035
Python3_add_library(cutlass_kernel MODULE WITH_SOABI python_main.cpp)
1136

1237
target_compile_options(cutlass_kernel PRIVATE "-fsycl" "-fsycl-targets=intel_gpu_pvc,intel_gpu_bmg_g21" "-fpreview-breaking-changes")
1338
target_compile_options(cutlass_kernel PRIVATE "-DCUTLASS_ENABLE_SYCL")
1439
target_compile_options(cutlass_kernel PRIVATE "-DSYCL_INTEL_TARGET")
40+
target_compile_definitions(cutlass_kernel PRIVATE GEMM_CONFIG_HEADER=\"${GEMM_CONFIG_OUTPUT}\")
41+
target_compile_definitions(cutlass_kernel PRIVATE GEMM_CONFIG_NAME=\"${GEMM_CONFIG_NAME}\")
1542

1643
target_link_options(cutlass_kernel PRIVATE ${CUTLASS_KERNEL_FLAGS})
1744
target_link_libraries(cutlass_kernel PUBLIC ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY})
1845

19-
target_include_directories(cutlass_kernel PUBLIC "${CUTLASSLibrary_INCLUDE_DIR}" "${CUTLASSLibrary_INCLUDE_TOOL_DIR}" "${CUTLASSLibrary_INCLUDE_APPLICATION_DIR}")
46+
target_include_directories(cutlass_kernel PUBLIC "${CUTLASSLibrary_INCLUDE_DIR}" "${CUTLASSLibrary_INCLUDE_TOOL_DIR}" "${CUTLASSLibrary_INCLUDE_APPLICATION_DIR}" "${CUTLASSLibrary_INCLUDE_BENCHMARK_DIR}")
47+
48+
add_dependencies(cutlass_kernel generate_gemm_config)
2049

2150
add_subdirectory(gemm)
2251
add_subdirectory(attention)

benchmarks/cutlass_kernel/attention/attention.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,11 @@ using FARunPtr = int (*)(const at::Tensor &Q, const at::Tensor &K,
204204
int SeqLengthKV, int HeadSizeQK, int HeadSizeVO,
205205
float sm_scale);
206206

207-
auto attention(const at::Tensor &Q, const at::Tensor &K, const at::Tensor &V,
208-
at::Tensor &O, int Batch, int NumHeadsQ, int NumHeadsKV,
209-
int SeqLengthQO, int SeqLengthKV, int HeadSizeQK, int HeadSizeVO,
210-
bool Causal, float sm_scale) -> int {
207+
auto attention_kernel(const at::Tensor &Q, const at::Tensor &K,
208+
const at::Tensor &V, at::Tensor &O, int Batch,
209+
int NumHeadsQ, int NumHeadsKV, int SeqLengthQO,
210+
int SeqLengthKV, int HeadSizeQK, int HeadSizeVO,
211+
bool Causal, float sm_scale) -> int {
211212
constexpr int PipelineStages = 2;
212213
FARunPtr f = nullptr;
213214

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#!/usr/bin/env python3
2+
3+
import argparse
4+
import re
5+
import sys
6+
7+
8+
def build_config_map(file_paths):
9+
config_map = {}
10+
pattern = re.compile(r'^(?P<name>\S+).*?--l=(?P<l>\d+)\s+--m=(?P<m>\d+)\s+--k=(?P<k>\d+)\s+--n=(?P<n>\d+)')
11+
12+
for path in file_paths:
13+
try:
14+
with open(path, 'r', encoding='utf-8') as f:
15+
for line in f:
16+
match = pattern.match(line.strip())
17+
if match:
18+
name = match.group('name')
19+
l = int(match.group('l'))
20+
m = int(match.group('m'))
21+
k = int(match.group('k'))
22+
n = int(match.group('n'))
23+
config_map[(l, m, n, k)] = name
24+
except IOError as e:
25+
print(f'Error reading {path}: {e}', file=sys.stderr)
26+
27+
return config_map
28+
29+
30+
def main():
31+
parser = argparse.ArgumentParser(description='Parse GEMM benchmark files and generate C++ table.')
32+
parser.add_argument('-o', '--output', required=True, help='Output file path')
33+
parser.add_argument('--name', required=True, help='Name identifier for logging or grouping')
34+
parser.add_argument('inputs', nargs='+', help='Input file(s) with GEMM benchmark data')
35+
36+
args = parser.parse_args()
37+
38+
config_map = build_config_map(args.inputs)
39+
40+
try:
41+
with open(args.output, 'w', encoding='utf-8') as outfile:
42+
outfile.write('// This file was auto-generated, do not edit!\n\n')
43+
outfile.write(
44+
f'static constexpr std::array<std::pair<Dim, GemmRunPtr>, {len(config_map)}> {args.name} = {{{{\n')
45+
for (l, m, n, k), name in config_map.items():
46+
outfile.write(f'{{ {{ {l}, {m}, {n}, {k} }}, &gemm_run<{name}> }},\n')
47+
outfile.write('}};\n')
48+
except IOError as e:
49+
print(f'Error writing output file: {e}')
50+
sys.exit(1)
51+
52+
53+
if __name__ == '__main__':
54+
main()

benchmarks/cutlass_kernel/gemm/gemm.hpp

Lines changed: 26 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -9,69 +9,27 @@
99
#include <exception>
1010
#include <iostream>
1111

12+
#define CUTLASS_CREATE_GEMM_BENCHMARK(x)
13+
#define CUTLASS_BENCHMARK(x)
14+
#include "gemm/benchmarks_sycl.hpp"
15+
#include "gemm/gemm_configuration_sycl.hpp"
16+
1217
////////////////////////////////////////////////////////////////////////////////
1318
// PRIVATE FUNCTION
1419
////////////////////////////////////////////////////////////////////////////////
1520

16-
template <typename TileShape>
21+
template <typename GemmConfig>
1722
static auto gemm_run(const at::Tensor &A, const at::Tensor &B, at::Tensor &C,
1823
const int M, const int N, const int K, const int L)
1924
-> int {
2025
RECORD_FUNCTION("cutlass gemm", {});
2126

22-
using ElementAccumulator = float;
2327
using ElementComputeEpilogue = float;
24-
using ElementInputA = cutlass::bfloat16_t;
25-
using ElementInputB = cutlass::bfloat16_t;
26-
using ElementOutput = float;
27-
28-
using LayoutA = typename cutlass::layout::RowMajor;
29-
using LayoutB = typename cutlass::layout::RowMajor;
30-
using LayoutC = typename cutlass::layout::RowMajor;
31-
using LayoutD = typename cutlass::layout::RowMajor;
32-
33-
constexpr int AlignmentA = sizeof(ElementInputA);
34-
constexpr int AlignmentB = sizeof(ElementInputB);
35-
constexpr int AlignmentC = sizeof(ElementAccumulator);
36-
constexpr int AlignmentD = sizeof(ElementOutput);
37-
38-
/// MAIN LOOP ///
39-
40-
using CollectiveMainloop =
41-
typename cutlass::gemm::collective::CollectiveBuilder<
42-
cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, ElementInputA,
43-
LayoutA, AlignmentA, ElementInputB, LayoutB, AlignmentB,
44-
ElementAccumulator, TileShape,
45-
cute::Shape<cute::_1, cute::_1, cute::_1>,
46-
cutlass::gemm::collective::StageCountAuto,
47-
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
48-
49-
/// EPILOGUE LOOP ///
50-
51-
using EpilogueOp = typename cutlass::epilogue::fusion::LinCombEltAct<
52-
cutlass::epilogue::thread::ReLu, ElementOutput, ElementComputeEpilogue,
53-
ElementAccumulator, ElementAccumulator,
54-
cutlass::FloatRoundStyle::round_to_nearest>;
55-
using CollectiveEpilogue =
56-
typename cutlass::epilogue::collective::CollectiveBuilder<
57-
cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, TileShape,
58-
cute::Shape<cute::_1, cute::_1, cute::_1>,
59-
cutlass::epilogue::collective::EpilogueTileAuto,
60-
ElementComputeEpilogue, ElementAccumulator, ElementAccumulator,
61-
LayoutC, AlignmentC, ElementOutput, LayoutD, AlignmentD,
62-
cutlass::epilogue::collective::EpilogueScheduleAuto,
63-
EpilogueOp>::CollectiveOp;
64-
65-
/// GEMM ///
66-
67-
using GemmKernel = typename cutlass::gemm::kernel::GemmUniversal<
68-
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>;
6928

7029
/// GEMM INVOCATION ///
7130

7231
try {
73-
using Gemm =
74-
typename cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
32+
using Gemm = GemmConfig::Gemm;
7533
typename Gemm::Arguments arguments;
7634

7735
/// Buffer Initialization
@@ -107,15 +65,17 @@ static auto gemm_run(const at::Tensor &A, const at::Tensor &B, at::Tensor &C,
10765
"Query result for SM count per device: " << hw_info.sm_count);
10866
}
10967

110-
arguments = {cutlass::gemm::GemmUniversalMode::kGemm,
111-
problem_size,
112-
{_A, stride_A, _B, stride_B},
113-
{{ElementComputeEpilogue(1), ElementComputeEpilogue(0)},
114-
nullptr,
115-
stride_C,
116-
_C,
117-
stride_D},
118-
hw_info};
68+
arguments = GemmConfig::defaultArguments();
69+
arguments.mode = cutlass::gemm::GemmUniversalMode::kGemm;
70+
arguments.problem_shape = problem_size;
71+
arguments.mainloop = {_A, stride_A, _B, stride_B};
72+
arguments.epilogue = {
73+
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)},
74+
nullptr,
75+
stride_C,
76+
_C,
77+
stride_D};
78+
arguments.hw_info = hw_info;
11979

12080
Gemm gemm_op;
12181

@@ -148,43 +108,19 @@ using GemmRunPtr = int (*)(const at::Tensor &A, const at::Tensor &B,
148108
at::Tensor &C, const int M, const int N, const int K,
149109
const int L);
150110

151-
/// Each entry associates a specific problem dimension to their corresponding
152-
/// tile shape. For more details, see:
153-
/// https://github.com/codeplaysoftware/cutlass-sycl/tree/sycl-develop/benchmarks
154-
155-
// clang-format off
156-
static constexpr std::array<std::pair<Dim, GemmRunPtr>, 18> tile_map = {{
157-
{ { 1, 1024, 8192, 28672 }, &gemm_run<cute::Shape<cute::_128, cute::_512, cute::_32>> },
158-
{ { 32, 4096, 128, 4096 }, &gemm_run<cute::Shape<cute::_256, cute::_128, cute::_32>> },
159-
{ { 4096, 8, 16384, 128 }, &gemm_run<cute::Shape<cute::_128, cute::_256, cute::_16>> },
160-
{ { 4096, 8, 128, 16384 }, &gemm_run<cute::Shape<cute::_8, cute::_128, cute::_32>> },
161-
{ { 1, 1, 1024, 4096 }, &gemm_run<cute::Shape<cute::_8, cute::_64, cute::_32>> },
162-
{ { 1, 1, 4096, 4096 }, &gemm_run<cute::Shape<cute::_8, cute::_128, cute::_32>> },
163-
{ { 1, 1, 6144, 4096 }, &gemm_run<cute::Shape<cute::_8, cute::_128, cute::_32>> },
164-
{ { 1, 1, 14336, 4096 }, &gemm_run<cute::Shape<cute::_64, cute::_256, cute::_32>> },
165-
{ { 1, 1, 28672, 4096 }, &gemm_run<cute::Shape<cute::_32, cute::_128, cute::_32>> },
166-
{ { 1, 1, 128256, 4096 }, &gemm_run<cute::Shape<cute::_32, cute::_512, cute::_32>> },
167-
{ { 1, 1, 4096, 14336 }, &gemm_run<cute::Shape<cute::_8, cute::_128, cute::_32>> },
168-
{ { 1, 8, 1024, 4096 }, &gemm_run<cute::Shape<cute::_8, cute::_64, cute::_32>> },
169-
{ { 1, 8, 4096, 4096 }, &gemm_run<cute::Shape<cute::_256, cute::_256, cute::_32>> },
170-
{ { 1, 8, 6144, 4096 }, &gemm_run<cute::Shape<cute::_256, cute::_256, cute::_32>> },
171-
{ { 1, 8, 14336, 4096 }, &gemm_run<cute::Shape<cute::_64, cute::_256, cute::_32>> },
172-
{ { 1, 8, 28672, 4096 }, &gemm_run<cute::Shape<cute::_32, cute::_128, cute::_32>> },
173-
{ { 1, 8, 128256, 4096 }, &gemm_run<cute::Shape<cute::_32, cute::_512, cute::_32>> },
174-
{ { 1, 8, 4096, 14336 }, &gemm_run<cute::Shape<cute::_256, cute::_256, cute::_32>> },
175-
}};
176-
// clang-format on
177-
178-
auto gemm(const at::Tensor &A, const at::Tensor &B, at::Tensor &C, const int M,
179-
const int N, const int K, const int L) -> int {
111+
// Includes the table mapping problem shape to best config from the header
112+
// generated by the configuration tool from the CUTLASS config file.
113+
#include GEMM_CONFIG_HEADER
114+
115+
auto gemm_kernel(const at::Tensor &A, const at::Tensor &B, at::Tensor &C,
116+
const int M, const int N, const int K, const int L) -> int {
180117
const Dim test_case{L, M, N, K};
181118

182-
for (auto const &kv : tile_map) {
119+
for (auto const &kv : gemm_config) {
183120
if (test_case == kv.first) {
184121
return kv.second(A, B, C, M, N, K, L);
185122
}
186123
}
187124

188-
return gemm_run<cute::Shape<cute::_256, cute::_256, cute::_32>>(A, B, C, M, N,
189-
K, L);
125+
return gemm_run<PvcGemmBF16BF16FP32_RRR_1>(A, B, C, M, N, K, L);
190126
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=5120 --n=13824
2+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4 --k=4096 --n=12288
3+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192
4+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192
5+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768
6+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192
7+
PvcGemmBF16BF16FP32_RRR_2 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=28672 --n=8192
8+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=3072 --k=4096 --n=3072
9+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096
10+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192
11+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=1024
12+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=4096
13+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=8192 --n=8192
14+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=1024
15+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=4096
16+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=8192
17+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192
18+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128
19+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096
20+
PvcGemmBF16BF16FP32_RRR_3 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128
21+
PvcGemmBF16BF16FP32_RRR_5 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128

benchmarks/cutlass_kernel/python_main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@
2424
////////////////////////////////////////////////////////////////////////////////
2525

2626
PYBIND11_MODULE(cutlass_kernel, m) {
27-
m.def("gemm", &gemm, "gemm (CUTLASS)");
28-
m.def("attention", &attention, "attention (CUTLASS)");
27+
m.def("gemm", &gemm_kernel, "gemm (CUTLASS)");
28+
m.def("attention", &attention_kernel, "attention (CUTLASS)");
2929
}

0 commit comments

Comments
 (0)