Skip to content

Commit 09646ce

Browse files
committed
Update
1 parent 923e9cc commit 09646ce

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

benchmarks/xetla_kernel/python_main.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,13 +309,13 @@ PYBIND11_MODULE(xetla_kernel, m) {
309309
"bf16_gemm_streamk (XeTLA)");
310310
// gemm split k
311311
m.def("gemm_splitk_shape_512_32768_8192",
312-
&bf16_split_k_gemm<512, 32768, 8192, kslicing_impl_t::none>,
312+
&bf16_split_k_gemm<512, 32768, 8192, kslicing_impl_t::global>,
313313
"bf16_gemm_splitk (XeTLA)");
314314
m.def("gemm_splitk_shape_1024_28672_8192",
315-
&bf16_split_k_gemm<1024, 28672, 8192, kslicing_impl_t::none>,
315+
&bf16_split_k_gemm<1024, 28672, 8192, kslicing_impl_t::global>,
316316
"bf16_gemm_splitk (XeTLA)");
317317
m.def("gemm_splitk_shape_3072_4096_3072",
318-
&bf16_split_k_gemm<3072, 4096, 3072, kslicing_impl_t::none>,
318+
&bf16_split_k_gemm<3072, 4096, 3072, kslicing_impl_t::global>,
319319
"bf16_gemm_splitk (XeTLA)");
320320
// flash_attn
321321
m.def("flash_attn_causal_false", &flash_attn<false, false, false>,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
ae46a690bac364a93437e248418636c2a8423134
1+
b9e489ca6a776694a898044a3f2ae023a98db03d

0 commit comments

Comments
 (0)