@@ -194,7 +194,7 @@ def wait(self):
194194
195195
196196def ty_to_cpp (ty ):
197- if ty [0 ] == '*' or ty == "none" :
197+ if ty [0 ] == '*' :
198198 return "void*"
199199 return {
200200 "i1" : "int32_t" ,
@@ -215,10 +215,12 @@ def ty_to_cpp(ty):
215215 }[ty ]
216216
217217
218- def make_launcher (constants , signature , ids ):
218+ def make_launcher (constants , signature ):
219219
220220 def _extracted_type (ty ):
221- if ty [0 ] == '*' or ty == "none" :
221+ if ty == "constexpr" :
222+ return "PyObject*"
223+ if ty [0 ] == '*' :
222224 return "PyObject*"
223225 if ty [0 ] == '[' :
224226 if ty == "[]" :
@@ -252,7 +254,6 @@ def format_of(ty):
252254 "uint64_t" : "K" ,
253255 }[ty ]
254256
255- signature = {k : v for k , v in signature .items () if v != 'constexpr' }
256257 args_format = '' .join ([format_of (_extracted_type (ty )) for ty in signature .values ()])
257258 format = "iiiOOOOOO" + args_format
258259 signature = ',' .join (signature .values ()).replace ('[' , '' ).replace (']' , '' )
@@ -262,9 +263,22 @@ def format_of(ty):
262263
263264 # Record the end of regular arguments;
264265 # subsequent arguments are architecture-specific descriptors.
265- arg_decls = ', ' .join (f"{ ty_to_cpp (ty )} arg{ i } " for i , ty in signature .items ())
266+ arg_decls = ', ' .join (f"{ ty_to_cpp (ty )} arg{ i } " for i , ty in signature .items () if ty != "constexpr" )
267+ internal_args_list = []
268+ for i , ty in signature .items ():
269+ if ty [0 ] == "*" :
270+ internal_args_list .append (f"ptr_info{ i } .dev_ptr" )
271+ elif ty != "constexpr" :
272+ internal_args_list .append (f"_arg{ i } " )
266273
267274 # generate glue code
275+ newline = '\n '
276+ ptr_decls = [
277+ f"DevicePtrInfo ptr_info{ i } = getPointer(_arg{ i } , { i } , stream); if (!ptr_info{ i } .valid) return NULL;"
278+ for i , ty in signature .items ()
279+ if ty [0 ] == "*"
280+ ]
281+ params = [f"&arg{ i } " for i , ty in signature .items () if ty != "constexpr" ]
268282 src = f"""
269283#include <cstddef>
270284#include <string>
@@ -369,7 +383,7 @@ def format_of(ty):
369383 std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
370384 { 'RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {});' if COMPILATION_HELPER .inject_pytorch_dep else "" }
371385
372- void *params[] = {{ { ', ' .join (f"&arg { i } " for i , ty in signature . items () if i not in constants and ty != "none" )} }};
386+ void *params[] = {{ { ', ' .join (params )} }};
373387 uint32_t num_params = sizeof(params)/sizeof(params[0]);
374388 uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();
375389 size_t global_range_x = gridX*threads_per_warp*num_warps;
@@ -387,7 +401,7 @@ def format_of(ty):
387401 assert(num_params == expected_num_params && "number of kernel param not matched");
388402 // Submit the imported kernel.
389403 auto cgf = [&](sycl::handler &cgh) {{
390- { " " .join (f'set_scalar_arg<{ ty_to_cpp (item )} >(cgh, { idx } , params[{ idx } ]);' for idx , item in enumerate ([signature [i ] for i in signature if i not in constants and signature [i ] != "none " ]))}
404+ { " " .join (f'set_scalar_arg<{ ty_to_cpp (item )} >(cgh, { idx } , params[{ idx } ]);' for idx , item in enumerate ([signature [i ] for i in signature if signature [i ] != "constexpr " ]))}
391405 if (shared_memory) {{
392406 using share_mem_t = sycl::local_accessor<int8_t, 1>;
393407 share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
@@ -411,7 +425,7 @@ def format_of(ty):
411425 PyObject *py_obj_stream;
412426 PyObject* py_kernel;
413427
414- { ' ' .join ([f"{ _extracted_type (ty )} _arg{ i } ; " for i , ty in signature .items ()])}
428+ { newline .join ([f"{ _extracted_type (ty )} _arg{ i } ;" for i , ty in signature .items ()])}
415429 if(!PyArg_ParseTuple(args, \" { format } \" , &gridX, &gridY, &gridZ, &py_obj_stream, &py_kernel,
416430 &kernel_metadata, &launch_metadata,
417431 &launch_enter_hook, &launch_exit_hook { args_list } )) {{
@@ -460,8 +474,8 @@ def format_of(ty):
460474 if(kernel_ptr == nullptr) return NULL;
461475 sycl::kernel kernel = *kernel_ptr;
462476
463- { "; " .join ([ f"DevicePtrInfo ptr_info { i } = getPointer(_arg { i } , { i } , stream); if (!ptr_info { i } .valid) return NULL;" if ty [ 0 ] == "*" or ty == "none" else "" for i , ty in signature . items ()]) } ;
464- sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel { ',' + ', ' .join (f"ptr_info { i } .dev_ptr" if ty [ 0 ] == "*" or ty == "none" else f"_arg { i } " for i , ty in signature . items ()) if len (signature ) > 0 else '' } );
477+ { newline .join (ptr_decls ) }
478+ sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel { ',' + ', ' .join (internal_args_list ) if len (internal_args_list ) > 0 else '' } );
465479
466480 if(launch_exit_hook != Py_None){{
467481 PyObject* args = Py_BuildValue("(O)", launch_metadata);
@@ -561,11 +575,11 @@ def serialize_args(args, constants, signature):
561575class XPULauncher (object ):
562576
563577 def __init__ (self , src , metadata ):
564- ids = {"ids_of_const_exprs" : src .fn .constexprs if hasattr (src , "fn" ) else tuple ()}
565578 constants = src .constants if hasattr (src , "constants" ) else dict ()
566- self .constants = {idx : value for idx , value in constants .items ()}
579+ arg_idx = lambda x : (src .fn .arg_names .index (x ), ) if isinstance (x , str ) else x
580+ self .constants = {arg_idx (idx ): value for idx , value in constants .items ()}
567581 self .signature = {idx : value for idx , value in src .signature .items ()}
568- src = make_launcher (self .constants , self .signature , ids )
582+ src = make_launcher (self .constants , self .signature )
569583 mod = compile_module_from_src (src , "__triton_launcher" )
570584 self .launch = mod .launch
571585
0 commit comments