Skip to content

Commit 2cab9ed

Browse files
committed
[intel] update driver.py
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 36d8b1b commit 2cab9ed

File tree

1 file changed

+27
-13
lines changed

1 file changed

+27
-13
lines changed

third_party/intel/backend/driver.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def wait(self):
194194

195195

196196
def ty_to_cpp(ty):
197-
if ty[0] == '*' or ty == "none":
197+
if ty[0] == '*':
198198
return "void*"
199199
return {
200200
"i1": "int32_t",
@@ -215,10 +215,12 @@ def ty_to_cpp(ty):
215215
}[ty]
216216

217217

218-
def make_launcher(constants, signature, ids):
218+
def make_launcher(constants, signature):
219219

220220
def _extracted_type(ty):
221-
if ty[0] == '*' or ty == "none":
221+
if ty == "constexpr":
222+
return "PyObject*"
223+
if ty[0] == '*':
222224
return "PyObject*"
223225
if ty[0] == '[':
224226
if ty == "[]":
@@ -252,7 +254,6 @@ def format_of(ty):
252254
"uint64_t": "K",
253255
}[ty]
254256

255-
signature = {k: v for k, v in signature.items() if v != 'constexpr'}
256257
args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
257258
format = "iiiOOOOOO" + args_format
258259
signature = ','.join(signature.values()).replace('[', '').replace(']', '')
@@ -262,9 +263,22 @@ def format_of(ty):
262263

263264
# Record the end of regular arguments;
264265
# subsequent arguments are architecture-specific descriptors.
265-
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
266+
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr")
267+
internal_args_list = []
268+
for i, ty in signature.items():
269+
if ty[0] == "*":
270+
internal_args_list.append(f"ptr_info{i}.dev_ptr")
271+
elif ty != "constexpr":
272+
internal_args_list.append(f"_arg{i}")
266273

267274
# generate glue code
275+
newline = '\n '
276+
ptr_decls = [
277+
f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;"
278+
for i, ty in signature.items()
279+
if ty[0] == "*"
280+
]
281+
params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
268282
src = f"""
269283
#include <cstddef>
270284
#include <string>
@@ -369,7 +383,7 @@ def format_of(ty):
369383
std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
370384
{ 'RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {});' if COMPILATION_HELPER.inject_pytorch_dep else "" }
371385
372-
void *params[] = {{ {', '.join(f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none")} }};
386+
void *params[] = {{ {', '.join(params)} }};
373387
uint32_t num_params = sizeof(params)/sizeof(params[0]);
374388
uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();
375389
size_t global_range_x = gridX*threads_per_warp*num_warps;
@@ -387,7 +401,7 @@ def format_of(ty):
387401
assert(num_params == expected_num_params && "number of kernel param not matched");
388402
// Submit the imported kernel.
389403
auto cgf = [&](sycl::handler &cgh) {{
390-
{" ".join(f'set_scalar_arg<{ty_to_cpp(item)}>(cgh, {idx}, params[{idx}]);' for idx, item in enumerate([signature[i] for i in signature if i not in constants and signature[i] != "none"]))}
404+
{" ".join(f'set_scalar_arg<{ty_to_cpp(item)}>(cgh, {idx}, params[{idx}]);' for idx, item in enumerate([signature[i] for i in signature if signature[i] != "constexpr"]))}
391405
if (shared_memory) {{
392406
using share_mem_t = sycl::local_accessor<int8_t, 1>;
393407
share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
@@ -411,7 +425,7 @@ def format_of(ty):
411425
PyObject *py_obj_stream;
412426
PyObject* py_kernel;
413427
414-
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
428+
{newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
415429
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &py_obj_stream, &py_kernel,
416430
&kernel_metadata, &launch_metadata,
417431
&launch_enter_hook, &launch_exit_hook {args_list})) {{
@@ -460,8 +474,8 @@ def format_of(ty):
460474
if(kernel_ptr == nullptr) return NULL;
461475
sycl::kernel kernel = *kernel_ptr;
462476
463-
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])};
464-
sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {',' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" or ty == "none" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''});
477+
{newline.join(ptr_decls)}
478+
sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {',' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
465479
466480
if(launch_exit_hook != Py_None){{
467481
PyObject* args = Py_BuildValue("(O)", launch_metadata);
@@ -561,11 +575,11 @@ def serialize_args(args, constants, signature):
561575
class XPULauncher(object):
562576

563577
def __init__(self, src, metadata):
564-
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
565578
constants = src.constants if hasattr(src, "constants") else dict()
566-
self.constants = {idx: value for idx, value in constants.items()}
579+
arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
580+
self.constants = {arg_idx(idx): value for idx, value in constants.items()}
567581
self.signature = {idx: value for idx, value in src.signature.items()}
568-
src = make_launcher(self.constants, self.signature, ids)
582+
src = make_launcher(self.constants, self.signature)
569583
mod = compile_module_from_src(src, "__triton_launcher")
570584
self.launch = mod.launch
571585

0 commit comments

Comments
 (0)