Skip to content

Commit 923e9cc

Browse files
committed
Add xetla splitk gemm
1 parent b6cdccd commit 923e9cc

File tree

5 files changed

+170
-7
lines changed

5 files changed

+170
-7
lines changed

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import triton.language as tl
44

55
import triton_kernels_benchmark as benchmark_suit
6+
import xetla_kernel
67

78
if benchmark_suit.USE_IPEX_OPTION:
89
import intel_extension_for_pytorch # type: ignore # noqa: F401
@@ -131,9 +132,9 @@ def forward(ctx, a, b, c, acc_dtype=None):
131132
line_arg='provider',
132133
# argument name whose value corresponds to a different line in the plot
133134
# possible values for `line_arg``
134-
line_vals=['triton'],
135+
line_vals=['triton', 'xetla'],
135136
# label name for the lines
136-
line_names=['Triton'],
137+
line_names=['Triton', 'XeTLA'],
137138
# line styles
138139
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
139140
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
@@ -148,23 +149,36 @@ def benchmark(M, N, K, provider):
148149
quantiles = [0.5, 0.0, 1.0]
149150

150151
if provider == 'onednn':
151-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
152-
quantiles=quantiles)
152+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
153+
quantiles=quantiles)
153154
elif provider == 'triton':
154155
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
155156
triton_fn = lambda: matmul(a, b, c)
156157
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
157158
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
158159
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
159-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
160-
kernel_name='_kernel')
160+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
161+
quantiles=quantiles, kernel_name='_kernel')
162+
elif provider == 'xetla':
163+
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
164+
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
165+
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
166+
167+
name = f'gemm_splitk_shape_{M}_{K}_{N}'
168+
func = getattr(xetla_kernel, name)
169+
xetla_fn = lambda: func(a, b, c, acc, cnt)
170+
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
171+
172+
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
173+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
174+
quantiles=quantiles, kernel_name='split_k_gemm_run')
161175
else:
162176
raise NotImplementedError(f'Unsupported provider {provider}')
163177

164178
tflops = lambda mean: 2 * M * N * K * (1e-12) / (mean * 1e-3)
165179
gbps = lambda mean: 2 * (M * K + K * N) + 4.0 * (M * N) * (1e-9) / (mean * 1e-3)
166180

167-
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
181+
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv
168182

169183

170184
if __name__ == '__main__':

benchmarks/xetla_kernel/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ endif()
4545
add_subdirectory(softmax)
4646
add_subdirectory(gemm)
4747
add_subdirectory(stream_k_gemm)
48+
add_subdirectory(split_k_gemm)
4849
add_subdirectory(flash_attention)
4950

5051
install(TARGETS xetla_kernel LIBRARY DESTINATION .)

benchmarks/xetla_kernel/python_main.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "flash_attention/fmha_forward_v5.h"
33
#include "gemm/gemm.h"
44
#include "softmax/softmax.h"
5+
#include "split_k_gemm/split_k_gemm.h"
56
#include "stream_k_gemm/stream_k_gemm.h"
67
#include <CL/sycl.hpp>
78
#include <c10/core/ScalarType.h>
@@ -95,6 +96,29 @@ at::Tensor bf16_stream_k_gemm(const at::Tensor &a, const at::Tensor &b,
9596
return acc;
9697
}
9798

99+
template <int m, int k, int n,
100+
kslicing_impl_t kslicing_type = kslicing_impl_t::none>
101+
at::Tensor bf16_split_k_gemm(const at::Tensor &a, const at::Tensor &b,
102+
const at::Tensor &c, const at::Tensor &acc,
103+
const at::Tensor &cnt) {
104+
CHECK_INPUT(a);
105+
CHECK_INPUT(b);
106+
CHECK_INPUT(c);
107+
CHECK_INPUT(acc);
108+
#ifdef USE_IPEX
109+
RECORD_FUNCTION("xetla split_k_gemm", {});
110+
#endif
111+
112+
auto queue = get_current_sycl_queue();
113+
auto evt = split_k_gemm_run<m, k, n, kslicing_type>(
114+
a.data_ptr(), b.data_ptr(), c.data_ptr(), acc.data_ptr(), cnt.data_ptr(),
115+
queue);
116+
#ifdef USE_IPEX
117+
xpu::profiler_record("xetla kernel", evt);
118+
#endif
119+
return acc;
120+
}
121+
98122
#define CALL_IMPL_ATTENTION_FWD_FUNC(P) \
99123
fmha::fmha_forward_impl<P, T, use_mask, IsCausal, use_dropout>( \
100124
queue, q.data_ptr(), k.data_ptr(), v.data_ptr(), out.data_ptr(), \
@@ -283,6 +307,16 @@ PYBIND11_MODULE(xetla_kernel, m) {
283307
// gemm stream k
284308
m.def("gemm_streamk_shape_3072_4096_3072", &bf16_stream_k_gemm,
285309
"bf16_gemm_streamk (XeTLA)");
310+
// gemm split k
311+
m.def("gemm_splitk_shape_512_32768_8192",
312+
&bf16_split_k_gemm<512, 32768, 8192, kslicing_impl_t::none>,
313+
"bf16_gemm_splitk (XeTLA)");
314+
m.def("gemm_splitk_shape_1024_28672_8192",
315+
&bf16_split_k_gemm<1024, 28672, 8192, kslicing_impl_t::none>,
316+
"bf16_gemm_splitk (XeTLA)");
317+
m.def("gemm_splitk_shape_3072_4096_3072",
318+
&bf16_split_k_gemm<3072, 4096, 3072, kslicing_impl_t::none>,
319+
"bf16_gemm_splitk (XeTLA)");
286320
// flash_attn
287321
m.def("flash_attn_causal_false", &flash_attn<false, false, false>,
288322
"flash attn fwd (XeTLA)");
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
target_include_directories(xetla_kernel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2023-2024 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
#ifndef TRITONBENCHMARK_SPLIT_K_GEMM_H
17+
#define TRITONBENCHMARK_SPLIT_K_GEMM_H
18+
19+
#include "xetla.hpp"
20+
#include <sycl.hpp>
21+
22+
enum class kslicing_impl_t : uint8_t { none = 0, global = 1, local = 2 };
23+
24+
template <int m, int k, int n,
25+
kslicing_impl_t kslicing_type = kslicing_impl_t::none>
26+
sycl::event split_k_gemm_run(void *_A, void *_B, void *_C, void *_Acc,
27+
void *_Cnt, sycl::queue &queue) {
28+
29+
// GEMM_UNIVERSAL input size
30+
size_t matrix_m = m;
31+
size_t matrix_n = n;
32+
size_t matrix_k = k;
33+
34+
size_t size_a = matrix_m * matrix_k;
35+
size_t size_b = matrix_k * matrix_n;
36+
size_t size_c = matrix_m * matrix_n;
37+
38+
using data_type_a = sycl::ext::oneapi::bfloat16;
39+
using data_type_b = sycl::ext::oneapi::bfloat16;
40+
using data_type_c = float;
41+
using data_type_acc = float;
42+
43+
data_type_a *A = static_cast<data_type_a *>(_A);
44+
data_type_b *B = static_cast<data_type_b *>(_B);
45+
data_type_c *C = static_cast<data_type_c *>(_C);
46+
47+
// Define the shape of workgroup
48+
// It's tunable parameters based on different input shape and hardware for
49+
// better performance
50+
constexpr uint32_t wg_tile_m =
51+
(kslicing_type != kslicing_impl_t::local) ? 256 : 64;
52+
constexpr uint32_t wg_tile_n =
53+
(kslicing_type != kslicing_impl_t::local) ? 256 : 128;
54+
55+
// specify the range k_w/k_s by setting the corresponding ratio
56+
// splitk using global memory
57+
constexpr uint32_t num_global_splitk =
58+
(kslicing_type == kslicing_impl_t::global) ? 2 : 1;
59+
// splitk using local memory
60+
constexpr uint32_t num_local_splitk =
61+
(kslicing_type == kslicing_impl_t::local) ? 2 : 1;
62+
63+
// Mirco-kernel configuration
64+
using tune_option =
65+
dict_t<elem_v_t<tune_key::param_optimizer_type,
66+
tune_key_value::param_optimizer_decision_tree>,
67+
elem_t_t<tune_key::data_type_acc, data_type_acc>,
68+
elem_v_t<tune_key::dispatch_policy,
69+
tune_key_value::dispatch_policy_kslicing>,
70+
elem_v_t<tune_key::global_kslicing_ratio, num_global_splitk>,
71+
elem_v_t<tune_key::local_kslicing_ratio, num_local_splitk>,
72+
elem_t_t<tune_key::wg_tile_shape, shape<wg_tile_n, wg_tile_m>>>;
73+
using gemm_op_t = gpu::xetla::kernel::default_gemm_t<
74+
data_type_a, // input datatype for A
75+
mem_layout::row_major, // memory layout for A
76+
8, // leading dimension alignment for A, in unit of element
77+
data_type_b, // input datatype for B
78+
mem_layout::row_major, // memory layout for B
79+
8, // leading dimension alignment for B, in unit of element
80+
data_type_c, // output datatype for C
81+
mem_layout::row_major, // memory layout for C
82+
8, // leading dimension alignment for C, in unit of element
83+
data_type_acc, // accumulator data type for intermediate resutls
84+
gpu_arch::Xe, // GPU arch
85+
tune_option>;
86+
87+
// allocate temp buffers for global split
88+
size_t size_acc = gemm_op_t::get_acc_buf_size(matrix_m, matrix_n);
89+
size_t size_cnt = gemm_op_t::get_cnt_buf_size(matrix_m, matrix_n);
90+
91+
data_type_acc *Acc = static_cast<data_type_acc *>(_Acc);
92+
uint32_t *Cnt = static_cast<uint32_t *>(_Cnt);
93+
94+
// set up gemm_universal arguments
95+
typename gemm_op_t::arguments_t gemm_arg(matrix_m, matrix_k, matrix_n, A,
96+
matrix_k, B, matrix_n, C, matrix_n,
97+
Acc, Cnt);
98+
99+
cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg);
100+
101+
auto gpu_event = queue.submit([&](sycl::handler &cgh) {
102+
// GPU kernel
103+
cgh.parallel_for(nd_range, [=](sycl::nd_item<3> item) KERNEL_MAIN {
104+
// allocate slm and nbarrier resource
105+
slm_barrier_init<gemm_op_t>();
106+
gemm_op_t gemm_op;
107+
gemm_op(item, gemm_arg);
108+
});
109+
});
110+
return gpu_event;
111+
}
112+
113+
#endif // TRITONBENCHMARK_SPLIT_K_GEMM_H

0 commit comments

Comments
 (0)