Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ repos:
- --disable=line-too-long
# Disable import-error: not everything can be imported when pre-commit runs
- --disable=import-error
- --disable=no-name-in-module
# Disable unused-import: ruff has a corresponding check and supports "noqa: F401"
- --disable=unused-import
# Disable invalid_name: benchmarks use a lot of UPPER_SNAKE arguments
Expand Down
37 changes: 12 additions & 25 deletions benchmarks/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
# TODO: update once there is replacement for clean:
# https://github.com/pypa/setuptools/discussions/2838
from distutils import log # pylint: disable=[deprecated-module]
from distutils.dir_util import remove_tree # pylint: disable=[deprecated-module]
from distutils.command.clean import clean as _clean # pylint: disable=[deprecated-module]

from setuptools import setup, Extension
from setuptools.command.build_ext import build_ext as _build_ext
Expand All @@ -24,10 +22,10 @@ def __init__(self, name):

class CMakeBuild():

def __init__(self, debug=False, dry_run=False):
def __init__(self, build_lib, build_temp, debug=False, dry_run=False):
self.current_dir = os.path.abspath(os.path.dirname(__file__))
self.build_temp = self.current_dir + "/build/temp"
self.extdir = self.current_dir + "/triton_kernels_benchmark"
self.build_temp = build_temp
self.extdir = build_lib + "/triton_kernels_benchmark"
self.build_type = self.get_build_type(debug)
self.cmake_prefix_paths = [torch.utils.cmake_prefix_path]
self.use_ipex = False
Expand Down Expand Up @@ -101,30 +99,20 @@ def build_extension(self):
self.check_call(["cmake"] + build_args)
self.check_call(["cmake"] + install_args)

def clean(self):
if os.path.exists(self.build_temp):
remove_tree(self.build_temp, dry_run=self.dry_run)
else:
log.warn("'%s' does not exist -- can't clean it", os.path.relpath(self.build_temp,
os.path.dirname(__file__)))


class build_ext(_build_ext):

def run(self):
cmake = CMakeBuild(debug=self.debug, dry_run=self.dry_run)
cmake = CMakeBuild(
build_lib=self.build_lib,
build_temp=self.build_temp,
debug=self.debug,
dry_run=self.dry_run,
)
cmake.run()
super().run()


class clean(_clean):

def run(self):
cmake = CMakeBuild(dry_run=self.dry_run)
cmake.clean()
super().run()


def get_git_commit_hash(length=8):
try:
cmd = ["git", "rev-parse", f"--short={length}", "HEAD"]
Expand All @@ -151,11 +139,10 @@ def get_git_commit_hash(length=8):
package_data={"triton_kernels_benchmark": ["xetla_kernel.cpython-*.so"]},
cmdclass={
"build_ext": build_ext,
"clean": clean,
},
ext_modules=[CMakeExtension("triton_kernels_benchmark")],
extra_require={
"ipex": ["numpy<=2.0", "intel-extension-for-pytorch=2.1.10"],
ext_modules=[CMakeExtension("triton_kernels_benchmark.xetla_kernel")],
extras_require={
"ipex": ["numpy<=2.0", "intel-extension-for-pytorch==2.1.10"],
"pytorch": ["torch>=2.6"],
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit
import xetla_kernel
from triton_kernels_benchmark import xetla_kernel

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from triton.runtime import driver

import triton_kernels_benchmark as benchmark_suit
import xetla_kernel
from triton_kernels_benchmark import xetla_kernel


@torch.jit.script
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

import triton_kernels_benchmark as benchmark_suit
from triton_kernels_benchmark.benchmark_testing import do_bench_elapsed_time, BENCHMARKING_METHOD

import xetla_kernel
from triton_kernels_benchmark import xetla_kernel

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit
import xetla_kernel
from triton_kernels_benchmark import xetla_kernel

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit
import xetla_kernel
from triton_kernels_benchmark import xetla_kernel

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401
Expand Down