|
9 | 9 | #include <exception>
|
10 | 10 | #include <iostream>
|
11 | 11 |
|
| 12 | +#define CUTLASS_CREATE_GEMM_BENCHMARK(x) |
| 13 | +#define CUTLASS_BENCHMARK(x) |
| 14 | +#include "gemm/benchmarks_sycl.hpp" |
| 15 | +#include "gemm/gemm_configuration_sycl.hpp" |
| 16 | + |
12 | 17 | ////////////////////////////////////////////////////////////////////////////////
|
13 | 18 | // PRIVATE FUNCTION
|
14 | 19 | ////////////////////////////////////////////////////////////////////////////////
|
15 | 20 |
|
16 |
| -template <typename TileShape> |
| 21 | +template <typename GemmConfig> |
17 | 22 | static auto gemm_run(const at::Tensor &A, const at::Tensor &B, at::Tensor &C,
|
18 | 23 | const int M, const int N, const int K, const int L)
|
19 | 24 | -> int {
|
20 | 25 | RECORD_FUNCTION("cutlass gemm", {});
|
21 | 26 |
|
22 |
| - using ElementAccumulator = float; |
23 | 27 | using ElementComputeEpilogue = float;
|
24 |
| - using ElementInputA = cutlass::bfloat16_t; |
25 |
| - using ElementInputB = cutlass::bfloat16_t; |
26 |
| - using ElementOutput = float; |
27 |
| - |
28 |
| - using LayoutA = typename cutlass::layout::RowMajor; |
29 |
| - using LayoutB = typename cutlass::layout::RowMajor; |
30 |
| - using LayoutC = typename cutlass::layout::RowMajor; |
31 |
| - using LayoutD = typename cutlass::layout::RowMajor; |
32 |
| - |
33 |
| - constexpr int AlignmentA = sizeof(ElementInputA); |
34 |
| - constexpr int AlignmentB = sizeof(ElementInputB); |
35 |
| - constexpr int AlignmentC = sizeof(ElementAccumulator); |
36 |
| - constexpr int AlignmentD = sizeof(ElementOutput); |
37 |
| - |
38 |
| - /// MAIN LOOP /// |
39 |
| - |
40 |
| - using CollectiveMainloop = |
41 |
| - typename cutlass::gemm::collective::CollectiveBuilder< |
42 |
| - cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, ElementInputA, |
43 |
| - LayoutA, AlignmentA, ElementInputB, LayoutB, AlignmentB, |
44 |
| - ElementAccumulator, TileShape, |
45 |
| - cute::Shape<cute::_1, cute::_1, cute::_1>, |
46 |
| - cutlass::gemm::collective::StageCountAuto, |
47 |
| - cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; |
48 |
| - |
49 |
| - /// EPILOGUE LOOP /// |
50 |
| - |
51 |
| - using EpilogueOp = typename cutlass::epilogue::fusion::LinCombEltAct< |
52 |
| - cutlass::epilogue::thread::ReLu, ElementOutput, ElementComputeEpilogue, |
53 |
| - ElementAccumulator, ElementAccumulator, |
54 |
| - cutlass::FloatRoundStyle::round_to_nearest>; |
55 |
| - using CollectiveEpilogue = |
56 |
| - typename cutlass::epilogue::collective::CollectiveBuilder< |
57 |
| - cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, TileShape, |
58 |
| - cute::Shape<cute::_1, cute::_1, cute::_1>, |
59 |
| - cutlass::epilogue::collective::EpilogueTileAuto, |
60 |
| - ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, |
61 |
| - LayoutC, AlignmentC, ElementOutput, LayoutD, AlignmentD, |
62 |
| - cutlass::epilogue::collective::EpilogueScheduleAuto, |
63 |
| - EpilogueOp>::CollectiveOp; |
64 |
| - |
65 |
| - /// GEMM /// |
66 |
| - |
67 |
| - using GemmKernel = typename cutlass::gemm::kernel::GemmUniversal< |
68 |
| - cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>; |
69 | 28 |
|
70 | 29 | /// GEMM INVOCATION ///
|
71 | 30 |
|
72 | 31 | try {
|
73 |
| - using Gemm = |
74 |
| - typename cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; |
| 32 | + using Gemm = GemmConfig::Gemm; |
75 | 33 | typename Gemm::Arguments arguments;
|
76 | 34 |
|
77 | 35 | /// Buffer Initialization
|
@@ -107,15 +65,17 @@ static auto gemm_run(const at::Tensor &A, const at::Tensor &B, at::Tensor &C,
|
107 | 65 | "Query result for SM count per device: " << hw_info.sm_count);
|
108 | 66 | }
|
109 | 67 |
|
110 |
| - arguments = {cutlass::gemm::GemmUniversalMode::kGemm, |
111 |
| - problem_size, |
112 |
| - {_A, stride_A, _B, stride_B}, |
113 |
| - {{ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, |
114 |
| - nullptr, |
115 |
| - stride_C, |
116 |
| - _C, |
117 |
| - stride_D}, |
118 |
| - hw_info}; |
| 68 | + arguments = GemmConfig::defaultArguments(); |
| 69 | + arguments.mode = cutlass::gemm::GemmUniversalMode::kGemm; |
| 70 | + arguments.problem_shape = problem_size; |
| 71 | + arguments.mainloop = {_A, stride_A, _B, stride_B}; |
| 72 | + arguments.epilogue = { |
| 73 | + {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, |
| 74 | + nullptr, |
| 75 | + stride_C, |
| 76 | + _C, |
| 77 | + stride_D}; |
| 78 | + arguments.hw_info = hw_info; |
119 | 79 |
|
120 | 80 | Gemm gemm_op;
|
121 | 81 |
|
@@ -148,43 +108,19 @@ using GemmRunPtr = int (*)(const at::Tensor &A, const at::Tensor &B,
|
148 | 108 | at::Tensor &C, const int M, const int N, const int K,
|
149 | 109 | const int L);
|
150 | 110 |
|
151 |
| -/// Each entry associates a specific problem dimension to their corresponding |
152 |
| -/// tile shape. For more details, see: |
153 |
| -/// https://github.com/codeplaysoftware/cutlass-sycl/tree/sycl-develop/benchmarks |
154 |
| - |
155 |
| -// clang-format off |
156 |
| -static constexpr std::array<std::pair<Dim, GemmRunPtr>, 18> tile_map = {{ |
157 |
| - { { 1, 1024, 8192, 28672 }, &gemm_run<cute::Shape<cute::_128, cute::_512, cute::_32>> }, |
158 |
| - { { 32, 4096, 128, 4096 }, &gemm_run<cute::Shape<cute::_256, cute::_128, cute::_32>> }, |
159 |
| - { { 4096, 8, 16384, 128 }, &gemm_run<cute::Shape<cute::_128, cute::_256, cute::_16>> }, |
160 |
| - { { 4096, 8, 128, 16384 }, &gemm_run<cute::Shape<cute::_8, cute::_128, cute::_32>> }, |
161 |
| - { { 1, 1, 1024, 4096 }, &gemm_run<cute::Shape<cute::_8, cute::_64, cute::_32>> }, |
162 |
| - { { 1, 1, 4096, 4096 }, &gemm_run<cute::Shape<cute::_8, cute::_128, cute::_32>> }, |
163 |
| - { { 1, 1, 6144, 4096 }, &gemm_run<cute::Shape<cute::_8, cute::_128, cute::_32>> }, |
164 |
| - { { 1, 1, 14336, 4096 }, &gemm_run<cute::Shape<cute::_64, cute::_256, cute::_32>> }, |
165 |
| - { { 1, 1, 28672, 4096 }, &gemm_run<cute::Shape<cute::_32, cute::_128, cute::_32>> }, |
166 |
| - { { 1, 1, 128256, 4096 }, &gemm_run<cute::Shape<cute::_32, cute::_512, cute::_32>> }, |
167 |
| - { { 1, 1, 4096, 14336 }, &gemm_run<cute::Shape<cute::_8, cute::_128, cute::_32>> }, |
168 |
| - { { 1, 8, 1024, 4096 }, &gemm_run<cute::Shape<cute::_8, cute::_64, cute::_32>> }, |
169 |
| - { { 1, 8, 4096, 4096 }, &gemm_run<cute::Shape<cute::_256, cute::_256, cute::_32>> }, |
170 |
| - { { 1, 8, 6144, 4096 }, &gemm_run<cute::Shape<cute::_256, cute::_256, cute::_32>> }, |
171 |
| - { { 1, 8, 14336, 4096 }, &gemm_run<cute::Shape<cute::_64, cute::_256, cute::_32>> }, |
172 |
| - { { 1, 8, 28672, 4096 }, &gemm_run<cute::Shape<cute::_32, cute::_128, cute::_32>> }, |
173 |
| - { { 1, 8, 128256, 4096 }, &gemm_run<cute::Shape<cute::_32, cute::_512, cute::_32>> }, |
174 |
| - { { 1, 8, 4096, 14336 }, &gemm_run<cute::Shape<cute::_256, cute::_256, cute::_32>> }, |
175 |
| -}}; |
176 |
| -// clang-format on |
177 |
| - |
178 |
| -auto gemm(const at::Tensor &A, const at::Tensor &B, at::Tensor &C, const int M, |
179 |
| - const int N, const int K, const int L) -> int { |
| 111 | +// Includes the table mapping problem shape to best config from the header |
| 112 | +// generated by the configuration tool from the CUTLASS config file. |
| 113 | +#include GEMM_CONFIG_HEADER |
| 114 | + |
| 115 | +auto gemm_kernel(const at::Tensor &A, const at::Tensor &B, at::Tensor &C, |
| 116 | + const int M, const int N, const int K, const int L) -> int { |
180 | 117 | const Dim test_case{L, M, N, K};
|
181 | 118 |
|
182 |
| - for (auto const &kv : tile_map) { |
| 119 | + for (auto const &kv : gemm_config) { |
183 | 120 | if (test_case == kv.first) {
|
184 | 121 | return kv.second(A, B, C, M, N, K, L);
|
185 | 122 | }
|
186 | 123 | }
|
187 | 124 |
|
188 |
| - return gemm_run<cute::Shape<cute::_256, cute::_256, cute::_32>>(A, B, C, M, N, |
189 |
| - K, L); |
| 125 | + return gemm_run<PvcGemmBF16BF16FP32_RRR_1>(A, B, C, M, N, K, L); |
190 | 126 | }
|
0 commit comments