Skip to content

Commit 36aa1cc

Browse files
authored
[NFC] Small intel driver refactor (#3102)
Align with nvidia driver Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 5e6f0e7 commit 36aa1cc

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

third_party/intel/backend/driver.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,15 @@ def format_of(ty):
263263
# Record the end of regular arguments;
264264
# subsequent arguments are architecture-specific descriptors.
265265
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
266+
internal_args_list = []
267+
for i, ty in signature.items():
268+
if ty[0] == "*" or ty == "none":
269+
internal_args_list.append(f"ptr_info{i}.dev_ptr")
270+
else:
271+
internal_args_list.append(f"_arg{i}")
266272

267273
# generate glue code
274+
params = [f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none"]
268275
src = f"""
269276
#include <cstddef>
270277
#include <string>
@@ -369,7 +376,7 @@ def format_of(ty):
369376
std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
370377
{ 'RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {});' if COMPILATION_HELPER.inject_pytorch_dep else "" }
371378
372-
void *params[] = {{ {', '.join(f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none")} }};
379+
void *params[] = {{ {', '.join(params)} }};
373380
uint32_t num_params = sizeof(params)/sizeof(params[0]);
374381
uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();
375382
size_t global_range_x = gridX*threads_per_warp*num_warps;
@@ -461,7 +468,7 @@ def format_of(ty):
461468
sycl::kernel kernel = *kernel_ptr;
462469
463470
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])};
464-
sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {',' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" or ty == "none" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''});
471+
sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {',' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
465472
466473
if(launch_exit_hook != Py_None){{
467474
PyObject* args = Py_BuildValue("(O)", launch_metadata);

0 commit comments

Comments
 (0)