Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 12 additions & 25 deletions benchmarks/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
# TODO: update once there is replacement for clean:
# https://github.com/pypa/setuptools/discussions/2838
from distutils import log # pylint: disable=[deprecated-module]
from distutils.dir_util import remove_tree # pylint: disable=[deprecated-module]
from distutils.command.clean import clean as _clean # pylint: disable=[deprecated-module]

from setuptools import setup, Extension
from setuptools.command.build_ext import build_ext as _build_ext
Expand All @@ -24,10 +22,10 @@ def __init__(self, name):

class CMakeBuild():

def __init__(self, debug=False, dry_run=False):
def __init__(self, build_lib, build_temp, debug=False, dry_run=False):
self.current_dir = os.path.abspath(os.path.dirname(__file__))
self.build_temp = self.current_dir + "/build/temp"
self.extdir = self.current_dir + "/triton_kernels_benchmark"
self.build_temp = build_temp
self.extdir = build_lib + "/triton_kernels_benchmark"
self.build_type = self.get_build_type(debug)
self.cmake_prefix_paths = [torch.utils.cmake_prefix_path]
self.use_ipex = False
Expand Down Expand Up @@ -101,30 +99,20 @@ def build_extension(self):
self.check_call(["cmake"] + build_args)
self.check_call(["cmake"] + install_args)

def clean(self):
if os.path.exists(self.build_temp):
remove_tree(self.build_temp, dry_run=self.dry_run)
else:
log.warn("'%s' does not exist -- can't clean it", os.path.relpath(self.build_temp,
os.path.dirname(__file__)))


class build_ext(_build_ext):

def run(self):
cmake = CMakeBuild(debug=self.debug, dry_run=self.dry_run)
cmake = CMakeBuild(
build_lib=self.build_lib,
build_temp=self.build_temp,
debug=self.debug,
dry_run=self.dry_run,
)
cmake.run()
super().run()


class clean(_clean):

def run(self):
cmake = CMakeBuild(dry_run=self.dry_run)
cmake.clean()
super().run()


def get_git_commit_hash(length=8):
try:
cmd = ["git", "rev-parse", f"--short={length}", "HEAD"]
Expand All @@ -151,11 +139,10 @@ def get_git_commit_hash(length=8):
package_data={"triton_kernels_benchmark": ["xetla_kernel.cpython-*.so"]},
cmdclass={
"build_ext": build_ext,
"clean": clean,
},
ext_modules=[CMakeExtension("triton_kernels_benchmark")],
extra_require={
"ipex": ["numpy<=2.0", "intel-extension-for-pytorch=2.1.10"],
ext_modules=[CMakeExtension("triton_kernels_benchmark.xetla_kernel")],
extras_require={
"ipex": ["numpy<=2.0", "intel-extension-for-pytorch==2.1.10"],
"pytorch": ["torch>=2.6"],
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit
import xetla_kernel
import triton_kernels_benchmark.xetla_kernel

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401
Expand Down Expand Up @@ -262,7 +262,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):

elif provider == 'xetla':
module_name = f'flash_attn_causal_{CAUSAL}'.lower()
func = getattr(xetla_kernel, module_name)
func = getattr(triton_kernels_benchmark.xetla_kernel, module_name)
out = torch.empty_like(q, device='xpu', dtype=dtype)
size_score = Z * H * N_CTX * N_CTX
size_attn_mask = Z * N_CTX * N_CTX
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from triton.runtime import driver

import triton_kernels_benchmark as benchmark_suit
import xetla_kernel
import triton_kernels_benchmark.xetla_kernel


@torch.jit.script
Expand Down Expand Up @@ -140,7 +140,7 @@ def benchmark(M, N, provider):

elif provider == "xetla":
name = f"softmax_shape_{M}_{N}"
func = getattr(xetla_kernel, name)
func = getattr(triton_kernels_benchmark.xetla_kernel, name)
out = torch.empty_like(x, device="xpu")
xetla_fn = lambda: func(x, out, 0)
torch_fn = lambda: torch.softmax(x, axis=-1)
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import triton_kernels_benchmark as benchmark_suit
from triton_kernels_benchmark.benchmark_testing import do_bench_elapsed_time, BENCHMARKING_METHOD

import xetla_kernel
import triton_kernels_benchmark.xetla_kernel

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401
Expand Down Expand Up @@ -313,7 +313,7 @@ def benchmark(B, M, N, K, provider):
# better performance.
if (B, M, N, K) == (1, 3072, 4096, 3072):
name = 'gemm_streamk_shape_3072_4096_3072'
func = getattr(xetla_kernel, name)
func = getattr(triton_kernels_benchmark.xetla_kernel, name)
xetla_fn = lambda: func(a, b, c, acc, cnt)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit
import xetla_kernel
import triton_kernels_benchmark.xetla_kernel

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401
Expand Down Expand Up @@ -165,7 +165,7 @@ def benchmark(M, N, K, provider):
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)

name = f'gemm_splitk_shape_{M}_{K}_{N}'
func = getattr(xetla_kernel, name)
func = getattr(triton_kernels_benchmark.xetla_kernel, name)
xetla_fn = lambda: func(a, b, c, acc, cnt)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit
import xetla_kernel
import triton_kernels_benchmark.xetla_kernel

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401
Expand Down Expand Up @@ -288,7 +288,7 @@ def benchmark(M, N, K, provider):
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)

name = f'gemm_streamk_shape_{M}_{K}_{N}'
func = getattr(xetla_kernel, name)
func = getattr(triton_kernels_benchmark.xetla_kernel, name)
xetla_fn = lambda: func(a, b, c, acc, cnt)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

Expand Down
Loading