diff --git a/benchmarks/setup.py b/benchmarks/setup.py index 1692415507..8924c96ef9 100644 --- a/benchmarks/setup.py +++ b/benchmarks/setup.py @@ -6,8 +6,6 @@ # TODO: update once there is replacement for clean: # https://github.com/pypa/setuptools/discussions/2838 from distutils import log # pylint: disable=[deprecated-module] -from distutils.dir_util import remove_tree # pylint: disable=[deprecated-module] -from distutils.command.clean import clean as _clean # pylint: disable=[deprecated-module] from setuptools import setup, Extension from setuptools.command.build_ext import build_ext as _build_ext @@ -24,10 +22,10 @@ def __init__(self, name): class CMakeBuild(): - def __init__(self, debug=False, dry_run=False): + def __init__(self, build_lib, build_temp, debug=False, dry_run=False): self.current_dir = os.path.abspath(os.path.dirname(__file__)) - self.build_temp = self.current_dir + "/build/temp" - self.extdir = self.current_dir + "/triton_kernels_benchmark" + self.build_temp = build_temp + self.extdir = build_lib + "/triton_kernels_benchmark" self.build_type = self.get_build_type(debug) self.cmake_prefix_paths = [torch.utils.cmake_prefix_path] self.use_ipex = False @@ -101,30 +99,20 @@ def build_extension(self): self.check_call(["cmake"] + build_args) self.check_call(["cmake"] + install_args) - def clean(self): - if os.path.exists(self.build_temp): - remove_tree(self.build_temp, dry_run=self.dry_run) - else: - log.warn("'%s' does not exist -- can't clean it", os.path.relpath(self.build_temp, - os.path.dirname(__file__))) - class build_ext(_build_ext): def run(self): - cmake = CMakeBuild(debug=self.debug, dry_run=self.dry_run) + cmake = CMakeBuild( + build_lib=self.build_lib, + build_temp=self.build_temp, + debug=self.debug, + dry_run=self.dry_run, + ) cmake.run() super().run() -class clean(_clean): - - def run(self): - cmake = CMakeBuild(dry_run=self.dry_run) - cmake.clean() - super().run() - - def get_git_commit_hash(length=8): try: cmd = ["git", "rev-parse", f"--short={length}", "HEAD"] @@ -151,11 +139,10 @@ def get_git_commit_hash(length=8): package_data={"triton_kernels_benchmark": ["xetla_kernel.cpython-*.so"]}, cmdclass={ "build_ext": build_ext, - "clean": clean, }, - ext_modules=[CMakeExtension("triton_kernels_benchmark")], - extra_require={ - "ipex": ["numpy<=2.0", "intel-extension-for-pytorch=2.1.10"], + ext_modules=[CMakeExtension("triton_kernels_benchmark.xetla_kernel")], + extras_require={ + "ipex": ["numpy<=2.0", "intel-extension-for-pytorch==2.1.10"], "pytorch": ["torch>=2.6"], }, ) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 606b67af52..10b13b801f 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -4,7 +4,7 @@ import triton.language as tl import triton_kernels_benchmark as benchmark_suit -import xetla_kernel +import triton_kernels_benchmark.xetla_kernel if benchmark_suit.USE_IPEX_OPTION: import intel_extension_for_pytorch # type: ignore # noqa: F401 @@ -262,7 +262,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): elif provider == 'xetla': module_name = f'flash_attn_causal_{CAUSAL}'.lower() - func = getattr(xetla_kernel, module_name) + func = getattr(triton_kernels_benchmark.xetla_kernel, module_name) out = torch.empty_like(q, device='xpu', dtype=dtype) size_score = Z * H * N_CTX * N_CTX size_attn_mask = Z * N_CTX * N_CTX diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index 3f17ac4a55..15f0d55b46 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -13,7 +13,7 @@ from triton.runtime import driver import triton_kernels_benchmark as benchmark_suit -import xetla_kernel +import triton_kernels_benchmark.xetla_kernel @torch.jit.script @@ -140,7 +140,7 @@ def benchmark(M, N, provider): elif provider == "xetla": name = f"softmax_shape_{M}_{N}" - func = getattr(xetla_kernel, name) + func = getattr(triton_kernels_benchmark.xetla_kernel, name) out = torch.empty_like(x, device="xpu") xetla_fn = lambda: func(x, out, 0) torch_fn = lambda: torch.softmax(x, axis=-1) diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index 7e0b339dc6..ea4df2f80e 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -15,7 +15,7 @@ import triton_kernels_benchmark as benchmark_suit from triton_kernels_benchmark.benchmark_testing import do_bench_elapsed_time, BENCHMARKING_METHOD -import xetla_kernel +import triton_kernels_benchmark.xetla_kernel if benchmark_suit.USE_IPEX_OPTION: import intel_extension_for_pytorch # type: ignore # noqa: F401 @@ -313,7 +313,7 @@ def benchmark(B, M, N, K, provider): # better performance. if (B, M, N, K) == (1, 3072, 4096, 3072): name = 'gemm_streamk_shape_3072_4096_3072' - func = getattr(xetla_kernel, name) + func = getattr(triton_kernels_benchmark.xetla_kernel, name) xetla_fn = lambda: func(a, b, c, acc, cnt) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index b6443bf947..580cebfcc3 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -3,7 +3,7 @@ import triton.language as tl import triton_kernels_benchmark as benchmark_suit -import xetla_kernel +import triton_kernels_benchmark.xetla_kernel if benchmark_suit.USE_IPEX_OPTION: import intel_extension_for_pytorch # type: ignore # noqa: F401 @@ -165,7 +165,7 @@ def benchmark(M, N, K, provider): cnt = torch.empty((M, N), device='xpu', dtype=torch.int32) name = f'gemm_splitk_shape_{M}_{K}_{N}' - func = getattr(xetla_kernel, name) + func = getattr(triton_kernels_benchmark.xetla_kernel, name) xetla_fn = lambda: func(a, b, c, acc, cnt) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index a495dca749..a542bbbad9 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -10,7 +10,7 @@ import triton.language as tl import triton_kernels_benchmark as benchmark_suit -import xetla_kernel +import triton_kernels_benchmark.xetla_kernel if benchmark_suit.USE_IPEX_OPTION: import intel_extension_for_pytorch # type: ignore # noqa: F401 @@ -288,7 +288,7 @@ def benchmark(M, N, K, provider): cnt = torch.empty((M, N), device='xpu', dtype=torch.int32) name = f'gemm_streamk_shape_{M}_{K}_{N}' - func = getattr(xetla_kernel, name) + func = getattr(triton_kernels_benchmark.xetla_kernel, name) xetla_fn = lambda: func(a, b, c, acc, cnt) torch_fn = lambda: torch.matmul(a, b).to(torch.float32)