Skip to content

Commit 6fa6f7e

Browse files
authored
Merge branch 'main' into fast-sub-group-transpose
2 parents adcd22d + e48642c commit 6fa6f7e

File tree

19 files changed

+332
-207
lines changed

19 files changed

+332
-207
lines changed

.github/pins/pytorch-upstream.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0a2685160140656e3e53818611dd2c65c4397be5
1+
8321eec009c8c79145ebccd51fdfc336e5f8b848

.gitignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,19 @@ build-*/
66
python/build/
77
python/dist/
88
python/triton*.egg-info/
9+
python/*.whl
910

1011
python/triton/_C/*.pyd
1112
python/triton/_C/*.so
1213
python/triton/_C/*.dylib
1314

15+
benchmarks/dist
16+
benchmarks/*.egg-info/
17+
benchmarks/**/*.so
18+
19+
# Logs
20+
inductor_log/
21+
1422
# Backends copied from submodules
1523
python/triton/backends/
1624
!python/triton/backends/__init__.py

.pre-commit-config.yaml

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
- id: ruff
2323
files: '^python/.*'
2424
args: ["--fix", "--line-length", "120"]
25-
stages: [commit, push, manual]
25+
stages: [pre-commit, pre-push, manual]
2626
exclude: |
2727
(?x)(
2828
^python/triton/runtime/.*|
@@ -35,49 +35,49 @@ repos:
3535
hooks:
3636
- id: yapf
3737
args: ["-p", "-i"]
38-
stages: [commit, push, manual]
38+
stages: [pre-commit, pre-push, manual]
3939
exclude: "python/test/unit/language/test_line_info.py"
4040

4141
- repo: https://github.com/pre-commit/mirrors-clang-format
4242
rev: v16.0.6
4343
hooks:
4444
- id: clang-format
45-
stages: [commit, push, manual]
45+
stages: [pre-commit, pre-push, manual]
4646

4747
# Expand YAML anchors in files used by github workflows, because github can't
4848
# do this itself. This lets us use anchors, which avoids code duplication.
49-
# - repo: local
50-
# hooks:
51-
# - id: expand-yaml-anchors
52-
# name: Expand YAML anchors
53-
# language: golang
54-
# additional_dependencies: [github.com/mikefarah/yq/v4@latest]
55-
# entry: >
56-
# bash -c '
57-
# OUT=".github/workflows/integration-tests.yml"
58-
# IN="$OUT.in"
59-
# echo "# AUTOGENERATED by pre-commit, modify the .in file instead." > "$OUT" &&
60-
# echo >> "$OUT"
61-
# yq "explode(.)" "$IN" >> "$OUT"
62-
# '
63-
# files: ^.github/workflows/integration-tests.yml.*
64-
# pass_filenames: false
49+
- repo: local
50+
hooks:
51+
- id: expand-yaml-anchors
52+
name: Expand YAML anchors
53+
language: golang
54+
additional_dependencies: [github.com/mikefarah/yq/v4@latest]
55+
entry: >
56+
bash -c '
57+
OUT=".github/workflows/integration-tests.yml"
58+
IN="$OUT.in"
59+
echo "# AUTOGENERATED by pre-commit, modify the .in file instead." > "$OUT" &&
60+
echo >> "$OUT"
61+
yq "explode(.)" "$IN" >> "$OUT"
62+
'
63+
files: ^.github/workflows/integration-tests.yml.*
64+
pass_filenames: false
6565

6666
- repo: https://github.com/PyCQA/bandit
6767
rev: '1.7.9'
6868
hooks:
6969
- id: bandit
7070
files: '^(benchmarks|scripts|third_party/intel)/.*\.py$'
7171
args: ["-c", "bandit.yaml", "-s", "B404,B603,B607"]
72-
stages: [commit, push, manual]
72+
stages: [pre-commit, pre-push, manual]
7373

7474
- repo: https://github.com/astral-sh/ruff-pre-commit
7575
rev: v0.1.3
7676
hooks:
7777
- id: ruff
7878
files: '^(benchmarks|third_party/intel|scripts)/.*'
7979
args: ["--fix", "--line-length", "120"]
80-
stages: [commit, push, manual]
80+
stages: [pre-commit, pre-push, manual]
8181

8282
- repo: https://github.com/pycqa/pylint
8383
rev: v3.2.6
@@ -105,7 +105,7 @@ repos:
105105
- --disable=too-many-locals
106106
- --disable=too-many-statements
107107
- --disable=too-many-arguments
108-
stages: [commit, push, manual]
108+
stages: [pre-commit, pre-push, manual]
109109

110110
- id: pylint
111111
name: pylint for benchmarks
@@ -136,7 +136,7 @@ repos:
136136
- --disable=too-many-statements
137137
- --disable=too-many-arguments
138138
- --disable=fixme
139-
stages: [commit, push, manual]
139+
stages: [pre-commit, pre-push, manual]
140140

141141

142142
exclude: |

benchmarks/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@ if(NOT WIN32)
1010
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
1111
endif()
1212

13-
find_package(Python3 COMPONENTS Interpreter)
13+
find_package(Python3 REQUIRED
14+
COMPONENTS Development.Module)
1415
find_package(Torch REQUIRED)
1516
find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib")
17+
find_package(XeTLALibrary REQUIRED)
1618

1719
if(USE_IPEX)
1820
string(APPEND CMAKE_CXX_FLAGS " -DUSE_IPEX")

benchmarks/cmake/FindXeTLALibrary.cmake

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
include(FetchContent)
44

55
if (NOT XeTLALibrary_FOUND)
6+
# TODO: switch ot FetchContent_MakeAvailable once XeTLA supports it
7+
cmake_policy(SET CMP0169 OLD)
68

79
set(XeTLALibrary_SOURCE_DIR
810
"${CMAKE_CURRENT_BINARY_DIR}/XeTLALibrary")
911
message(STATUS "XeTLALibrary is not specified. Will try to download
1012
XeTLA library from https://github.com/intel/xetla into
1113
${XeTLALibrary_SOURCE_DIR}")
12-
file(READ xetla-library.conf XeTLALibrary_TAG)
14+
file(READ xetla_kernel/xetla-library.conf XeTLALibrary_TAG)
1315
# Strip the potential trailing newline from tag
1416
string(STRIP "${XeTLALibrary_TAG}" XeTLALibrary_TAG)
1517
FetchContent_Declare(xetla-library

benchmarks/setup.py

Lines changed: 91 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,135 @@
11
import os
2-
import re
32
import shutil
43
import subprocess
5-
import sysconfig
64
import sys
75

8-
from setuptools import setup
6+
# TODO: update once there is replacement for clean:
7+
# https://github.com/pypa/setuptools/discussions/2838
8+
from distutils import log # pylint: disable=[deprecated-module]
9+
from distutils.dir_util import remove_tree # pylint: disable=[deprecated-module]
10+
from distutils.command.clean import clean as _clean # pylint: disable=[deprecated-module]
11+
12+
from setuptools import setup, Extension
13+
from setuptools.command.build_ext import build_ext as _build_ext
914

1015
import torch
1116

12-
ipex_cmake_prefix_path = ""
13-
USE_IPEX_OPTION = os.getenv("USE_IPEX", "1")
14-
if USE_IPEX_OPTION == "1":
15-
import intel_extension_for_pytorch
16-
ipex_cmake_prefix_path = f";{intel_extension_for_pytorch.cmake_prefix_path}"
17+
18+
class CMakeExtension(Extension):
19+
20+
def __init__(self, name):
21+
# don't invoke the original build_ext for this special extension
22+
super().__init__(name, sources=[])
1723

1824

1925
class CMakeBuild():
2026

21-
def __init__(self):
27+
def __init__(self, debug=False, dry_run=False):
2228
self.current_dir = os.path.abspath(os.path.dirname(__file__))
2329
self.build_temp = self.current_dir + "/build/temp"
2430
self.extdir = self.current_dir + "/triton_kernels_benchmark"
31+
self.build_type = self.get_build_type(debug)
32+
self.cmake_prefix_paths = [torch.utils.cmake_prefix_path]
33+
self.use_ipex = False
34+
self.dry_run = dry_run
35+
36+
def get_build_type(self, debug):
37+
DEBUG_OPTION = os.getenv("DEBUG", "0")
38+
return "Debug" if debug or (DEBUG_OPTION == "1") else "Release"
2539

2640
def run(self):
27-
try:
28-
out = subprocess.check_output(["cmake", "--version"])
29-
except OSError as error:
30-
raise RuntimeError("CMake must be installed") from error
41+
self.check_ipex()
42+
self.build_extension()
3143

32-
match = re.search(r"version\s*(?P<major>\d+)\.(?P<minor>\d+)([\d.]+)?", out.decode())
33-
cmake_major, cmake_minor = int(match.group("major")), int(match.group("minor"))
34-
if (cmake_major, cmake_minor) < (3, 18):
35-
raise RuntimeError("CMake >= 3.18.0 is required")
44+
def check_ipex(self):
45+
self.use_ipex = os.getenv("USE_IPEX", "1") == "1"
46+
if not self.use_ipex:
47+
return
48+
try:
49+
import intel_extension_for_pytorch
50+
except ImportError:
51+
log.warn("ipex is not installed trying to build without ipex")
52+
self.use_ipex = False
53+
return
54+
self.cmake_prefix_paths.append(intel_extension_for_pytorch.cmake_prefix_path)
3655

37-
self.build_extension()
56+
def check_call(self, *popenargs, **kwargs):
57+
log.info(" ".join(popenargs[0]))
58+
if not self.dry_run:
59+
subprocess.check_call(*popenargs, **kwargs)
3860

3961
def build_extension(self):
4062
ninja_dir = shutil.which("ninja")
4163
# create build directories
4264
if not os.path.exists(self.build_temp):
4365
os.makedirs(self.build_temp)
44-
# python directories
45-
python_include_dir = sysconfig.get_path("platinclude")
4666
cmake_args = [
4767
"-G",
4868
"Ninja", # Ninja is much faster than make
4969
"-DCMAKE_MAKE_PROGRAM=" +
5070
ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path
51-
f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}{ipex_cmake_prefix_path}",
52-
f"-DUSE_IPEX={USE_IPEX_OPTION}",
53-
"-DCMAKE_EXPORT_COMPILE_COMMANDS=ON",
54-
"-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY=" + self.extdir,
55-
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + self.extdir,
56-
"-DPython3_EXECUTABLE:FILEPATH=" + sys.executable,
57-
"-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON",
58-
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
71+
"-DCMAKE_PREFIX_PATH=" + ";".join(self.cmake_prefix_paths),
72+
f"-DUSE_IPEX={int(self.use_ipex)}",
73+
"-DCMAKE_INSTALL_PREFIX=" + self.extdir,
74+
"-DPython3_ROOT_DIR:FILEPATH=" + sys.exec_prefix,
75+
"-DCMAKE_VERBOSE_MAKEFILE=TRUE",
5976
"-DCMAKE_C_COMPILER=icx",
6077
"-DCMAKE_CXX_COMPILER=icpx",
78+
"-DCMAKE_BUILD_TYPE=" + self.build_type,
79+
"-S",
80+
self.current_dir,
81+
"-B",
82+
self.build_temp,
6183
]
6284

63-
# configuration
64-
build_type = "Debug"
65-
build_args = ["--config", build_type]
66-
cmake_args += ["-DCMAKE_BUILD_TYPE=" + build_type]
6785
max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count()))
68-
build_args += ["-j" + max_jobs]
86+
build_args = [
87+
"--build",
88+
self.build_temp,
89+
"-j" + max_jobs,
90+
]
91+
92+
install_args = [
93+
"--build",
94+
self.build_temp,
95+
"--target",
96+
"install",
97+
]
6998

7099
env = os.environ.copy()
71-
cmake_dir = self.build_temp
72-
subprocess.check_call(["cmake", self.current_dir] + cmake_args, cwd=cmake_dir, env=env)
73-
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir)
100+
self.check_call(["cmake"] + cmake_args, env=env)
101+
self.check_call(["cmake"] + build_args)
102+
self.check_call(["cmake"] + install_args)
103+
104+
def clean(self):
105+
if os.path.exists(self.build_temp):
106+
remove_tree(self.build_temp, dry_run=self.dry_run)
107+
else:
108+
log.warn("'%s' does not exist -- can't clean it", os.path.relpath(self.build_temp,
109+
os.path.dirname(__file__)))
110+
74111

112+
class build_ext(_build_ext):
113+
114+
def run(self):
115+
cmake = CMakeBuild(debug=self.debug, dry_run=self.dry_run)
116+
cmake.run()
117+
super().run()
118+
119+
120+
class clean(_clean):
121+
122+
def run(self):
123+
cmake = CMakeBuild(dry_run=self.dry_run)
124+
cmake.clean()
125+
super().run()
75126

76-
cmake = CMakeBuild()
77-
cmake.run()
78127

79128
setup(name="triton-kernels-benchmark", packages=[
80129
"triton_kernels_benchmark",
81130
], package_dir={
82131
"triton_kernels_benchmark": "triton_kernels_benchmark",
83-
}, package_data={"triton_kernels_benchmark": ["xetla_kernel.so"]})
132+
}, package_data={"triton_kernels_benchmark": ["xetla_kernel.cpython-*.so"]}, cmdclass={
133+
"build_ext": build_ext,
134+
"clean": clean,
135+
}, ext_modules=[CMakeExtension("triton_kernels_benchmark")])

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,16 +213,18 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no
213213

214214
function_events = prof.events()
215215

216-
functions = []
216+
all_functions = []
217217
if isinstance(kernel_name, str):
218218
kernel_name = [kernel_name]
219219
for ker_name in kernel_name:
220-
functions.extend(list(filter(lambda x: x.name.startswith(ker_name), function_events))) # pylint: disable=cell-var-from-loop
220+
functions = list(filter(lambda x: x.name.startswith(ker_name), function_events)) # pylint: disable=cell-var-from-loop
221+
assert len(functions) == n_repeat, f"the profiling number for kernel: '{ker_name}' not match, {len(functions)}"
222+
all_functions.append(functions)
221223
# profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events)
222224

223-
assert len(functions) == n_repeat, f"the profiling number not match, {len(functions)}"
224225
# Make the time to the milliseconds.
225-
times = torch.tensor([f.self_device_time_total * 1e-3 for f in functions], dtype=torch.float)
226+
times = torch.tensor([sum(map(lambda elem: elem.self_device_time_total, f)) * 1e-3 for f in zip(*all_functions)],
227+
dtype=torch.float)
226228
return _summarize_statistics(times, quantiles, return_mode)
227229

228230

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,10 @@ def benchmark(B, M, N, K, provider):
309309
acc = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
310310
cnt = torch.empty((B, M, N), device='xpu', dtype=torch.int32)
311311
name = f'gemm_shape_{B}_{M}_{K}_{N}'
312+
# FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
313+
# better performance.
314+
if (B, M, N, K) == (1, 3072, 4096, 3072):
315+
name = 'gemm_streamk_shape_3072_4096_3072'
312316
func = getattr(xetla_kernel, name)
313317
xetla_fn = lambda: func(a, b, c, acc, cnt)
314318
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
@@ -338,6 +342,7 @@ def benchmark(B, M, N, K, provider):
338342
'gemm_shape_32_4096_4096_128': 'Test_32x4096x4096x128_row_row',
339343
'gemm_shape_4096_8_128_16384': 'Test_4096x8x128x16384_row_row',
340344
'gemm_shape_4096_8_16384_128': 'Test_4096x8x16384x128_row_row',
345+
'gemm_streamk_shape_3072_4096_3072': 'stream_k_gemm_run',
341346
}
342347

343348
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')

0 commit comments

Comments
 (0)