88from triton .backends .driver import DriverBase
99from triton .runtime .cache import get_cache_manager
1010from triton .runtime .build import _build , quiet
11+ from triton ._utils import parse_list_string
1112
1213import torch
1314
@@ -84,7 +85,7 @@ def get_sycl_queue(self):
8485
8586
8687def ty_to_cpp (ty ):
87- if ty [0 ] == "*" :
88+ if ty [0 ] == "*" or ty == "none" :
8889 return "void*"
8990 return {
9091 "i1" : "int32_t" ,
@@ -106,16 +107,27 @@ def ty_to_cpp(ty):
106107
107108
108109def make_launcher (constants , signature , ids ): # pylint: disable=unused-argument
109- # Record the end of regular arguments;
110- # subsequent arguments are architecture-specific descriptors.
111- arg_decls = ", " .join (f"{ ty_to_cpp (ty )} arg{ i } " for i , ty in signature .items ())
112110
113111 def _extracted_type (ty ):
114- if ty [0 ] == "*" :
112+ if ty [0 ] == "*" or ty == "none" :
115113 return "PyObject*"
114+ if ty [0 ] == "[" :
115+ if ty == "[]" :
116+ return "[]"
117+ tys = parse_list_string (ty )
118+ val = "," .join (map (_extracted_type , tys ))
119+ return f"[{ val } ]"
116120 return ty_to_cpp (ty )
117121
118122 def format_of (ty ):
123+ if ty == "void*" :
124+ return "O"
125+ if ty [0 ] == "[" :
126+ if ty == "[]" :
127+ return "()"
128+ tys = parse_list_string (ty )
129+ val = "" .join (map (format_of , tys ))
130+ return f"({ val } )"
119131 return {
120132 "PyObject*" : "O" ,
121133 "float" : "f" ,
@@ -131,10 +143,18 @@ def format_of(ty):
131143 "uint64_t" : "K" ,
132144 }[ty ]
133145
146+ signature = {k : v for k , v in signature .items () if v != "constexpr" }
134147 args_format = "" .join ([format_of (_extracted_type (ty )) for ty in signature .values ()])
135148 fmt = "iiiOOOOOO" + args_format
149+ signature = "," .join (signature .values ()).replace ("[" , "" ).replace ("]" , "" )
150+ signature = list (filter (bool , signature .split ("," )))
151+ signature = dict (enumerate (signature ))
136152 args_list = ", " + ", " .join (f"&_arg{ i } " for i , ty in signature .items ()) if len (signature ) > 0 else ""
137153
154+ # Record the end of regular arguments;
155+ # subsequent arguments are architecture-specific descriptors.
156+ arg_decls = ", " .join (f"{ ty_to_cpp (ty )} arg{ i } " for i , ty in signature .items ())
157+
138158 # generate glue code
139159 src = f"""
140160 #include <cstddef>
@@ -229,33 +249,15 @@ def format_of(ty):
229249 return ptr_info;
230250 }}
231251// start sycl
232- static void set_scalar_arg(
233- sycl::handler& cgh,
234- int index,
235- size_t size,
236- const void* value) {{
237- switch (size) {{
238- case sizeof(uint8_t):
239- cgh.set_arg(index, *static_cast<const uint8_t*>(value));
240- break;
241- case sizeof(uint16_t):
242- cgh.set_arg(index, *static_cast<const uint16_t*>(value));
243- break;
244- case sizeof(uint32_t):
245- cgh.set_arg(index, *static_cast<const uint32_t*>(value));
246- break;
247- case sizeof(uint64_t):
248- cgh.set_arg(index, *static_cast<const uint64_t*>(value));
249- break;
250- default:
251- assert(false && "wrong scalar size in sycl gen.");
252- }}
252+ template <class T>
253+ static inline void set_scalar_arg(sycl::handler &cgh, int index, const void *value) {{
254+ cgh.set_arg(index, *static_cast<const T *>(value));
253255 }}
254256 static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr { ", " + arg_decls if len (arg_decls ) > 0 else "" } ) {{
255257
256258 std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
257259 RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {{}});
258- void *params[] = {{ { ", " .join (f"&arg{ i } " for i in signature .keys () if i not in constants )} }};
260+ void *params[] = {{ { ", " .join (f"&arg{ i } " for i , ty in signature .items () if i not in constants and ty != "none" )} }};
259261 uint32_t num_params = sizeof(params)/sizeof(params[0]);
260262 uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();
261263 size_t global_range_x = gridX*threads_per_warp*num_warps;
@@ -273,8 +275,7 @@ def format_of(ty):
273275 assert(num_params == expected_num_params && "number of kernel param not matched");
274276 // Submit the imported kernel.
275277 auto cgf = [&](sycl::handler &cgh) {{
276- { " " .join (f"set_scalar_arg(cgh, { idx } , sizeof({ ty_to_cpp (item )} ), params[{ idx } ]);" for idx , item in enumerate ([signature [i ] for i in signature if i not in constants ]))}
277- if (shared_memory) {{
278+ { " " .join (f"set_scalar_arg<{ ty_to_cpp (item )} >(cgh, { idx } , params[{ idx } ]);" for idx , item in enumerate ([signature [i ] for i in signature if i not in constants and signature [i ] != "none" ]))} if (shared_memory) {{
278279 using share_mem_t = sycl::local_accessor<int8_t, 1>;
279280 share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
280281 cgh.set_arg(num_params, local_buffer);
@@ -336,8 +337,8 @@ def format_of(ty):
336337 if(kernel_ptr == nullptr) return NULL;
337338 sycl::kernel kernel = *kernel_ptr;
338339
339- { "; " .join ([f"DevicePtrInfo ptr_info{ i } = getPointer(_arg{ i } , { i } , stream); if (!ptr_info{ i } .valid) return NULL;" if ty [0 ] == "*" else "" for i , ty in signature .items ()])} ;
340- sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel { "," + ", " .join (f"ptr_info{ i } .dev_ptr" if ty [0 ]== "*" else f"_arg{ i } " for i , ty in signature .items ()) if len (signature ) > 0 else "" } );
340+ { "; " .join ([f"DevicePtrInfo ptr_info{ i } = getPointer(_arg{ i } , { i } , stream); if (!ptr_info{ i } .valid) return NULL;" if ty [0 ] == "*" or ty == "none" else "" for i , ty in signature .items ()])} ;
341+ sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel { "," + ", " .join (f"ptr_info{ i } .dev_ptr" if ty [0 ]== "*" or ty == "none" else f"_arg{ i } " for i , ty in signature .items ()) if len (signature ) > 0 else "" } );
341342
342343 if(launch_exit_hook != Py_None){{
343344 PyObject* args = Py_BuildValue("(O)", launch_metadata);
@@ -440,9 +441,8 @@ class XPULauncher:
440441 def __init__ (self , src , metadata ): # pylint: disable=unused-argument
441442 ids = {"ids_of_const_exprs" : src .fn .constexprs if hasattr (src , "fn" ) else tuple ()}
442443 constants = src .constants if hasattr (src , "constants" ) else {}
443- cst_key = lambda i : src .fn .arg_names .index (i ) if isinstance (i , str ) else i
444- self .constants = {cst_key (key ): value for key , value in constants .items ()}
445- self .signature = {cst_key (key ): value for key , value in src .signature .items ()}
444+ self .constants = dict (constants .items ())
445+ self .signature = dict (src .signature .items ())
446446 src = make_launcher (self .constants , self .signature , ids )
447447 mod = compile_module_from_src (src , "__triton_launcher" )
448448 self .launch = mod .launch
0 commit comments