Skip to content

Commit 77d819c

Browse files
authored
Align usage of RECORD_FUNCTION for xetla (#2371)
Closes #2373 Motivation: `NOTE: passing the inputs incurs an additional overhead` (that is, for a fairer comparison, we should not use this parameter) : https://github.com/Stonepia/pytorch/blob/4de58719fbb5c681305622bd0e22997c9ece52b0/aten/src/ATen/record_function.h#L110 ```c++ /** * RecordFunctionCallback represents a pair of callbacks to be used with * RecordFunction, members: * start, end - the callbacks to run when entering and exiting the scope; * optionally, the start callback may return an ObserverContext which will * be passed to the end callback, use appropriate constructor accordingly. * needs_inputs - whether the callbacks need the inputs passed from the * observed function/range; NOTE: passing the inputs incurs an additional * overhead; sampling_probability - if not 1.0, then the callback is * probabilistically sampled to run; NOTE: start and end callbacks always run as * a pair and are sampled together; scopes - types of scopes to execute the * callbacks on (see RecordScope); passing empty set means the callbacks will be * executed for all possible scope types should_run - optional function that * returns whether this callback should run; overwrites the effect of setting * sampling_probability */ ``` CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/11071744714/job/30765931643 Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 427f3c4 commit 77d819c

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

benchmarks/xetla_kernel/python_main.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ at::Tensor softmax(const at::Tensor &input, const at::Tensor &output,
4242
CHECK_INPUT(input);
4343
CHECK_INPUT(output);
4444
#ifdef USE_IPEX
45-
RECORD_FUNCTION("xetla softmax", {input});
45+
RECORD_FUNCTION("xetla softmax", {});
4646
#endif
4747

4848
auto queue = get_current_sycl_queue();
@@ -62,7 +62,7 @@ at::Tensor bf16_gemm(const at::Tensor &a, const at::Tensor &b,
6262
CHECK_INPUT(c);
6363
CHECK_INPUT(acc);
6464
#ifdef USE_IPEX
65-
RECORD_FUNCTION("xetla gemm", {a, b, c, acc});
65+
RECORD_FUNCTION("xetla gemm", {});
6666
#endif
6767

6868
auto queue = get_current_sycl_queue();
@@ -82,7 +82,7 @@ at::Tensor bf16_stream_k_gemm(const at::Tensor &a, const at::Tensor &b,
8282
CHECK_INPUT(c);
8383
CHECK_INPUT(acc);
8484
#ifdef USE_IPEX
85-
RECORD_FUNCTION("xetla stream_k_gemm", {a, b, c, acc});
85+
RECORD_FUNCTION("xetla stream_k_gemm", {});
8686
#endif
8787

8888
auto queue = get_current_sycl_queue();
@@ -119,8 +119,7 @@ void flash_attn(const at::Tensor &q, const at::Tensor &k, const at::Tensor &v,
119119
CHECK_INPUT(m);
120120
CHECK_INPUT(l);
121121
#ifdef USE_IPEX
122-
RECORD_FUNCTION("xetla fa",
123-
{num_batches, num_heads, head_size, num_queries, num_keys});
122+
RECORD_FUNCTION("xetla fa", {});
124123
#endif
125124

126125
auto queue = get_current_sycl_queue();

0 commit comments

Comments
 (0)