@@ -194,7 +194,7 @@ def wait(self):
194194
195195
196196def ty_to_cpp (ty ):
197- if ty [0 ] == '*' or ty == "none" :
197+ if ty [0 ] == '*' :
198198 return "void*"
199199 return {
200200 "i1" : "int32_t" ,
@@ -215,10 +215,12 @@ def ty_to_cpp(ty):
215215 }[ty ]
216216
217217
218- def make_launcher (constants , signature , ids ):
218+ def make_launcher (constants , signature ):
219219
220220 def _extracted_type (ty ):
221- if ty [0 ] == '*' or ty == "none" :
221+ if ty == "constexpr" :
222+ return "PyObject*"
223+ if ty [0 ] == '*' :
222224 return "PyObject*"
223225 if ty [0 ] == '[' :
224226 if ty == "[]" :
@@ -252,7 +254,6 @@ def format_of(ty):
252254 "uint64_t" : "K" ,
253255 }[ty ]
254256
255- signature = {k : v for k , v in signature .items () if v != 'constexpr' }
256257 args_format = '' .join ([format_of (_extracted_type (ty )) for ty in signature .values ()])
257258 format = "iiiOOOOOO" + args_format
258259 signature = ',' .join (signature .values ()).replace ('[' , '' ).replace (']' , '' )
@@ -262,16 +263,22 @@ def format_of(ty):
262263
263264 # Record the end of regular arguments;
264265 # subsequent arguments are architecture-specific descriptors.
265- arg_decls = ', ' .join (f"{ ty_to_cpp (ty )} arg{ i } " for i , ty in signature .items ())
266+ arg_decls = ', ' .join (f"{ ty_to_cpp (ty )} arg{ i } " for i , ty in signature .items () if ty != "constexpr" )
266267 internal_args_list = []
267268 for i , ty in signature .items ():
268- if ty [0 ] == "*" or ty == "none" :
269+ if ty [0 ] == "*" :
269270 internal_args_list .append (f"ptr_info{ i } .dev_ptr" )
270- else :
271+ elif ty != "constexpr" :
271272 internal_args_list .append (f"_arg{ i } " )
272273
273274 # generate glue code
274- params = [f"&arg{ i } " for i , ty in signature .items () if i not in constants and ty != "none" ]
275+ newline = '\n '
276+ ptr_decls = [
277+ f"DevicePtrInfo ptr_info{ i } = getPointer(_arg{ i } , { i } , stream); if (!ptr_info{ i } .valid) return NULL;"
278+ for i , ty in signature .items ()
279+ if ty [0 ] == "*"
280+ ]
281+ params = [f"&arg{ i } " for i , ty in signature .items () if ty != "constexpr" ]
275282 src = f"""
276283#include <cstddef>
277284#include <string>
@@ -394,7 +401,7 @@ def format_of(ty):
394401 assert(num_params == expected_num_params && "number of kernel param not matched");
395402 // Submit the imported kernel.
396403 auto cgf = [&](sycl::handler &cgh) {{
397- { " " .join (f'set_scalar_arg<{ ty_to_cpp (item )} >(cgh, { idx } , params[{ idx } ]);' for idx , item in enumerate ([signature [i ] for i in signature if i not in constants and signature [i ] != "none " ]))}
404+ { " " .join (f'set_scalar_arg<{ ty_to_cpp (item )} >(cgh, { idx } , params[{ idx } ]);' for idx , item in enumerate ([signature [i ] for i in signature if signature [i ] != "constexpr " ]))}
398405 if (shared_memory) {{
399406 using share_mem_t = sycl::local_accessor<int8_t, 1>;
400407 share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
@@ -418,7 +425,7 @@ def format_of(ty):
418425 PyObject *py_obj_stream;
419426 PyObject* py_kernel;
420427
421- { ' ' .join ([f"{ _extracted_type (ty )} _arg{ i } ; " for i , ty in signature .items ()])}
428+ { newline .join ([f"{ _extracted_type (ty )} _arg{ i } ;" for i , ty in signature .items ()])}
422429 if(!PyArg_ParseTuple(args, \" { format } \" , &gridX, &gridY, &gridZ, &py_obj_stream, &py_kernel,
423430 &kernel_metadata, &launch_metadata,
424431 &launch_enter_hook, &launch_exit_hook { args_list } )) {{
@@ -467,7 +474,7 @@ def format_of(ty):
467474 if(kernel_ptr == nullptr) return NULL;
468475 sycl::kernel kernel = *kernel_ptr;
469476
470- { "; " .join ([ f"DevicePtrInfo ptr_info { i } = getPointer(_arg { i } , { i } , stream); if (!ptr_info { i } .valid) return NULL;" if ty [ 0 ] == "*" or ty == "none" else "" for i , ty in signature . items ()]) } ;
477+ { newline .join (ptr_decls ) }
471478 sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel { ',' + ', ' .join (internal_args_list ) if len (internal_args_list ) > 0 else '' } );
472479
473480 if(launch_exit_hook != Py_None){{
@@ -568,11 +575,11 @@ def serialize_args(args, constants, signature):
568575class XPULauncher (object ):
569576
570577 def __init__ (self , src , metadata ):
571- ids = {"ids_of_const_exprs" : src .fn .constexprs if hasattr (src , "fn" ) else tuple ()}
572578 constants = src .constants if hasattr (src , "constants" ) else dict ()
573- self .constants = {idx : value for idx , value in constants .items ()}
579+ arg_idx = lambda x : (src .fn .arg_names .index (x ), ) if isinstance (x , str ) else x
580+ self .constants = {arg_idx (idx ): value for idx , value in constants .items ()}
574581 self .signature = {idx : value for idx , value in src .signature .items ()}
575- src = make_launcher (self .constants , self .signature , ids )
582+ src = make_launcher (self .constants , self .signature )
576583 mod = compile_module_from_src (src , "__triton_launcher" )
577584 self .launch = mod .launch
578585
0 commit comments