Skip to content

Commit 26d2947

Browse files
authored
Port and run tests in python/test/unit/tools (#2953)
1 parent b80c9d6 commit 26d2947

File tree

9 files changed

+579
-85
lines changed

9 files changed

+579
-85
lines changed

python/test/unit/tools/test_aot.py

Lines changed: 161 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import numpy as np
88

99
import triton
10+
from triton._internal_testing import is_cuda, is_xpu
1011
from triton.backends.compiler import GPUTarget
1112
from triton.backends.nvidia.driver import include_dir, library_dirs
13+
from triton.backends.intel.driver import COMPILATION_HELPER
1214

1315
kernel_utils_src = """
1416
import triton
@@ -97,21 +99,42 @@ def kernel(C, A, B, M, N, K,
9799
}"""
98100

99101

100-
def gen_kernel_library(dir, libname):
101-
c_files = glob.glob(os.path.join(dir, "*.c"))
102+
def gen_kernel_library_xpu(dir, libname):
103+
cpp_files = glob.glob(os.path.join(dir, "*.cpp"))
102104
subprocess.run(
103-
["gcc"] + c_files + ["-I", include_dir[0], "-c", "-fPIC"],
105+
["icpx"] + cpp_files + ["-I", COMPILATION_HELPER.include_dir[0], "-c", "-fsycl", "-fPIC"],
104106
check=True,
105107
cwd=dir,
106108
)
107109
o_files = glob.glob(os.path.join(dir, "*.o"))
108110

109-
command = ["gcc", *o_files, "-shared", "-o", libname]
110-
for lib_dir in library_dirs():
111+
command = ["icpx", "-fsycl", "-lze_loader", *o_files, "-shared", "-o", libname]
112+
for lib_dir in COMPILATION_HELPER.library_dir:
111113
command.extend(["-L", lib_dir])
114+
if COMPILATION_HELPER.libsycl_dir:
115+
for lib_dir in COMPILATION_HELPER.libsycl_dir:
116+
command.extend(["-L", lib_dir])
112117
subprocess.run(command, check=True, cwd=dir)
113118

114119

120+
def gen_kernel_library(dir, libname):
121+
if is_xpu():
122+
gen_kernel_library_xpu(dir, libname)
123+
else:
124+
c_files = glob.glob(os.path.join(dir, "*.c"))
125+
subprocess.run(
126+
["gcc"] + c_files + ["-I", include_dir[0], "-c", "-fPIC"],
127+
check=True,
128+
cwd=dir,
129+
)
130+
o_files = glob.glob(os.path.join(dir, "*.o"))
131+
132+
command = ["gcc", *o_files, "-shared", "-o", libname]
133+
for lib_dir in library_dirs():
134+
command.extend(["-L", lib_dir])
135+
subprocess.run(command, check=True, cwd=dir)
136+
137+
115138
def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
116139
test_src = f"""
117140
int main(int argc, char **argv) {{
@@ -171,15 +194,118 @@ def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
171194
}}
172195
"""
173196
src = test_utils_src + test_src
174-
with open(os.path.join(dir, "test.c"), "w") as file:
197+
if is_xpu():
198+
src = f"""
199+
#include "kernel.h"
200+
#include <assert.h>
201+
#include <cmath>
202+
#include <cstddef>
203+
#include <level_zero/ze_api.h>
204+
#include <stdint.h>
205+
#include <stdio.h>
206+
#include <string.h>
207+
#include <sycl/sycl.hpp>
208+
209+
static void write_buffer_to_csv(char *filename, int32_t *buffer, int size) {{
210+
FILE *file = fopen(filename, "w");
211+
if (file == NULL) {{
212+
printf("Could not open file %s\\n", filename);
213+
return;
214+
}}
215+
for (int i = 0; i < size; i++) {{
216+
fprintf(file, "%d", buffer[i]);
217+
if (i < size - 1) {{
218+
fprintf(file, ",");
219+
}}
220+
}}
221+
fclose(file);
222+
}}
223+
224+
static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) {{
225+
FILE *file = fopen(filename, "r");
226+
if (file == NULL) {{
227+
printf("Could not open file %s\\n", filename);
228+
return;
229+
}}
230+
int index = 0;
231+
while (fscanf(file, "%hd,", &buffer[index]) != EOF && index < size) {{
232+
index++;
233+
}}
234+
fclose(file);
235+
}}
236+
int main(int argc, char ** argv) {{
237+
int M = {M}, N = {N}, K = {K};
238+
239+
// initialize sycl handles
240+
sycl::queue q{{sycl::gpu_selector_v}};
241+
sycl::ext::intel::device_ptr<sycl::float16> A =
242+
sycl::malloc_device<sycl::float16>(M * K * 2, q);
243+
sycl::ext::intel::device_ptr<sycl::float16> B =
244+
sycl::malloc_device<sycl::float16>(K * N * 2, q);
245+
sycl::ext::intel::device_ptr<sycl::float16> C =
246+
sycl::malloc_device<sycl::float16>(M * N * 4, q);
247+
248+
// initialize input data
249+
int16_t hA[M * K];
250+
int16_t hB[K * N];
251+
memset(hA, 0, M * K * 2);
252+
memset(hB, 0, K * N * 2);
253+
read_csv_to_buffer(argv[1], hA, M * K);
254+
read_csv_to_buffer(argv[2], hB, K * N);
255+
q.memcpy(A, hA, M * K * 2).wait();
256+
q.memcpy(B, hB, K * N * 2).wait();
257+
258+
// launch kernel
259+
load_matmul_fp16();
260+
int32_t ret;
261+
int algo_id = {algo_id};
262+
if (algo_id == 0) {{
263+
ret = matmul_fp16_default(q, C, A, B, M, N, K, N, 1, K, 1, N, 1);
264+
}} else {{
265+
ret = matmul_fp16(q, C, A, B, M, N, K, N, 1, K, 1, N, 1, {algo_id});
266+
}}
267+
if (ret != 0) fprintf(stderr, "kernel launch failed\\n");
268+
assert(ret == 0);
269+
270+
q.wait();
271+
272+
// read data
273+
int32_t hC[M * N];
274+
memset(hC, 0, M * N * 4);
275+
q.memcpy(hC, C, M * N * 4).wait();
276+
write_buffer_to_csv(argv[3], hC, M * N);
277+
278+
// free sycl resources
279+
unload_matmul_fp16();
280+
sycl::free(A, q);
281+
sycl::free(B, q);
282+
sycl::free(C, q);
283+
}}
284+
"""
285+
src_name = "test.c"
286+
if is_xpu():
287+
src_name = "test.cpp"
288+
with open(os.path.join(dir, src_name), "w") as file:
175289
file.write(src)
176290

177-
command = ["gcc", "test.c"]
178-
for inc_dir in include_dir:
179-
command.extend(["-I", inc_dir])
180-
for lib_dir in library_dirs():
181-
command.extend(["-L", lib_dir])
182-
command.extend(["-l", "cuda", "-L", dir, "-l", "kernel", "-o", exe])
291+
if is_cuda():
292+
command = ["gcc", "test.c"]
293+
for inc_dir in include_dir:
294+
command.extend(["-I", inc_dir])
295+
for lib_dir in library_dirs():
296+
command.extend(["-L", lib_dir])
297+
command.extend(["-l", "cuda", "-L", dir, "-l", "kernel", "-o", exe])
298+
299+
if is_xpu():
300+
command = ["icpx", "test.cpp"]
301+
for inc_dir in COMPILATION_HELPER.include_dir:
302+
command.extend(["-I", inc_dir])
303+
for lib_dir in COMPILATION_HELPER.library_dir:
304+
command.extend(["-L", lib_dir])
305+
if COMPILATION_HELPER.libsycl_dir:
306+
for lib_dir in COMPILATION_HELPER.libsycl_dir:
307+
command.extend(["-L", lib_dir])
308+
command.extend(["-fsycl", "-lze_loader", "-L", dir, "-l", "kernel", "-o", exe])
183309
subprocess.run(command, check=True, cwd=dir)
184310

185311

@@ -283,6 +409,7 @@ def test_compile_link_matmul_no_specialization():
283409

284410
with tempfile.TemporaryDirectory() as tmp_dir:
285411
dtype = "fp16"
412+
286413
BM, BN, BK = 16, 16, 16
287414

288415
kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
@@ -299,9 +426,8 @@ def test_compile_link_matmul_no_specialization():
299426

300427
# run test case
301428
env = os.environ.copy()
302-
env["LD_LIBRARY_PATH"] = tmp_dir
429+
env["LD_LIBRARY_PATH"] = tmp_dir + ":" + env.get("LD_LIBRARY_PATH", "")
303430
subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir)
304-
305431
# read data and compare against reference
306432
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
307433
c_tri = c.reshape((M, N)).view(np.float32)
@@ -330,7 +456,7 @@ def test_compile_link_matmul():
330456

331457
# run test case
332458
env = os.environ.copy()
333-
env["LD_LIBRARY_PATH"] = tmp_dir
459+
env["LD_LIBRARY_PATH"] = tmp_dir + ":" + env.get("LD_LIBRARY_PATH", "")
334460
subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir)
335461

336462
# read data and compare against reference
@@ -361,7 +487,7 @@ def test_launcher_has_no_available_kernel():
361487

362488
# run test case
363489
env = os.environ.copy()
364-
env["LD_LIBRARY_PATH"] = tmp_dir
490+
env["LD_LIBRARY_PATH"] = tmp_dir + ":" + env.get("LD_LIBRARY_PATH", "")
365491
result = subprocess.run(
366492
["./test", a_path, b_path, c_path],
367493
env=env,
@@ -410,7 +536,7 @@ def test_compile_link_autotune_matmul():
410536
gen_test_bin(tmp_dir, M, N, K, exe=test_name, algo_id=algo_id)
411537

412538
env = os.environ.copy()
413-
env["LD_LIBRARY_PATH"] = tmp_dir
539+
env["LD_LIBRARY_PATH"] = tmp_dir + ":" + env.get("LD_LIBRARY_PATH", "")
414540
subprocess.run(
415541
[f"./{test_name}", a_path, b_path, c_path],
416542
check=True,
@@ -440,3 +566,21 @@ def test_ttgir_to_ptx():
440566
ptx = k.asm["ptx"]
441567
assert ".target sm_80" in ptx
442568
assert ".address_size 64" in ptx
569+
570+
571+
def test_ttgir_to_spv():
572+
src = """
573+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} {
574+
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32>, %arg1: !tt.ptr<i32>) {
575+
tt.return
576+
}
577+
}
578+
"""
579+
with tempfile.TemporaryDirectory() as tmp_dir:
580+
kernel_path = os.path.join(tmp_dir, "empty_kernel.ttgir")
581+
with open(kernel_path, "w") as fp:
582+
fp.write(src)
583+
k = triton.compile(kernel_path, target=triton.runtime.driver.active.get_current_target())
584+
spv = k.asm['spvdis']
585+
assert "OpCapability KernelAttributesINTEL" in spv
586+
assert "SubgroupSize 32" in spv

python/test/unit/tools/test_disasm.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,19 @@ def kernel(X, i: tl.constexpr):
1919
sass = h.asm["sass"]
2020
# check that the sass has a store instruction.
2121
assert "STG.E" in sass
22+
23+
24+
def test_disam_spvbin():
25+
if not triton.runtime.driver.active.get_current_target().backend == "xpu":
26+
pytest.skip("Test requires XPU.")
27+
28+
@triton.jit
29+
def kernel(X, i: tl.constexpr):
30+
tl.store(X, i)
31+
32+
x = torch.empty(1, dtype=torch.int32, device='xpu')
33+
h = kernel[(1, )](x, i=12)
34+
assert x[0] == 12
35+
dis = h.asm["spvdis"]
36+
# check that the spvdis has a store instruction.
37+
assert "OpStore" in dis

python/triton/compiler/compiler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ..runtime.autotuner import OutOfResources
99
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
1010
from ..runtime.driver import driver
11-
from ..tools.disasm import get_sass
11+
from ..tools.disasm import get_sass, get_spvdis
1212
# TODO: this shouldn't be here
1313
from .code_generator import ast_to_ttir
1414
from pathlib import Path
@@ -175,6 +175,8 @@ def parse(full_name, ext, context):
175175
return Path(full_name).read_text()
176176
if ext == "cubin" or ext == "hsaco":
177177
return Path(full_name).read_bytes()
178+
if ext == "spv":
179+
return Path(full_name).read_bytes()
178180

179181

180182
def filter_traceback(e: BaseException):
@@ -339,6 +341,8 @@ def __missing__(self, key):
339341

340342
if key == "sass":
341343
value = get_sass(self["cubin"])
344+
if key == "spvdis":
345+
value = get_spvdis(self["spv"])
342346
else:
343347
raise KeyError("Unknown key: '%s'" % key)
344348

0 commit comments

Comments
 (0)