Skip to content

Commit e69f985

Browse files
authored
Hot fix for benchmark_driver.py after #3043 (#3060)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 2347dba commit e69f985

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

benchmarks/triton_kernels_benchmark/benchmark_driver.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from triton.backends.driver import DriverBase
99
from triton.runtime.cache import get_cache_manager
1010
from triton.runtime.build import _build, quiet
11+
from triton._utils import parse_list_string
1112

1213
import torch
1314

@@ -84,7 +85,7 @@ def get_sycl_queue(self):
8485

8586

8687
def ty_to_cpp(ty):
87-
if ty[0] == "*":
88+
if ty[0] == "*" or ty == "none":
8889
return "void*"
8990
return {
9091
"i1": "int32_t",
@@ -106,16 +107,27 @@ def ty_to_cpp(ty):
106107

107108

108109
def make_launcher(constants, signature, ids): # pylint: disable=unused-argument
109-
# Record the end of regular arguments;
110-
# subsequent arguments are architecture-specific descriptors.
111-
arg_decls = ", ".join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
112110

113111
def _extracted_type(ty):
114-
if ty[0] == "*":
112+
if ty[0] == "*" or ty == "none":
115113
return "PyObject*"
114+
if ty[0] == "[":
115+
if ty == "[]":
116+
return "[]"
117+
tys = parse_list_string(ty)
118+
val = ",".join(map(_extracted_type, tys))
119+
return f"[{val}]"
116120
return ty_to_cpp(ty)
117121

118122
def format_of(ty):
123+
if ty == "void*":
124+
return "O"
125+
if ty[0] == "[":
126+
if ty == "[]":
127+
return "()"
128+
tys = parse_list_string(ty)
129+
val = "".join(map(format_of, tys))
130+
return f"({val})"
119131
return {
120132
"PyObject*": "O",
121133
"float": "f",
@@ -131,10 +143,18 @@ def format_of(ty):
131143
"uint64_t": "K",
132144
}[ty]
133145

146+
signature = {k: v for k, v in signature.items() if v != "constexpr"}
134147
args_format = "".join([format_of(_extracted_type(ty)) for ty in signature.values()])
135148
fmt = "iiiOOOOOO" + args_format
149+
signature = ",".join(signature.values()).replace("[", "").replace("]", "")
150+
signature = list(filter(bool, signature.split(",")))
151+
signature = dict(enumerate(signature))
136152
args_list = ", " + ", ".join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""
137153

154+
# Record the end of regular arguments;
155+
# subsequent arguments are architecture-specific descriptors.
156+
arg_decls = ", ".join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
157+
138158
# generate glue code
139159
src = f"""
140160
#include <cstddef>
@@ -229,33 +249,15 @@ def format_of(ty):
229249
return ptr_info;
230250
}}
231251
// start sycl
232-
static void set_scalar_arg(
233-
sycl::handler& cgh,
234-
int index,
235-
size_t size,
236-
const void* value) {{
237-
switch (size) {{
238-
case sizeof(uint8_t):
239-
cgh.set_arg(index, *static_cast<const uint8_t*>(value));
240-
break;
241-
case sizeof(uint16_t):
242-
cgh.set_arg(index, *static_cast<const uint16_t*>(value));
243-
break;
244-
case sizeof(uint32_t):
245-
cgh.set_arg(index, *static_cast<const uint32_t*>(value));
246-
break;
247-
case sizeof(uint64_t):
248-
cgh.set_arg(index, *static_cast<const uint64_t*>(value));
249-
break;
250-
default:
251-
assert(false && "wrong scalar size in sycl gen.");
252-
}}
252+
template <class T>
253+
static inline void set_scalar_arg(sycl::handler &cgh, int index, const void *value) {{
254+
cgh.set_arg(index, *static_cast<const T *>(value));
253255
}}
254256
static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {", " + arg_decls if len(arg_decls) > 0 else ""}) {{
255257
256258
std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
257259
RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {{}});
258-
void *params[] = {{ {", ".join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
260+
void *params[] = {{ {", ".join(f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none")} }};
259261
uint32_t num_params = sizeof(params)/sizeof(params[0]);
260262
uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();
261263
size_t global_range_x = gridX*threads_per_warp*num_warps;
@@ -273,8 +275,7 @@ def format_of(ty):
273275
assert(num_params == expected_num_params && "number of kernel param not matched");
274276
// Submit the imported kernel.
275277
auto cgf = [&](sycl::handler &cgh) {{
276-
{" ".join(f"set_scalar_arg(cgh, {idx}, sizeof({ty_to_cpp(item)}), params[{idx}]);" for idx, item in enumerate([signature[i] for i in signature if i not in constants]))}
277-
if (shared_memory) {{
278+
{" ".join(f"set_scalar_arg<{ty_to_cpp(item)}>(cgh, {idx}, params[{idx}]);" for idx, item in enumerate([signature[i] for i in signature if i not in constants and signature[i] != "none"]))} if (shared_memory) {{
278279
using share_mem_t = sycl::local_accessor<int8_t, 1>;
279280
share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
280281
cgh.set_arg(num_params, local_buffer);
@@ -336,8 +337,8 @@ def format_of(ty):
336337
if(kernel_ptr == nullptr) return NULL;
337338
sycl::kernel kernel = *kernel_ptr;
338339
339-
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
340-
sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {"," + ", ".join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""});
340+
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])};
341+
sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {"," + ", ".join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" or ty == "none" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""});
341342
342343
if(launch_exit_hook != Py_None){{
343344
PyObject* args = Py_BuildValue("(O)", launch_metadata);
@@ -440,9 +441,8 @@ class XPULauncher:
440441
def __init__(self, src, metadata): # pylint: disable=unused-argument
441442
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
442443
constants = src.constants if hasattr(src, "constants") else {}
443-
cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i
444-
self.constants = {cst_key(key): value for key, value in constants.items()}
445-
self.signature = {cst_key(key): value for key, value in src.signature.items()}
444+
self.constants = dict(constants.items())
445+
self.signature = dict(src.signature.items())
446446
src = make_launcher(self.constants, self.signature, ids)
447447
mod = compile_module_from_src(src, "__triton_launcher")
448448
self.launch = mod.launch

0 commit comments

Comments
 (0)