From 79053e8bc6b21c9dc6f3e63a8cbe806bc8b4fa2f Mon Sep 17 00:00:00 2001 From: Pavel Chekin Date: Tue, 29 Oct 2024 07:44:36 -0700 Subject: [PATCH] Fix xetla import when built as wheel Fixes #2576. --- .pre-commit-config.yaml | 1 + benchmarks/setup.py | 37 ++++++------------- .../flash_attention_fwd_benchmark.py | 2 +- .../triton_kernels_benchmark/fused_softmax.py | 2 +- .../gemm_benchmark.py | 3 +- .../gemm_splitk_benchmark.py | 2 +- .../gemm_streamk_benchmark.py | 2 +- 7 files changed, 18 insertions(+), 31 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ccb5b9c795..0d20a03175 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -118,6 +118,7 @@ repos: - --disable=line-too-long # Disable import-error: not everything can be imported when pre-commit runs - --disable=import-error + - --disable=no-name-in-module # Disable unused-import: ruff has a corresponding check and supports "noqa: F401" - --disable=unused-import # Disable invalid_name: benchmarks use a lot of UPPER_SNAKE arguments diff --git a/benchmarks/setup.py b/benchmarks/setup.py index 1692415507..8924c96ef9 100644 --- a/benchmarks/setup.py +++ b/benchmarks/setup.py @@ -6,8 +6,6 @@ # TODO: update once there is replacement for clean: # https://github.com/pypa/setuptools/discussions/2838 from distutils import log # pylint: disable=[deprecated-module] -from distutils.dir_util import remove_tree # pylint: disable=[deprecated-module] -from distutils.command.clean import clean as _clean # pylint: disable=[deprecated-module] from setuptools import setup, Extension from setuptools.command.build_ext import build_ext as _build_ext @@ -24,10 +22,10 @@ def __init__(self, name): class CMakeBuild(): - def __init__(self, debug=False, dry_run=False): + def __init__(self, build_lib, build_temp, debug=False, dry_run=False): self.current_dir = os.path.abspath(os.path.dirname(__file__)) - self.build_temp = self.current_dir + "/build/temp" - self.extdir = self.current_dir + "/triton_kernels_benchmark" + self.build_temp = build_temp + self.extdir = build_lib + "/triton_kernels_benchmark" self.build_type = self.get_build_type(debug) self.cmake_prefix_paths = [torch.utils.cmake_prefix_path] self.use_ipex = False @@ -101,30 +99,20 @@ def build_extension(self): self.check_call(["cmake"] + build_args) self.check_call(["cmake"] + install_args) - def clean(self): - if os.path.exists(self.build_temp): - remove_tree(self.build_temp, dry_run=self.dry_run) - else: - log.warn("'%s' does not exist -- can't clean it", os.path.relpath(self.build_temp, - os.path.dirname(__file__))) - class build_ext(_build_ext): def run(self): - cmake = CMakeBuild(debug=self.debug, dry_run=self.dry_run) + cmake = CMakeBuild( + build_lib=self.build_lib, + build_temp=self.build_temp, + debug=self.debug, + dry_run=self.dry_run, + ) cmake.run() super().run() -class clean(_clean): - - def run(self): - cmake = CMakeBuild(dry_run=self.dry_run) - cmake.clean() - super().run() - - def get_git_commit_hash(length=8): try: cmd = ["git", "rev-parse", f"--short={length}", "HEAD"] @@ -151,11 +139,10 @@ def get_git_commit_hash(length=8): package_data={"triton_kernels_benchmark": ["xetla_kernel.cpython-*.so"]}, cmdclass={ "build_ext": build_ext, - "clean": clean, }, - ext_modules=[CMakeExtension("triton_kernels_benchmark")], - extra_require={ - "ipex": ["numpy<=2.0", "intel-extension-for-pytorch=2.1.10"], + ext_modules=[CMakeExtension("triton_kernels_benchmark.xetla_kernel")], + extras_require={ + "ipex": ["numpy<=2.0", "intel-extension-for-pytorch==2.1.10"], "pytorch": ["torch>=2.6"], }, ) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 606b67af52..83cca419ec 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -4,7 +4,7 @@ import triton.language as tl import triton_kernels_benchmark as benchmark_suit -import xetla_kernel +from triton_kernels_benchmark import xetla_kernel if benchmark_suit.USE_IPEX_OPTION: import intel_extension_for_pytorch # type: ignore # noqa: F401 diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index 3f17ac4a55..3fa5983d7b 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -13,7 +13,7 @@ from triton.runtime import driver import triton_kernels_benchmark as benchmark_suit -import xetla_kernel +from triton_kernels_benchmark import xetla_kernel @torch.jit.script diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index 7e0b339dc6..6aef756dcb 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -14,8 +14,7 @@ import triton_kernels_benchmark as benchmark_suit from triton_kernels_benchmark.benchmark_testing import do_bench_elapsed_time, BENCHMARKING_METHOD - -import xetla_kernel +from triton_kernels_benchmark import xetla_kernel if benchmark_suit.USE_IPEX_OPTION: import intel_extension_for_pytorch # type: ignore # noqa: F401 diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index b6443bf947..3433655303 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -3,7 +3,7 @@ import triton.language as tl import triton_kernels_benchmark as benchmark_suit -import xetla_kernel +from triton_kernels_benchmark import xetla_kernel if benchmark_suit.USE_IPEX_OPTION: import intel_extension_for_pytorch # type: ignore # noqa: F401 diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index a495dca749..1389eb9eb1 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -10,7 +10,7 @@ import triton.language as tl import triton_kernels_benchmark as benchmark_suit -import xetla_kernel +from triton_kernels_benchmark import xetla_kernel if benchmark_suit.USE_IPEX_OPTION: import intel_extension_for_pytorch # type: ignore # noqa: F401