Skip to content

Commit 1826d9e

Browse files
committed
Add 4k to XeTLA
1 parent fa6cc70 commit 1826d9e

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

benchmarks/xetla_kernel/python_main.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,9 @@ PYBIND11_MODULE(xetla_kernel, m) {
317317
m.def("gemm_splitk_shape_3072_4096_3072",
318318
&bf16_split_k_gemm<3072, 4096, 3072, kslicing_impl_t::global>,
319319
"bf16_gemm_splitk (XeTLA)");
320+
m.def("gemm_splitk_shape_4096_4096_4096",
321+
&bf16_split_k_gemm<4096, 4096, 4096, kslicing_impl_t::global>,
322+
"bf16_gemm_splitk (XeTLA)");
320323
// flash_attn
321324
m.def("flash_attn_causal_false", &flash_attn<false, false, false>,
322325
"flash attn fwd (XeTLA)");

0 commit comments

Comments
 (0)