Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit
import xetla_kernel

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401
Expand Down Expand Up @@ -131,9 +132,9 @@ def forward(ctx, a, b, c, acc_dtype=None):
line_arg='provider',
# argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
line_vals=['triton'],
line_vals=['triton', 'xetla'],
# label name for the lines
line_names=['Triton'],
line_names=['Triton', 'XeTLA'],
# line styles
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
Expand All @@ -148,23 +149,36 @@ def benchmark(M, N, K, provider):
quantiles = [0.5, 0.0, 1.0]

if provider == 'onednn':
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
quantiles=quantiles)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
quantiles=quantiles)
elif provider == 'triton':
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
kernel_name='_kernel')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name='_kernel')
elif provider == 'xetla':
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)

name = f'gemm_splitk_shape_{M}_{K}_{N}'
func = getattr(xetla_kernel, name)
xetla_fn = lambda: func(a, b, c, acc, cnt)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name='split_k_gemm_run')
else:
raise NotImplementedError(f'Unsupported provider {provider}')

tflops = lambda mean: 2 * M * N * K * (1e-12) / (mean * 1e-3)
gbps = lambda mean: 2 * (M * K + K * N) + 4.0 * (M * N) * (1e-9) / (mean * 1e-3)

return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions benchmarks/xetla_kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ endif()
add_subdirectory(softmax)
add_subdirectory(gemm)
add_subdirectory(stream_k_gemm)
add_subdirectory(split_k_gemm)
add_subdirectory(flash_attention)

install(TARGETS xetla_kernel LIBRARY DESTINATION .)
34 changes: 34 additions & 0 deletions benchmarks/xetla_kernel/python_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "flash_attention/fmha_forward_v5.h"
#include "gemm/gemm.h"
#include "softmax/softmax.h"
#include "split_k_gemm/split_k_gemm.h"
#include "stream_k_gemm/stream_k_gemm.h"
#include <CL/sycl.hpp>
#include <c10/core/ScalarType.h>
Expand Down Expand Up @@ -95,6 +96,29 @@ at::Tensor bf16_stream_k_gemm(const at::Tensor &a, const at::Tensor &b,
return acc;
}

template <int m, int k, int n,
kslicing_impl_t kslicing_type = kslicing_impl_t::none>
at::Tensor bf16_split_k_gemm(const at::Tensor &a, const at::Tensor &b,
const at::Tensor &c, const at::Tensor &acc,
const at::Tensor &cnt) {
CHECK_INPUT(a);
CHECK_INPUT(b);
CHECK_INPUT(c);
CHECK_INPUT(acc);
#ifdef USE_IPEX
RECORD_FUNCTION("xetla split_k_gemm", {});
#endif

auto queue = get_current_sycl_queue();
auto evt = split_k_gemm_run<m, k, n, kslicing_type>(
a.data_ptr(), b.data_ptr(), c.data_ptr(), acc.data_ptr(), cnt.data_ptr(),
queue);
#ifdef USE_IPEX
xpu::profiler_record("xetla kernel", evt);
#endif
return acc;
}

#define CALL_IMPL_ATTENTION_FWD_FUNC(P) \
fmha::fmha_forward_impl<P, T, use_mask, IsCausal, use_dropout>( \
queue, q.data_ptr(), k.data_ptr(), v.data_ptr(), out.data_ptr(), \
Expand Down Expand Up @@ -283,6 +307,16 @@ PYBIND11_MODULE(xetla_kernel, m) {
// gemm stream k
m.def("gemm_streamk_shape_3072_4096_3072", &bf16_stream_k_gemm,
"bf16_gemm_streamk (XeTLA)");
// gemm split k
m.def("gemm_splitk_shape_512_32768_8192",
&bf16_split_k_gemm<512, 32768, 8192, kslicing_impl_t::global>,
"bf16_gemm_splitk (XeTLA)");
m.def("gemm_splitk_shape_1024_28672_8192",
&bf16_split_k_gemm<1024, 28672, 8192, kslicing_impl_t::global>,
"bf16_gemm_splitk (XeTLA)");
m.def("gemm_splitk_shape_3072_4096_3072",
&bf16_split_k_gemm<3072, 4096, 3072, kslicing_impl_t::global>,
"bf16_gemm_splitk (XeTLA)");
// flash_attn
m.def("flash_attn_causal_false", &flash_attn<false, false, false>,
"flash attn fwd (XeTLA)");
Expand Down
1 change: 1 addition & 0 deletions benchmarks/xetla_kernel/split_k_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
target_include_directories(xetla_kernel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
113 changes: 113 additions & 0 deletions benchmarks/xetla_kernel/split_k_gemm/split_k_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*******************************************************************************
* Copyright (c) 2023-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef TRITONBENCHMARK_SPLIT_K_GEMM_H
#define TRITONBENCHMARK_SPLIT_K_GEMM_H

#include "xetla.hpp"
#include <sycl.hpp>

enum class kslicing_impl_t : uint8_t { none = 0, global = 1, local = 2 };

template <int m, int k, int n,
kslicing_impl_t kslicing_type = kslicing_impl_t::none>
sycl::event split_k_gemm_run(void *_A, void *_B, void *_C, void *_Acc,
void *_Cnt, sycl::queue &queue) {

// GEMM_UNIVERSAL input size
size_t matrix_m = m;
size_t matrix_n = n;
size_t matrix_k = k;

size_t size_a = matrix_m * matrix_k;
size_t size_b = matrix_k * matrix_n;
size_t size_c = matrix_m * matrix_n;

using data_type_a = sycl::ext::oneapi::bfloat16;
using data_type_b = sycl::ext::oneapi::bfloat16;
using data_type_c = float;
using data_type_acc = float;

data_type_a *A = static_cast<data_type_a *>(_A);
data_type_b *B = static_cast<data_type_b *>(_B);
data_type_c *C = static_cast<data_type_c *>(_C);

// Define the shape of workgroup
// It's tunable parameters based on different input shape and hardware for
// better performance
constexpr uint32_t wg_tile_m =
(kslicing_type != kslicing_impl_t::local) ? 256 : 64;
constexpr uint32_t wg_tile_n =
(kslicing_type != kslicing_impl_t::local) ? 256 : 128;

// specify the range k_w/k_s by setting the corresponding ratio
// splitk using global memory
constexpr uint32_t num_global_splitk =
(kslicing_type == kslicing_impl_t::global) ? 2 : 1;
// splitk using local memory
constexpr uint32_t num_local_splitk =
(kslicing_type == kslicing_impl_t::local) ? 2 : 1;

// Mirco-kernel configuration
using tune_option =
dict_t<elem_v_t<tune_key::param_optimizer_type,
tune_key_value::param_optimizer_decision_tree>,
elem_t_t<tune_key::data_type_acc, data_type_acc>,
elem_v_t<tune_key::dispatch_policy,
tune_key_value::dispatch_policy_kslicing>,
elem_v_t<tune_key::global_kslicing_ratio, num_global_splitk>,
elem_v_t<tune_key::local_kslicing_ratio, num_local_splitk>,
elem_t_t<tune_key::wg_tile_shape, shape<wg_tile_n, wg_tile_m>>>;
using gemm_op_t = gpu::xetla::kernel::default_gemm_t<
data_type_a, // input datatype for A
mem_layout::row_major, // memory layout for A
8, // leading dimension alignment for A, in unit of element
data_type_b, // input datatype for B
mem_layout::row_major, // memory layout for B
8, // leading dimension alignment for B, in unit of element
data_type_c, // output datatype for C
mem_layout::row_major, // memory layout for C
8, // leading dimension alignment for C, in unit of element
data_type_acc, // accumulator data type for intermediate resutls
gpu_arch::Xe, // GPU arch
tune_option>;

// allocate temp buffers for global split
size_t size_acc = gemm_op_t::get_acc_buf_size(matrix_m, matrix_n);
size_t size_cnt = gemm_op_t::get_cnt_buf_size(matrix_m, matrix_n);

data_type_acc *Acc = static_cast<data_type_acc *>(_Acc);
uint32_t *Cnt = static_cast<uint32_t *>(_Cnt);

// set up gemm_universal arguments
typename gemm_op_t::arguments_t gemm_arg(matrix_m, matrix_k, matrix_n, A,
matrix_k, B, matrix_n, C, matrix_n,
Acc, Cnt);

cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg);

auto gpu_event = queue.submit([&](sycl::handler &cgh) {
// GPU kernel
cgh.parallel_for(nd_range, [=](sycl::nd_item<3> item) KERNEL_MAIN {
// allocate slm and nbarrier resource
slm_barrier_init<gemm_op_t>();
gemm_op_t gemm_op;
gemm_op(item, gemm_arg);
});
});
return gpu_event;
}

#endif // TRITONBENCHMARK_SPLIT_K_GEMM_H
2 changes: 1 addition & 1 deletion benchmarks/xetla_kernel/xetla-library.conf
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ae46a690bac364a93437e248418636c2a8423134
b9e489ca6a776694a898044a3f2ae023a98db03d