|
1 | 1 | import os |
2 | 2 |
|
3 | | -from triton._utils import parse_list_string |
4 | | -from triton.backends.intel.driver import compile_module_from_src, COMPILATION_HELPER, ty_to_cpp, serialize_args |
| 3 | +from .benchmark_testing import BENCHMARKING_METHOD |
5 | 4 |
|
6 | | -# ------------------------ |
7 | | -# Utils |
8 | | -# ------------------------ |
9 | | - |
10 | | -COMPILATION_HELPER.inject_pytorch_dep() |
11 | | - |
12 | | -# ------------------------ |
13 | | -# Launcher |
14 | | -# ------------------------ |
15 | | - |
16 | | - |
17 | | -def make_launcher(constants, signature, ids): # pylint: disable=unused-argument |
18 | | - |
19 | | - def _extracted_type(ty): |
20 | | - if ty[0] == "*" or ty == "none": |
21 | | - return "PyObject*" |
22 | | - if ty[0] == "[": |
23 | | - if ty == "[]": |
24 | | - return "[]" |
25 | | - tys = parse_list_string(ty) |
26 | | - val = ",".join(map(_extracted_type, tys)) |
27 | | - return f"[{val}]" |
28 | | - return ty_to_cpp(ty) |
29 | | - |
30 | | - def format_of(ty): |
31 | | - if ty == "void*": |
32 | | - return "O" |
33 | | - if ty[0] == "[": |
34 | | - if ty == "[]": |
35 | | - return "()" |
36 | | - tys = parse_list_string(ty) |
37 | | - val = "".join(map(format_of, tys)) |
38 | | - return f"({val})" |
39 | | - return { |
40 | | - "PyObject*": "O", |
41 | | - "float": "f", |
42 | | - "double": "d", |
43 | | - "long": "l", |
44 | | - "int8_t": "b", |
45 | | - "int16_t": "h", |
46 | | - "int32_t": "i", |
47 | | - "int64_t": "L", |
48 | | - "uint8_t": "B", |
49 | | - "uint16_t": "H", |
50 | | - "uint32_t": "I", |
51 | | - "uint64_t": "K", |
52 | | - }[ty] |
53 | | - |
54 | | - signature = {k: v for k, v in signature.items() if v != "constexpr"} |
55 | | - args_format = "".join([format_of(_extracted_type(ty)) for ty in signature.values()]) |
56 | | - fmt = "iiiOOOOOO" + args_format |
57 | | - signature = ",".join(signature.values()).replace("[", "").replace("]", "") |
58 | | - signature = list(filter(bool, signature.split(","))) |
59 | | - signature = dict(enumerate(signature)) |
60 | | - args_list = ", " + ", ".join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else "" |
61 | | - |
62 | | - # Record the end of regular arguments; |
63 | | - # subsequent arguments are architecture-specific descriptors. |
64 | | - arg_decls = ", ".join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) |
65 | | - |
66 | | - # generate glue code |
67 | | - src = f""" |
68 | | - #include <cstddef> |
69 | | - #include <string> |
70 | | - #include <iostream> |
71 | | - #include <iomanip> |
72 | | - #include <level_zero/ze_api.h> |
73 | | - #include <sycl/sycl.hpp> |
74 | | - #include <ATen/record_function.h> |
75 | | -
|
76 | | - #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION |
77 | | - #include <Python.h> |
78 | | - #include <stdio.h> |
79 | | - #include <numpy/arrayobject.h> |
80 | | -
|
81 | | - static inline void gpuAssert(ze_result_t code, const char *file, int line) |
82 | | - {{ |
83 | | - if (code != ZE_RESULT_SUCCESS) |
84 | | - {{ |
85 | | - const char* prefix = "Triton Error [ZE]: "; |
86 | | - std::string str = std::to_string(code); |
87 | | - char err[1024] = {{0}}; |
88 | | - strcat(err, prefix); |
89 | | - strcat(err, str.c_str()); |
90 | | - PyErr_SetString(PyExc_RuntimeError, err); |
91 | | - }} |
92 | | - }} |
93 | | -
|
94 | | - #define ZE_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} |
95 | | -
|
96 | | - typedef struct _DevicePtrInfo {{ |
97 | | - void* dev_ptr; |
98 | | - bool valid; |
99 | | - }} DevicePtrInfo; |
100 | | -
|
101 | | - static inline void checkDevicePointer(DevicePtrInfo *ptr_info, int idx, const sycl::queue &queue) {{ |
102 | | - if (!ptr_info->dev_ptr || !ptr_info->valid) {{ |
103 | | - return; |
104 | | - }} |
105 | | - auto context = queue.get_context(); |
106 | | - auto handle = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(context); |
107 | | - ze_memory_allocation_properties_t prop; |
108 | | - prop.stype = ZE_STRUCTURE_TYPE_MEMORY_ALLOCATION_PROPERTIES; |
109 | | - prop.pNext = nullptr; |
110 | | - ze_device_handle_t device; |
111 | | - auto res = zeMemGetAllocProperties((ze_context_handle_t)handle, ptr_info->dev_ptr, &prop, &device); |
112 | | - if (res != ZE_RESULT_SUCCESS) {{ |
113 | | - PyErr_Format(PyExc_ValueError, |
114 | | - "Cannot get memory properties for pointer argument (at %d, err=%d)", idx, res); |
115 | | - ptr_info->valid = false; |
116 | | - }} else if (prop.type != ZE_MEMORY_TYPE_DEVICE) {{ |
117 | | - PyErr_Format(PyExc_ValueError, |
118 | | - "Pointer argument (at %d) doesn't reference XPU device memory (cpu tensor?)", idx); |
119 | | - ptr_info->valid = false; |
120 | | - }} |
121 | | - }} |
122 | | -
|
123 | | - static inline DevicePtrInfo getPointer(PyObject *obj, int idx, const sycl::queue &queue) {{ |
124 | | - DevicePtrInfo ptr_info; |
125 | | - ptr_info.dev_ptr = 0; |
126 | | - ptr_info.valid = true; |
127 | | - if (PyLong_Check(obj)) {{ |
128 | | - ptr_info.dev_ptr = PyLong_AsVoidPtr(obj); |
129 | | - checkDevicePointer(&ptr_info, idx, queue); |
130 | | - return ptr_info; |
131 | | - }} |
132 | | - if (obj == Py_None) {{ |
133 | | - // valid nullptr |
134 | | - return ptr_info; |
135 | | - }} |
136 | | - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); |
137 | | - if(ptr){{ |
138 | | - PyObject *empty_tuple = PyTuple_New(0); |
139 | | - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); |
140 | | - Py_DECREF(empty_tuple); |
141 | | - Py_DECREF(ptr); |
142 | | - if (!PyLong_Check(ret)) {{ |
143 | | - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); |
144 | | - ptr_info.valid = false; |
145 | | - return ptr_info; |
146 | | - }} |
147 | | - ptr_info.dev_ptr = PyLong_AsVoidPtr(ret); |
148 | | - if(!ptr_info.dev_ptr) {{ |
149 | | - return ptr_info; |
150 | | - }} |
151 | | - checkDevicePointer(&ptr_info, idx, queue); |
152 | | - Py_DECREF(ret); // Thanks ChatGPT! |
153 | | - return ptr_info; |
154 | | - }} |
155 | | - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); |
156 | | - ptr_info.valid = false; |
157 | | - return ptr_info; |
158 | | - }} |
159 | | -// start sycl |
160 | | - template <class T> |
161 | | - static inline void set_scalar_arg(sycl::handler &cgh, int index, const void *value) {{ |
162 | | - cgh.set_arg(index, *static_cast<const T *>(value)); |
163 | | - }} |
164 | | - static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {", " + arg_decls if len(arg_decls) > 0 else ""}) {{ |
165 | | -
|
166 | | - std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>(); |
167 | | - RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {{}}); |
168 | | - void *params[] = {{ {", ".join(f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none")} }}; |
169 | | - uint32_t num_params = sizeof(params)/sizeof(params[0]); |
170 | | - uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>(); |
171 | | - size_t global_range_x = gridX*threads_per_warp*num_warps; |
172 | | - size_t global_range_y = gridY; |
173 | | - size_t global_range_z = gridZ; |
174 | | - size_t local_range_x = num_warps*threads_per_warp; |
175 | | - size_t local_range_y = 1; |
176 | | - size_t local_range_z = 1; |
177 | | - sycl::range<3> global_range(global_range_z, global_range_y, global_range_x); |
178 | | - sycl::range<3> local_range(local_range_z, local_range_y, local_range_x); |
179 | | - sycl::nd_range<3> parallel_work_size(global_range, local_range); |
180 | | - if (shared_memory) {{ |
181 | | - expected_num_params -= 1; |
182 | | - }} |
183 | | - assert(num_params == expected_num_params && "number of kernel param not matched"); |
184 | | - // Submit the imported kernel. |
185 | | - auto cgf = [&](sycl::handler &cgh) {{ |
186 | | - {" ".join(f"set_scalar_arg<{ty_to_cpp(item)}>(cgh, {idx}, params[{idx}]);" for idx, item in enumerate([signature[i] for i in signature if i not in constants and signature[i] != "none"]))} if (shared_memory) {{ |
187 | | - using share_mem_t = sycl::local_accessor<int8_t, 1>; |
188 | | - share_mem_t local_buffer = share_mem_t(shared_memory, cgh); |
189 | | - cgh.set_arg(num_params, local_buffer); |
190 | | - cgh.parallel_for(parallel_work_size, kernel_ptr); |
191 | | - }} else {{ |
192 | | - cgh.parallel_for(parallel_work_size, kernel_ptr); |
193 | | - }} |
194 | | - }}; |
195 | | - auto event = stream.submit(cgf); |
196 | | - }} |
197 | | -// end sycl |
198 | | - static PyObject* launch(PyObject* self, PyObject* args) {{ |
199 | | -
|
200 | | - int gridX, gridY, gridZ; |
201 | | - PyObject *launch_enter_hook = NULL; |
202 | | - PyObject *launch_exit_hook = NULL; |
203 | | - PyObject *kernel_metadata = NULL; |
204 | | - PyObject *launch_metadata = NULL; |
205 | | - PyObject *py_obj_stream; |
206 | | - PyObject *py_kernel; |
207 | | -
|
208 | | - {" ".join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} |
209 | | - if(!PyArg_ParseTuple(args, \"{fmt}\", &gridX, &gridY, &gridZ, &py_obj_stream, &py_kernel, |
210 | | - &kernel_metadata, &launch_metadata, |
211 | | - &launch_enter_hook, &launch_exit_hook {args_list})) {{ |
212 | | - return NULL; |
213 | | - }} |
214 | | -
|
215 | | - // extract kernel metadata |
216 | | - int num_warps = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "num_warps")); |
217 | | - int num_ctas = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "num_ctas")); |
218 | | - int shared_memory = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "shared")); |
219 | | - int threads_per_warp = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "threads_per_warp")); |
220 | | -
|
221 | | - // extract cluster dims |
222 | | - PyObject *clusterDim = PyObject_GetAttrString(kernel_metadata, "cluster_dims"); |
223 | | - if (!PyTuple_Check(kernel_metadata)) {{ |
224 | | - PyErr_SetString(PyExc_TypeError, "kernel_metadata.cluster_dims must be a tuple"); |
225 | | - return NULL; |
226 | | - }} |
227 | | - int clusterDimX = PyLong_AsLong(PyTuple_GetItem(clusterDim, 0)); |
228 | | - int clusterDimY = PyLong_AsLong(PyTuple_GetItem(clusterDim, 1)); |
229 | | - int clusterDimZ = PyLong_AsLong(PyTuple_GetItem(clusterDim, 2)); |
230 | | - // extract launch metadata |
231 | | - if (launch_enter_hook != Py_None){{ |
232 | | - PyObject* args = Py_BuildValue("(O)", launch_metadata); |
233 | | - PyObject* ret = PyObject_CallObject(launch_enter_hook, args); |
234 | | - Py_DECREF(args); |
235 | | - if (!ret) |
236 | | - return NULL; |
237 | | - }} |
238 | | -
|
239 | | - void * pStream = PyLong_AsVoidPtr(py_obj_stream); |
240 | | - //error check |
241 | | - if(pStream == nullptr || py_kernel == nullptr) return NULL; |
242 | | -
|
243 | | - sycl::queue stream = *(static_cast<sycl::queue*>(pStream)); |
244 | | - sycl::kernel* kernel_ptr = reinterpret_cast<sycl::kernel*>(PyCapsule_GetPointer(py_kernel, "kernel")); |
245 | | - if(kernel_ptr == nullptr) return NULL; |
246 | | - sycl::kernel kernel = *kernel_ptr; |
247 | | -
|
248 | | - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])}; |
249 | | - sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {"," + ", ".join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" or ty == "none" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""}); |
250 | | -
|
251 | | - if(launch_exit_hook != Py_None){{ |
252 | | - PyObject* args = Py_BuildValue("(O)", launch_metadata); |
253 | | - PyObject* ret = PyObject_CallObject(launch_exit_hook, args); |
254 | | - Py_DECREF(args); |
255 | | - if (!ret) |
256 | | - return NULL; |
257 | | - }} |
258 | | - if (PyErr_Occurred()) {{ |
259 | | - return NULL; |
260 | | - }} |
261 | | -
|
262 | | - // return None |
263 | | - Py_INCREF(Py_None); |
264 | | - return Py_None; |
265 | | - }} |
266 | | -
|
267 | | - static PyMethodDef ModuleMethods[] = {{ |
268 | | - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, |
269 | | - {{NULL, NULL, 0, NULL}} // sentinel |
270 | | - }}; |
271 | | -
|
272 | | - static struct PyModuleDef ModuleDef = {{ |
273 | | - PyModuleDef_HEAD_INIT, |
274 | | - \"__triton_launcher\", |
275 | | - NULL, //documentation |
276 | | - -1, //size |
277 | | - ModuleMethods |
278 | | - }}; |
279 | | -
|
280 | | - PyMODINIT_FUNC PyInit___triton_launcher(void) {{ |
281 | | - PyObject *m = PyModule_Create(&ModuleDef); |
282 | | - if(m == NULL) {{ |
283 | | - return NULL; |
284 | | - }} |
285 | | - PyModule_AddFunctions(m, ModuleMethods); |
286 | | - return m; |
287 | | - }} |
288 | | - """ |
289 | | - return src |
290 | | - |
291 | | - |
292 | | -class XPULauncher: |
293 | | - |
294 | | - def __init__(self, src, metadata): # pylint: disable=unused-argument |
295 | | - ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} |
296 | | - constants = src.constants if hasattr(src, "constants") else {} |
297 | | - self.constants = dict(constants.items()) |
298 | | - self.signature = dict(src.signature.items()) |
299 | | - src = make_launcher(self.constants, self.signature, ids) |
300 | | - mod = compile_module_from_src(src, "__triton_launcher") |
301 | | - self.launch = mod.launch |
302 | | - |
303 | | - def __call__(self, *args, **kwargs): |
304 | | - # Serialize KernelArguments for SPIR-V Runner |
305 | | - serialize_kernel_args = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS", None) |
306 | | - if serialize_kernel_args: |
307 | | - serialize_args(args, self.constants, self.signature) |
308 | | - self.launch(*args, **kwargs) |
| 5 | +if BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER": |
| 6 | + os.environ["INJECT_PYTORCH"] = "True" |
0 commit comments