@@ -263,8 +263,15 @@ def format_of(ty):
263263 # Record the end of regular arguments;
264264 # subsequent arguments are architecture-specific descriptors.
265265 arg_decls = ', ' .join (f"{ ty_to_cpp (ty )} arg{ i } " for i , ty in signature .items ())
266+ internal_args_list = []
267+ for i , ty in signature .items ():
268+ if ty [0 ] == "*" or ty == "none" :
269+ internal_args_list .append (f"ptr_info{ i } .dev_ptr" )
270+ else :
271+ internal_args_list .append (f"_arg{ i } " )
266272
267273 # generate glue code
274+ params = [f"&arg{ i } " for i , ty in signature .items () if i not in constants and ty != "none" ]
268275 src = f"""
269276#include <cstddef>
270277#include <string>
@@ -369,7 +376,7 @@ def format_of(ty):
369376 std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
370377 { 'RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {});' if COMPILATION_HELPER .inject_pytorch_dep else "" }
371378
372- void *params[] = {{ { ', ' .join (f"&arg { i } " for i , ty in signature . items () if i not in constants and ty != "none" )} }};
379+ void *params[] = {{ { ', ' .join (params )} }};
373380 uint32_t num_params = sizeof(params)/sizeof(params[0]);
374381 uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();
375382 size_t global_range_x = gridX*threads_per_warp*num_warps;
@@ -461,7 +468,7 @@ def format_of(ty):
461468 sycl::kernel kernel = *kernel_ptr;
462469
463470 { "; " .join ([f"DevicePtrInfo ptr_info{ i } = getPointer(_arg{ i } , { i } , stream); if (!ptr_info{ i } .valid) return NULL;" if ty [0 ] == "*" or ty == "none" else "" for i , ty in signature .items ()])} ;
464- sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel { ',' + ', ' .join (f"ptr_info { i } .dev_ptr" if ty [ 0 ] == "*" or ty == "none" else f"_arg { i } " for i , ty in signature . items ()) if len (signature ) > 0 else '' } );
471+ sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel { ',' + ', ' .join (internal_args_list ) if len (internal_args_list ) > 0 else '' } );
465472
466473 if(launch_exit_hook != Py_None){{
467474 PyObject* args = Py_BuildValue("(O)", launch_metadata);
0 commit comments