Skip to content

Commit 622abfc

Browse files
committed
Fix xetla import for wheel
1 parent 1bc283c commit 622abfc

File tree

5 files changed

+10
-10
lines changed

5 files changed

+10
-10
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import triton.language as tl
55

66
import triton_kernels_benchmark as benchmark_suit
7-
import xetla_kernel
7+
import triton_kernels_benchmark.xetla_kernel
88

99
if benchmark_suit.USE_IPEX_OPTION:
1010
import intel_extension_for_pytorch # type: ignore # noqa: F401
@@ -262,7 +262,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
262262

263263
elif provider == 'xetla':
264264
module_name = f'flash_attn_causal_{CAUSAL}'.lower()
265-
func = getattr(xetla_kernel, module_name)
265+
func = getattr(triton_kernels_benchmark.xetla_kernel, module_name)
266266
out = torch.empty_like(q, device='xpu', dtype=dtype)
267267
size_score = Z * H * N_CTX * N_CTX
268268
size_attn_mask = Z * N_CTX * N_CTX

benchmarks/triton_kernels_benchmark/fused_softmax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from triton.runtime import driver
1414

1515
import triton_kernels_benchmark as benchmark_suit
16-
import xetla_kernel
16+
import triton_kernels_benchmark.xetla_kernel
1717

1818

1919
@torch.jit.script
@@ -140,7 +140,7 @@ def benchmark(M, N, provider):
140140

141141
elif provider == "xetla":
142142
name = f"softmax_shape_{M}_{N}"
143-
func = getattr(xetla_kernel, name)
143+
func = getattr(triton_kernels_benchmark.xetla_kernel, name)
144144
out = torch.empty_like(x, device="xpu")
145145
xetla_fn = lambda: func(x, out, 0)
146146
torch_fn = lambda: torch.softmax(x, axis=-1)

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import triton_kernels_benchmark as benchmark_suit
1616
from triton_kernels_benchmark.benchmark_testing import do_bench_elapsed_time, BENCHMARKING_METHOD
1717

18-
import xetla_kernel
18+
import triton_kernels_benchmark.xetla_kernel
1919

2020
if benchmark_suit.USE_IPEX_OPTION:
2121
import intel_extension_for_pytorch # type: ignore # noqa: F401
@@ -313,7 +313,7 @@ def benchmark(B, M, N, K, provider):
313313
# better performance.
314314
if (B, M, N, K) == (1, 3072, 4096, 3072):
315315
name = 'gemm_streamk_shape_3072_4096_3072'
316-
func = getattr(xetla_kernel, name)
316+
func = getattr(triton_kernels_benchmark.xetla_kernel, name)
317317
xetla_fn = lambda: func(a, b, c, acc, cnt)
318318
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
319319

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import triton.language as tl
44

55
import triton_kernels_benchmark as benchmark_suit
6-
import xetla_kernel
6+
import triton_kernels_benchmark.xetla_kernel
77

88
if benchmark_suit.USE_IPEX_OPTION:
99
import intel_extension_for_pytorch # type: ignore # noqa: F401
@@ -165,7 +165,7 @@ def benchmark(M, N, K, provider):
165165
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
166166

167167
name = f'gemm_splitk_shape_{M}_{K}_{N}'
168-
func = getattr(xetla_kernel, name)
168+
func = getattr(triton_kernels_benchmark.xetla_kernel, name)
169169
xetla_fn = lambda: func(a, b, c, acc, cnt)
170170
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
171171

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import triton.language as tl
1111

1212
import triton_kernels_benchmark as benchmark_suit
13-
import xetla_kernel
13+
import triton_kernels_benchmark.xetla_kernel
1414

1515
if benchmark_suit.USE_IPEX_OPTION:
1616
import intel_extension_for_pytorch # type: ignore # noqa: F401
@@ -288,7 +288,7 @@ def benchmark(M, N, K, provider):
288288
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
289289

290290
name = f'gemm_streamk_shape_{M}_{K}_{N}'
291-
func = getattr(xetla_kernel, name)
291+
func = getattr(triton_kernels_benchmark.xetla_kernel, name)
292292
xetla_fn = lambda: func(a, b, c, acc, cnt)
293293
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
294294

0 commit comments

Comments
 (0)