From 47018e7ac609db70c8c640de733fb45495b2a32c Mon Sep 17 00:00:00 2001 From: ESI-SYD Date: Tue, 22 Oct 2024 07:38:24 +0000 Subject: [PATCH] Add xetla splitk gemm --- .../gemm_splitk_benchmark.py | 28 +++-- benchmarks/xetla_kernel/CMakeLists.txt | 1 + benchmarks/xetla_kernel/python_main.cpp | 34 ++++++ .../xetla_kernel/split_k_gemm/CMakeLists.txt | 1 + .../xetla_kernel/split_k_gemm/split_k_gemm.h | 113 ++++++++++++++++++ benchmarks/xetla_kernel/xetla-library.conf | 2 +- 6 files changed, 171 insertions(+), 8 deletions(-) create mode 100644 benchmarks/xetla_kernel/split_k_gemm/CMakeLists.txt create mode 100644 benchmarks/xetla_kernel/split_k_gemm/split_k_gemm.h diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index 4aa1910591..b6443bf947 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -3,6 +3,7 @@ import triton.language as tl import triton_kernels_benchmark as benchmark_suit +import xetla_kernel if benchmark_suit.USE_IPEX_OPTION: import intel_extension_for_pytorch # type: ignore # noqa: F401 @@ -131,9 +132,9 @@ def forward(ctx, a, b, c, acc_dtype=None): line_arg='provider', # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` - line_vals=['triton'], + line_vals=['triton', 'xetla'], # label name for the lines - line_names=['Triton'], + line_names=['Triton', 'XeTLA'], # line styles styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], ylabel=['GB/s', 'TFlops'], # label name for the y-axis @@ -148,23 +149,36 @@ def benchmark(M, N, K, provider): quantiles = [0.5, 0.0, 1.0] if provider == 'onednn': - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10, - quantiles=quantiles) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10, + quantiles=quantiles) elif provider == 'triton': c = torch.empty((M, N), device='xpu', dtype=torch.float32) triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, - kernel_name='_kernel') + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles, kernel_name='_kernel') + elif provider == 'xetla': + c = torch.empty((M, N), device='xpu', dtype=torch.float32) + acc = torch.empty((M, N), device='xpu', dtype=torch.float32) + cnt = torch.empty((M, N), device='xpu', dtype=torch.int32) + + name = f'gemm_splitk_shape_{M}_{K}_{N}' + func = getattr(xetla_kernel, name) + xetla_fn = lambda: func(a, b, c, acc, cnt) + torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + + # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, + quantiles=quantiles, kernel_name='split_k_gemm_run') else: raise NotImplementedError(f'Unsupported provider {provider}') tflops = lambda mean: 2 * M * N * K * (1e-12) / (mean * 1e-3) gbps = lambda mean: 2 * (M * K + K * N) + 4.0 * (M * N) * (1e-9) / (mean * 1e-3) - return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv if __name__ == '__main__': diff --git a/benchmarks/xetla_kernel/CMakeLists.txt b/benchmarks/xetla_kernel/CMakeLists.txt index 439849f5c8..73ab97e5f8 100644 --- a/benchmarks/xetla_kernel/CMakeLists.txt +++ b/benchmarks/xetla_kernel/CMakeLists.txt @@ -45,6 +45,7 @@ endif() add_subdirectory(softmax) add_subdirectory(gemm) add_subdirectory(stream_k_gemm) +add_subdirectory(split_k_gemm) add_subdirectory(flash_attention) install(TARGETS xetla_kernel LIBRARY DESTINATION .) diff --git a/benchmarks/xetla_kernel/python_main.cpp b/benchmarks/xetla_kernel/python_main.cpp index 4a366b3826..80dc03ef51 100644 --- a/benchmarks/xetla_kernel/python_main.cpp +++ b/benchmarks/xetla_kernel/python_main.cpp @@ -2,6 +2,7 @@ #include "flash_attention/fmha_forward_v5.h" #include "gemm/gemm.h" #include "softmax/softmax.h" +#include "split_k_gemm/split_k_gemm.h" #include "stream_k_gemm/stream_k_gemm.h" #include #include @@ -95,6 +96,29 @@ at::Tensor bf16_stream_k_gemm(const at::Tensor &a, const at::Tensor &b, return acc; } +template +at::Tensor bf16_split_k_gemm(const at::Tensor &a, const at::Tensor &b, + const at::Tensor &c, const at::Tensor &acc, + const at::Tensor &cnt) { + CHECK_INPUT(a); + CHECK_INPUT(b); + CHECK_INPUT(c); + CHECK_INPUT(acc); +#ifdef USE_IPEX + RECORD_FUNCTION("xetla split_k_gemm", {}); +#endif + + auto queue = get_current_sycl_queue(); + auto evt = split_k_gemm_run( + a.data_ptr(), b.data_ptr(), c.data_ptr(), acc.data_ptr(), cnt.data_ptr(), + queue); +#ifdef USE_IPEX + xpu::profiler_record("xetla kernel", evt); +#endif + return acc; +} + #define CALL_IMPL_ATTENTION_FWD_FUNC(P) \ fmha::fmha_forward_impl( \ queue, q.data_ptr(), k.data_ptr(), v.data_ptr(), out.data_ptr(), \ @@ -283,6 +307,16 @@ PYBIND11_MODULE(xetla_kernel, m) { // gemm stream k m.def("gemm_streamk_shape_3072_4096_3072", &bf16_stream_k_gemm, "bf16_gemm_streamk (XeTLA)"); + // gemm split k + m.def("gemm_splitk_shape_512_32768_8192", + &bf16_split_k_gemm<512, 32768, 8192, kslicing_impl_t::global>, + "bf16_gemm_splitk (XeTLA)"); + m.def("gemm_splitk_shape_1024_28672_8192", + &bf16_split_k_gemm<1024, 28672, 8192, kslicing_impl_t::global>, + "bf16_gemm_splitk (XeTLA)"); + m.def("gemm_splitk_shape_3072_4096_3072", + &bf16_split_k_gemm<3072, 4096, 3072, kslicing_impl_t::global>, + "bf16_gemm_splitk (XeTLA)"); // flash_attn m.def("flash_attn_causal_false", &flash_attn, "flash attn fwd (XeTLA)"); diff --git a/benchmarks/xetla_kernel/split_k_gemm/CMakeLists.txt b/benchmarks/xetla_kernel/split_k_gemm/CMakeLists.txt new file mode 100644 index 0000000000..9916e9d828 --- /dev/null +++ b/benchmarks/xetla_kernel/split_k_gemm/CMakeLists.txt @@ -0,0 +1 @@ +target_include_directories(xetla_kernel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/benchmarks/xetla_kernel/split_k_gemm/split_k_gemm.h b/benchmarks/xetla_kernel/split_k_gemm/split_k_gemm.h new file mode 100644 index 0000000000..a9260118cb --- /dev/null +++ b/benchmarks/xetla_kernel/split_k_gemm/split_k_gemm.h @@ -0,0 +1,113 @@ +/******************************************************************************* + * Copyright (c) 2023-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ +#ifndef TRITONBENCHMARK_SPLIT_K_GEMM_H +#define TRITONBENCHMARK_SPLIT_K_GEMM_H + +#include "xetla.hpp" +#include + +enum class kslicing_impl_t : uint8_t { none = 0, global = 1, local = 2 }; + +template +sycl::event split_k_gemm_run(void *_A, void *_B, void *_C, void *_Acc, + void *_Cnt, sycl::queue &queue) { + + // GEMM_UNIVERSAL input size + size_t matrix_m = m; + size_t matrix_n = n; + size_t matrix_k = k; + + size_t size_a = matrix_m * matrix_k; + size_t size_b = matrix_k * matrix_n; + size_t size_c = matrix_m * matrix_n; + + using data_type_a = sycl::ext::oneapi::bfloat16; + using data_type_b = sycl::ext::oneapi::bfloat16; + using data_type_c = float; + using data_type_acc = float; + + data_type_a *A = static_cast(_A); + data_type_b *B = static_cast(_B); + data_type_c *C = static_cast(_C); + + // Define the shape of workgroup + // It's tunable parameters based on different input shape and hardware for + // better performance + constexpr uint32_t wg_tile_m = + (kslicing_type != kslicing_impl_t::local) ? 256 : 64; + constexpr uint32_t wg_tile_n = + (kslicing_type != kslicing_impl_t::local) ? 256 : 128; + + // specify the range k_w/k_s by setting the corresponding ratio + // splitk using global memory + constexpr uint32_t num_global_splitk = + (kslicing_type == kslicing_impl_t::global) ? 2 : 1; + // splitk using local memory + constexpr uint32_t num_local_splitk = + (kslicing_type == kslicing_impl_t::local) ? 2 : 1; + + // Mirco-kernel configuration + using tune_option = + dict_t, + elem_t_t, + elem_v_t, + elem_v_t, + elem_v_t, + elem_t_t>>; + using gemm_op_t = gpu::xetla::kernel::default_gemm_t< + data_type_a, // input datatype for A + mem_layout::row_major, // memory layout for A + 8, // leading dimension alignment for A, in unit of element + data_type_b, // input datatype for B + mem_layout::row_major, // memory layout for B + 8, // leading dimension alignment for B, in unit of element + data_type_c, // output datatype for C + mem_layout::row_major, // memory layout for C + 8, // leading dimension alignment for C, in unit of element + data_type_acc, // accumulator data type for intermediate resutls + gpu_arch::Xe, // GPU arch + tune_option>; + + // allocate temp buffers for global split + size_t size_acc = gemm_op_t::get_acc_buf_size(matrix_m, matrix_n); + size_t size_cnt = gemm_op_t::get_cnt_buf_size(matrix_m, matrix_n); + + data_type_acc *Acc = static_cast(_Acc); + uint32_t *Cnt = static_cast(_Cnt); + + // set up gemm_universal arguments + typename gemm_op_t::arguments_t gemm_arg(matrix_m, matrix_k, matrix_n, A, + matrix_k, B, matrix_n, C, matrix_n, + Acc, Cnt); + + cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg); + + auto gpu_event = queue.submit([&](sycl::handler &cgh) { + // GPU kernel + cgh.parallel_for(nd_range, [=](sycl::nd_item<3> item) KERNEL_MAIN { + // allocate slm and nbarrier resource + slm_barrier_init(); + gemm_op_t gemm_op; + gemm_op(item, gemm_arg); + }); + }); + return gpu_event; +} + +#endif // TRITONBENCHMARK_SPLIT_K_GEMM_H diff --git a/benchmarks/xetla_kernel/xetla-library.conf b/benchmarks/xetla_kernel/xetla-library.conf index 81489c2c7c..2cc1e9f5b3 100644 --- a/benchmarks/xetla_kernel/xetla-library.conf +++ b/benchmarks/xetla_kernel/xetla-library.conf @@ -1 +1 @@ -ae46a690bac364a93437e248418636c2a8423134 +b9e489ca6a776694a898044a3f2ae023a98db03d