77import numpy as np
88
99import triton
10+ from triton ._internal_testing import is_cuda , is_xpu
1011from triton .backends .compiler import GPUTarget
1112from triton .backends .nvidia .driver import include_dir , library_dirs
13+ from triton .backends .intel .driver import COMPILATION_HELPER
1214
1315kernel_utils_src = """
1416import triton
@@ -97,21 +99,42 @@ def kernel(C, A, B, M, N, K,
9799}"""
98100
99101
100- def gen_kernel_library (dir , libname ):
101- c_files = glob .glob (os .path .join (dir , "*.c " ))
102+ def gen_kernel_library_xpu (dir , libname ):
103+ cpp_files = glob .glob (os .path .join (dir , "*.cpp " ))
102104 subprocess .run (
103- ["gcc " ] + c_files + ["-I" , include_dir [0 ], "-c" , "-fPIC" ],
105+ ["icpx " ] + cpp_files + ["-I" , COMPILATION_HELPER . include_dir [0 ], "-c" , "-fsycl " , "-fPIC" ],
104106 check = True ,
105107 cwd = dir ,
106108 )
107109 o_files = glob .glob (os .path .join (dir , "*.o" ))
108110
109- command = ["gcc " , * o_files , "-shared" , "-o" , libname ]
110- for lib_dir in library_dirs () :
111+ command = ["icpx" , "-fsycl" , "-lze_loader " , * o_files , "-shared" , "-o" , libname ]
112+ for lib_dir in COMPILATION_HELPER . library_dir :
111113 command .extend (["-L" , lib_dir ])
114+ if COMPILATION_HELPER .libsycl_dir :
115+ for lib_dir in COMPILATION_HELPER .libsycl_dir :
116+ command .extend (["-L" , lib_dir ])
112117 subprocess .run (command , check = True , cwd = dir )
113118
114119
120+ def gen_kernel_library (dir , libname ):
121+ if is_xpu ():
122+ gen_kernel_library_xpu (dir , libname )
123+ else :
124+ c_files = glob .glob (os .path .join (dir , "*.c" ))
125+ subprocess .run (
126+ ["gcc" ] + c_files + ["-I" , include_dir [0 ], "-c" , "-fPIC" ],
127+ check = True ,
128+ cwd = dir ,
129+ )
130+ o_files = glob .glob (os .path .join (dir , "*.o" ))
131+
132+ command = ["gcc" , * o_files , "-shared" , "-o" , libname ]
133+ for lib_dir in library_dirs ():
134+ command .extend (["-L" , lib_dir ])
135+ subprocess .run (command , check = True , cwd = dir )
136+
137+
115138def gen_test_bin (dir , M , N , K , exe = "test" , algo_id = 0 ):
116139 test_src = f"""
117140int main(int argc, char **argv) {{
@@ -171,15 +194,118 @@ def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
171194}}
172195"""
173196 src = test_utils_src + test_src
174- with open (os .path .join (dir , "test.c" ), "w" ) as file :
197+ if is_xpu ():
198+ src = f"""
199+ #include "kernel.h"
200+ #include <assert.h>
201+ #include <cmath>
202+ #include <cstddef>
203+ #include <level_zero/ze_api.h>
204+ #include <stdint.h>
205+ #include <stdio.h>
206+ #include <string.h>
207+ #include <sycl/sycl.hpp>
208+
209+ static void write_buffer_to_csv(char *filename, int32_t *buffer, int size) {{
210+ FILE *file = fopen(filename, "w");
211+ if (file == NULL) {{
212+ printf("Could not open file %s\\ n", filename);
213+ return;
214+ }}
215+ for (int i = 0; i < size; i++) {{
216+ fprintf(file, "%d", buffer[i]);
217+ if (i < size - 1) {{
218+ fprintf(file, ",");
219+ }}
220+ }}
221+ fclose(file);
222+ }}
223+
224+ static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) {{
225+ FILE *file = fopen(filename, "r");
226+ if (file == NULL) {{
227+ printf("Could not open file %s\\ n", filename);
228+ return;
229+ }}
230+ int index = 0;
231+ while (fscanf(file, "%hd,", &buffer[index]) != EOF && index < size) {{
232+ index++;
233+ }}
234+ fclose(file);
235+ }}
236+ int main(int argc, char ** argv) {{
237+ int M = { M } , N = { N } , K = { K } ;
238+
239+ // initialize sycl handles
240+ sycl::queue q{{sycl::gpu_selector_v}};
241+ sycl::ext::intel::device_ptr<sycl::float16> A =
242+ sycl::malloc_device<sycl::float16>(M * K * 2, q);
243+ sycl::ext::intel::device_ptr<sycl::float16> B =
244+ sycl::malloc_device<sycl::float16>(K * N * 2, q);
245+ sycl::ext::intel::device_ptr<sycl::float16> C =
246+ sycl::malloc_device<sycl::float16>(M * N * 4, q);
247+
248+ // initialize input data
249+ int16_t hA[M * K];
250+ int16_t hB[K * N];
251+ memset(hA, 0, M * K * 2);
252+ memset(hB, 0, K * N * 2);
253+ read_csv_to_buffer(argv[1], hA, M * K);
254+ read_csv_to_buffer(argv[2], hB, K * N);
255+ q.memcpy(A, hA, M * K * 2).wait();
256+ q.memcpy(B, hB, K * N * 2).wait();
257+
258+ // launch kernel
259+ load_matmul_fp16();
260+ int32_t ret;
261+ int algo_id = { algo_id } ;
262+ if (algo_id == 0) {{
263+ ret = matmul_fp16_default(q, C, A, B, M, N, K, N, 1, K, 1, N, 1);
264+ }} else {{
265+ ret = matmul_fp16(q, C, A, B, M, N, K, N, 1, K, 1, N, 1, { algo_id } );
266+ }}
267+ if (ret != 0) fprintf(stderr, "kernel launch failed\\ n");
268+ assert(ret == 0);
269+
270+ q.wait();
271+
272+ // read data
273+ int32_t hC[M * N];
274+ memset(hC, 0, M * N * 4);
275+ q.memcpy(hC, C, M * N * 4).wait();
276+ write_buffer_to_csv(argv[3], hC, M * N);
277+
278+ // free sycl resources
279+ unload_matmul_fp16();
280+ sycl::free(A, q);
281+ sycl::free(B, q);
282+ sycl::free(C, q);
283+ }}
284+ """
285+ src_name = "test.c"
286+ if is_xpu ():
287+ src_name = "test.cpp"
288+ with open (os .path .join (dir , src_name ), "w" ) as file :
175289 file .write (src )
176290
177- command = ["gcc" , "test.c" ]
178- for inc_dir in include_dir :
179- command .extend (["-I" , inc_dir ])
180- for lib_dir in library_dirs ():
181- command .extend (["-L" , lib_dir ])
182- command .extend (["-l" , "cuda" , "-L" , dir , "-l" , "kernel" , "-o" , exe ])
291+ if is_cuda ():
292+ command = ["gcc" , "test.c" ]
293+ for inc_dir in include_dir :
294+ command .extend (["-I" , inc_dir ])
295+ for lib_dir in library_dirs ():
296+ command .extend (["-L" , lib_dir ])
297+ command .extend (["-l" , "cuda" , "-L" , dir , "-l" , "kernel" , "-o" , exe ])
298+
299+ if is_xpu ():
300+ command = ["icpx" , "test.cpp" ]
301+ for inc_dir in COMPILATION_HELPER .include_dir :
302+ command .extend (["-I" , inc_dir ])
303+ for lib_dir in COMPILATION_HELPER .library_dir :
304+ command .extend (["-L" , lib_dir ])
305+ if COMPILATION_HELPER .libsycl_dir :
306+ for lib_dir in COMPILATION_HELPER .libsycl_dir :
307+ command .extend (["-L" , lib_dir ])
308+ command .extend (["-fsycl" , "-lze_loader" , "-L" , dir , "-l" , "kernel" , "-o" , exe ])
183309 subprocess .run (command , check = True , cwd = dir )
184310
185311
@@ -283,6 +409,7 @@ def test_compile_link_matmul_no_specialization():
283409
284410 with tempfile .TemporaryDirectory () as tmp_dir :
285411 dtype = "fp16"
412+
286413 BM , BN , BK = 16 , 16 , 16
287414
288415 kernel_path = write_triton_kernels (tmp_dir , kernel_src , kernel_utils_src )
@@ -299,9 +426,8 @@ def test_compile_link_matmul_no_specialization():
299426
300427 # run test case
301428 env = os .environ .copy ()
302- env ["LD_LIBRARY_PATH" ] = tmp_dir
429+ env ["LD_LIBRARY_PATH" ] = tmp_dir + ":" + env . get ( "LD_LIBRARY_PATH" , "" )
303430 subprocess .run (["./test" , a_path , b_path , c_path ], env = env , check = True , cwd = tmp_dir )
304-
305431 # read data and compare against reference
306432 c = np .genfromtxt (c_path , delimiter = "," , dtype = np .int32 )
307433 c_tri = c .reshape ((M , N )).view (np .float32 )
@@ -330,7 +456,7 @@ def test_compile_link_matmul():
330456
331457 # run test case
332458 env = os .environ .copy ()
333- env ["LD_LIBRARY_PATH" ] = tmp_dir
459+ env ["LD_LIBRARY_PATH" ] = tmp_dir + ":" + env . get ( "LD_LIBRARY_PATH" , "" )
334460 subprocess .run (["./test" , a_path , b_path , c_path ], env = env , check = True , cwd = tmp_dir )
335461
336462 # read data and compare against reference
@@ -361,7 +487,7 @@ def test_launcher_has_no_available_kernel():
361487
362488 # run test case
363489 env = os .environ .copy ()
364- env ["LD_LIBRARY_PATH" ] = tmp_dir
490+ env ["LD_LIBRARY_PATH" ] = tmp_dir + ":" + env . get ( "LD_LIBRARY_PATH" , "" )
365491 result = subprocess .run (
366492 ["./test" , a_path , b_path , c_path ],
367493 env = env ,
@@ -410,7 +536,7 @@ def test_compile_link_autotune_matmul():
410536 gen_test_bin (tmp_dir , M , N , K , exe = test_name , algo_id = algo_id )
411537
412538 env = os .environ .copy ()
413- env ["LD_LIBRARY_PATH" ] = tmp_dir
539+ env ["LD_LIBRARY_PATH" ] = tmp_dir + ":" + env . get ( "LD_LIBRARY_PATH" , "" )
414540 subprocess .run (
415541 [f"./{ test_name } " , a_path , b_path , c_path ],
416542 check = True ,
@@ -440,3 +566,21 @@ def test_ttgir_to_ptx():
440566 ptx = k .asm ["ptx" ]
441567 assert ".target sm_80" in ptx
442568 assert ".address_size 64" in ptx
569+
570+
571+ def test_ttgir_to_spv ():
572+ src = """
573+ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} {
574+ tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32>, %arg1: !tt.ptr<i32>) {
575+ tt.return
576+ }
577+ }
578+ """
579+ with tempfile .TemporaryDirectory () as tmp_dir :
580+ kernel_path = os .path .join (tmp_dir , "empty_kernel.ttgir" )
581+ with open (kernel_path , "w" ) as fp :
582+ fp .write (src )
583+ k = triton .compile (kernel_path , target = triton .runtime .driver .active .get_current_target ())
584+ spv = k .asm ['spvdis' ]
585+ assert "OpCapability KernelAttributesINTEL" in spv
586+ assert "SubgroupSize 32" in spv
0 commit comments