Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions benchmarks/xetla_kernel/python_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifdef USE_IPEX
#include <ipex.h>
#else
#include <ATen/record_function.h>
#include <c10/xpu/XPUStream.h>
#endif

Expand Down Expand Up @@ -43,9 +44,7 @@ at::Tensor softmax(const at::Tensor &input, const at::Tensor &output,
const int64_t dim) {
CHECK_INPUT(input);
CHECK_INPUT(output);
#ifdef USE_IPEX
RECORD_FUNCTION("xetla softmax", {});
#endif

auto queue = get_current_sycl_queue();
auto evt = softmax_forward<T>(input.data_ptr(), output.data_ptr(), queue);
Expand All @@ -63,9 +62,7 @@ at::Tensor bf16_gemm(const at::Tensor &a, const at::Tensor &b,
CHECK_INPUT(b);
CHECK_INPUT(c);
CHECK_INPUT(acc);
#ifdef USE_IPEX
RECORD_FUNCTION("xetla gemm", {});
#endif

auto queue = get_current_sycl_queue();
auto evt = gemm_run<T>(a.data_ptr(), b.data_ptr(), c.data_ptr(),
Expand All @@ -83,9 +80,7 @@ at::Tensor bf16_stream_k_gemm(const at::Tensor &a, const at::Tensor &b,
CHECK_INPUT(b);
CHECK_INPUT(c);
CHECK_INPUT(acc);
#ifdef USE_IPEX
RECORD_FUNCTION("xetla stream_k_gemm", {});
#endif

auto queue = get_current_sycl_queue();
auto evt = stream_k_gemm_run(a.data_ptr(), b.data_ptr(), c.data_ptr(),
Expand All @@ -105,9 +100,7 @@ at::Tensor bf16_split_k_gemm(const at::Tensor &a, const at::Tensor &b,
CHECK_INPUT(b);
CHECK_INPUT(c);
CHECK_INPUT(acc);
#ifdef USE_IPEX
RECORD_FUNCTION("xetla split_k_gemm", {});
#endif

auto queue = get_current_sycl_queue();
auto evt = split_k_gemm_run<m, k, n, kslicing_type>(
Expand Down Expand Up @@ -143,9 +136,7 @@ void flash_attn(const at::Tensor &q, const at::Tensor &k, const at::Tensor &v,
CHECK_INPUT(bias);
CHECK_INPUT(m);
CHECK_INPUT(l);
#ifdef USE_IPEX
RECORD_FUNCTION("xetla fa", {});
#endif

auto queue = get_current_sycl_queue();

Expand Down Expand Up @@ -212,9 +203,7 @@ void flash_attn_bwd(const at::Tensor &grad_out, const at::Tensor &q,
CHECK_INPUT(grad_value);
CHECK_INPUT(grad_bias);

#ifdef USE_IPEX
RECORD_FUNCTION("xetla fa", {});
#endif

auto queue = get_current_sycl_queue();

Expand Down