Skip to content

Commit bb38a43

Browse files
authored
Reuse kernels launcher from main intel driver for benchmarks (#3070)
Closes #2540 Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 6919b06 commit bb38a43

File tree

3 files changed

+11
-319
lines changed

3 files changed

+11
-319
lines changed
Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,2 @@
11
from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, BENCHMARKING_METHOD # type: ignore # noqa: F401
2-
3-
if BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
4-
from triton.runtime import driver
5-
from . import benchmark_driver
6-
# replace the launcher with the profilier hook.
7-
driver.active.launcher_cls = benchmark_driver.XPULauncher
2+
from . import benchmark_driver # type: ignore # noqa: F401
Lines changed: 3 additions & 305 deletions
Original file line numberDiff line numberDiff line change
@@ -1,308 +1,6 @@
11
import os
22

3-
from triton._utils import parse_list_string
4-
from triton.backends.intel.driver import compile_module_from_src, COMPILATION_HELPER, ty_to_cpp, serialize_args
3+
from .benchmark_testing import BENCHMARKING_METHOD
54

6-
# ------------------------
7-
# Utils
8-
# ------------------------
9-
10-
COMPILATION_HELPER.inject_pytorch_dep()
11-
12-
# ------------------------
13-
# Launcher
14-
# ------------------------
15-
16-
17-
def make_launcher(constants, signature, ids): # pylint: disable=unused-argument
18-
19-
def _extracted_type(ty):
20-
if ty[0] == "*" or ty == "none":
21-
return "PyObject*"
22-
if ty[0] == "[":
23-
if ty == "[]":
24-
return "[]"
25-
tys = parse_list_string(ty)
26-
val = ",".join(map(_extracted_type, tys))
27-
return f"[{val}]"
28-
return ty_to_cpp(ty)
29-
30-
def format_of(ty):
31-
if ty == "void*":
32-
return "O"
33-
if ty[0] == "[":
34-
if ty == "[]":
35-
return "()"
36-
tys = parse_list_string(ty)
37-
val = "".join(map(format_of, tys))
38-
return f"({val})"
39-
return {
40-
"PyObject*": "O",
41-
"float": "f",
42-
"double": "d",
43-
"long": "l",
44-
"int8_t": "b",
45-
"int16_t": "h",
46-
"int32_t": "i",
47-
"int64_t": "L",
48-
"uint8_t": "B",
49-
"uint16_t": "H",
50-
"uint32_t": "I",
51-
"uint64_t": "K",
52-
}[ty]
53-
54-
signature = {k: v for k, v in signature.items() if v != "constexpr"}
55-
args_format = "".join([format_of(_extracted_type(ty)) for ty in signature.values()])
56-
fmt = "iiiOOOOOO" + args_format
57-
signature = ",".join(signature.values()).replace("[", "").replace("]", "")
58-
signature = list(filter(bool, signature.split(",")))
59-
signature = dict(enumerate(signature))
60-
args_list = ", " + ", ".join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""
61-
62-
# Record the end of regular arguments;
63-
# subsequent arguments are architecture-specific descriptors.
64-
arg_decls = ", ".join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
65-
66-
# generate glue code
67-
src = f"""
68-
#include <cstddef>
69-
#include <string>
70-
#include <iostream>
71-
#include <iomanip>
72-
#include <level_zero/ze_api.h>
73-
#include <sycl/sycl.hpp>
74-
#include <ATen/record_function.h>
75-
76-
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
77-
#include <Python.h>
78-
#include <stdio.h>
79-
#include <numpy/arrayobject.h>
80-
81-
static inline void gpuAssert(ze_result_t code, const char *file, int line)
82-
{{
83-
if (code != ZE_RESULT_SUCCESS)
84-
{{
85-
const char* prefix = "Triton Error [ZE]: ";
86-
std::string str = std::to_string(code);
87-
char err[1024] = {{0}};
88-
strcat(err, prefix);
89-
strcat(err, str.c_str());
90-
PyErr_SetString(PyExc_RuntimeError, err);
91-
}}
92-
}}
93-
94-
#define ZE_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
95-
96-
typedef struct _DevicePtrInfo {{
97-
void* dev_ptr;
98-
bool valid;
99-
}} DevicePtrInfo;
100-
101-
static inline void checkDevicePointer(DevicePtrInfo *ptr_info, int idx, const sycl::queue &queue) {{
102-
if (!ptr_info->dev_ptr || !ptr_info->valid) {{
103-
return;
104-
}}
105-
auto context = queue.get_context();
106-
auto handle = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(context);
107-
ze_memory_allocation_properties_t prop;
108-
prop.stype = ZE_STRUCTURE_TYPE_MEMORY_ALLOCATION_PROPERTIES;
109-
prop.pNext = nullptr;
110-
ze_device_handle_t device;
111-
auto res = zeMemGetAllocProperties((ze_context_handle_t)handle, ptr_info->dev_ptr, &prop, &device);
112-
if (res != ZE_RESULT_SUCCESS) {{
113-
PyErr_Format(PyExc_ValueError,
114-
"Cannot get memory properties for pointer argument (at %d, err=%d)", idx, res);
115-
ptr_info->valid = false;
116-
}} else if (prop.type != ZE_MEMORY_TYPE_DEVICE) {{
117-
PyErr_Format(PyExc_ValueError,
118-
"Pointer argument (at %d) doesn't reference XPU device memory (cpu tensor?)", idx);
119-
ptr_info->valid = false;
120-
}}
121-
}}
122-
123-
static inline DevicePtrInfo getPointer(PyObject *obj, int idx, const sycl::queue &queue) {{
124-
DevicePtrInfo ptr_info;
125-
ptr_info.dev_ptr = 0;
126-
ptr_info.valid = true;
127-
if (PyLong_Check(obj)) {{
128-
ptr_info.dev_ptr = PyLong_AsVoidPtr(obj);
129-
checkDevicePointer(&ptr_info, idx, queue);
130-
return ptr_info;
131-
}}
132-
if (obj == Py_None) {{
133-
// valid nullptr
134-
return ptr_info;
135-
}}
136-
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
137-
if(ptr){{
138-
PyObject *empty_tuple = PyTuple_New(0);
139-
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
140-
Py_DECREF(empty_tuple);
141-
Py_DECREF(ptr);
142-
if (!PyLong_Check(ret)) {{
143-
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
144-
ptr_info.valid = false;
145-
return ptr_info;
146-
}}
147-
ptr_info.dev_ptr = PyLong_AsVoidPtr(ret);
148-
if(!ptr_info.dev_ptr) {{
149-
return ptr_info;
150-
}}
151-
checkDevicePointer(&ptr_info, idx, queue);
152-
Py_DECREF(ret); // Thanks ChatGPT!
153-
return ptr_info;
154-
}}
155-
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
156-
ptr_info.valid = false;
157-
return ptr_info;
158-
}}
159-
// start sycl
160-
template <class T>
161-
static inline void set_scalar_arg(sycl::handler &cgh, int index, const void *value) {{
162-
cgh.set_arg(index, *static_cast<const T *>(value));
163-
}}
164-
static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {", " + arg_decls if len(arg_decls) > 0 else ""}) {{
165-
166-
std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
167-
RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {{}});
168-
void *params[] = {{ {", ".join(f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none")} }};
169-
uint32_t num_params = sizeof(params)/sizeof(params[0]);
170-
uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();
171-
size_t global_range_x = gridX*threads_per_warp*num_warps;
172-
size_t global_range_y = gridY;
173-
size_t global_range_z = gridZ;
174-
size_t local_range_x = num_warps*threads_per_warp;
175-
size_t local_range_y = 1;
176-
size_t local_range_z = 1;
177-
sycl::range<3> global_range(global_range_z, global_range_y, global_range_x);
178-
sycl::range<3> local_range(local_range_z, local_range_y, local_range_x);
179-
sycl::nd_range<3> parallel_work_size(global_range, local_range);
180-
if (shared_memory) {{
181-
expected_num_params -= 1;
182-
}}
183-
assert(num_params == expected_num_params && "number of kernel param not matched");
184-
// Submit the imported kernel.
185-
auto cgf = [&](sycl::handler &cgh) {{
186-
{" ".join(f"set_scalar_arg<{ty_to_cpp(item)}>(cgh, {idx}, params[{idx}]);" for idx, item in enumerate([signature[i] for i in signature if i not in constants and signature[i] != "none"]))} if (shared_memory) {{
187-
using share_mem_t = sycl::local_accessor<int8_t, 1>;
188-
share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
189-
cgh.set_arg(num_params, local_buffer);
190-
cgh.parallel_for(parallel_work_size, kernel_ptr);
191-
}} else {{
192-
cgh.parallel_for(parallel_work_size, kernel_ptr);
193-
}}
194-
}};
195-
auto event = stream.submit(cgf);
196-
}}
197-
// end sycl
198-
static PyObject* launch(PyObject* self, PyObject* args) {{
199-
200-
int gridX, gridY, gridZ;
201-
PyObject *launch_enter_hook = NULL;
202-
PyObject *launch_exit_hook = NULL;
203-
PyObject *kernel_metadata = NULL;
204-
PyObject *launch_metadata = NULL;
205-
PyObject *py_obj_stream;
206-
PyObject *py_kernel;
207-
208-
{" ".join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
209-
if(!PyArg_ParseTuple(args, \"{fmt}\", &gridX, &gridY, &gridZ, &py_obj_stream, &py_kernel,
210-
&kernel_metadata, &launch_metadata,
211-
&launch_enter_hook, &launch_exit_hook {args_list})) {{
212-
return NULL;
213-
}}
214-
215-
// extract kernel metadata
216-
int num_warps = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "num_warps"));
217-
int num_ctas = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "num_ctas"));
218-
int shared_memory = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "shared"));
219-
int threads_per_warp = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "threads_per_warp"));
220-
221-
// extract cluster dims
222-
PyObject *clusterDim = PyObject_GetAttrString(kernel_metadata, "cluster_dims");
223-
if (!PyTuple_Check(kernel_metadata)) {{
224-
PyErr_SetString(PyExc_TypeError, "kernel_metadata.cluster_dims must be a tuple");
225-
return NULL;
226-
}}
227-
int clusterDimX = PyLong_AsLong(PyTuple_GetItem(clusterDim, 0));
228-
int clusterDimY = PyLong_AsLong(PyTuple_GetItem(clusterDim, 1));
229-
int clusterDimZ = PyLong_AsLong(PyTuple_GetItem(clusterDim, 2));
230-
// extract launch metadata
231-
if (launch_enter_hook != Py_None){{
232-
PyObject* args = Py_BuildValue("(O)", launch_metadata);
233-
PyObject* ret = PyObject_CallObject(launch_enter_hook, args);
234-
Py_DECREF(args);
235-
if (!ret)
236-
return NULL;
237-
}}
238-
239-
void * pStream = PyLong_AsVoidPtr(py_obj_stream);
240-
//error check
241-
if(pStream == nullptr || py_kernel == nullptr) return NULL;
242-
243-
sycl::queue stream = *(static_cast<sycl::queue*>(pStream));
244-
sycl::kernel* kernel_ptr = reinterpret_cast<sycl::kernel*>(PyCapsule_GetPointer(py_kernel, "kernel"));
245-
if(kernel_ptr == nullptr) return NULL;
246-
sycl::kernel kernel = *kernel_ptr;
247-
248-
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])};
249-
sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {"," + ", ".join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" or ty == "none" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""});
250-
251-
if(launch_exit_hook != Py_None){{
252-
PyObject* args = Py_BuildValue("(O)", launch_metadata);
253-
PyObject* ret = PyObject_CallObject(launch_exit_hook, args);
254-
Py_DECREF(args);
255-
if (!ret)
256-
return NULL;
257-
}}
258-
if (PyErr_Occurred()) {{
259-
return NULL;
260-
}}
261-
262-
// return None
263-
Py_INCREF(Py_None);
264-
return Py_None;
265-
}}
266-
267-
static PyMethodDef ModuleMethods[] = {{
268-
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
269-
{{NULL, NULL, 0, NULL}} // sentinel
270-
}};
271-
272-
static struct PyModuleDef ModuleDef = {{
273-
PyModuleDef_HEAD_INIT,
274-
\"__triton_launcher\",
275-
NULL, //documentation
276-
-1, //size
277-
ModuleMethods
278-
}};
279-
280-
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
281-
PyObject *m = PyModule_Create(&ModuleDef);
282-
if(m == NULL) {{
283-
return NULL;
284-
}}
285-
PyModule_AddFunctions(m, ModuleMethods);
286-
return m;
287-
}}
288-
"""
289-
return src
290-
291-
292-
class XPULauncher:
293-
294-
def __init__(self, src, metadata): # pylint: disable=unused-argument
295-
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
296-
constants = src.constants if hasattr(src, "constants") else {}
297-
self.constants = dict(constants.items())
298-
self.signature = dict(src.signature.items())
299-
src = make_launcher(self.constants, self.signature, ids)
300-
mod = compile_module_from_src(src, "__triton_launcher")
301-
self.launch = mod.launch
302-
303-
def __call__(self, *args, **kwargs):
304-
# Serialize KernelArguments for SPIR-V Runner
305-
serialize_kernel_args = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS", None)
306-
if serialize_kernel_args:
307-
serialize_args(args, self.constants, self.signature)
308-
self.launch(*args, **kwargs)
5+
if BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
6+
os.environ["INJECT_PYTORCH"] = "True"

third_party/intel/backend/driver.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,6 @@ class CompilationHelper:
7070
_include_dir: list[str]
7171
libraries: list[str]
7272

73-
# for benchmarks
74-
_build_with_pytorch_dep: bool = False
75-
7673
def __init__(self):
7774
self._library_dir = None
7875
self._include_dir = None
@@ -81,11 +78,9 @@ def __init__(self):
8178
if os.name != "nt":
8279
self.libraries += ["sycl"]
8380

81+
@property
8482
def inject_pytorch_dep(self):
85-
# must be called before any cached properties (if pytorch is needed)
86-
if self._build_with_pytorch_dep is False:
87-
self._build_with_pytorch_dep = True
88-
self.libraries += ['torch']
83+
return os.environ.get("INJECT_PYTORCH", "False") == "True"
8984

9085
@cached_property
9186
def _compute_compilation_options_lazy(self):
@@ -103,7 +98,7 @@ def _compute_compilation_options_lazy(self):
10398
include_dir += [os.path.join(dirname, "include")]
10499
library_dir += [os.path.join(dirname, "lib")]
105100

106-
if self._build_with_pytorch_dep:
101+
if self.inject_pytorch_dep:
107102
import torch
108103

109104
torch_path = torch.utils.cmake_prefix_path
@@ -112,6 +107,7 @@ def _compute_compilation_options_lazy(self):
112107
os.path.join(torch_path, "../../include/torch/csrc/api/include"),
113108
]
114109
library_dir += [os.path.join(torch_path, "../../lib")]
110+
self.libraries += ['torch']
115111

116112
self._library_dir = library_dir
117113
self._include_dir = include_dir
@@ -276,6 +272,7 @@ def format_of(ty):
276272
#include <iomanip>
277273
#include <level_zero/ze_api.h>
278274
#include <sycl/sycl.hpp>
275+
{ "#include <ATen/record_function.h>" if COMPILATION_HELPER.inject_pytorch_dep else "" }
279276
280277
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
281278
#include <Python.h>
@@ -370,6 +367,8 @@ def format_of(ty):
370367
static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
371368
372369
std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
370+
{ 'RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {});' if COMPILATION_HELPER.inject_pytorch_dep else "" }
371+
373372
void *params[] = {{ {', '.join(f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none")} }};
374373
uint32_t num_params = sizeof(params)/sizeof(params[0]);
375374
uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();

0 commit comments

Comments
 (0)