Skip to content

Commit 303c0ab

Browse files
authored
Implement __triton_launcher as pure DLL (#3251)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent b018ed6 commit 303c0ab

File tree

1 file changed

+40
-27
lines changed

1 file changed

+40
-27
lines changed

third_party/intel/backend/driver.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,35 @@ def __del__(self):
168168
ctypes.windll.kernel32.FreeLibrary(handle)
169169

170170

171+
class TritonLauncher:
172+
173+
def __init__(self, cache_path: str):
174+
self.shared_library = ctypes.PyDLL(cache_path)
175+
# breakpoint()
176+
self.shared_library.launch.restype = ctypes.py_object
177+
self.shared_library.launch.argtypes = (ctypes.py_object, )
178+
179+
def __getattribute__(self, name):
180+
if name == "launch":
181+
shared_library = super().__getattribute__("shared_library")
182+
return getattr(shared_library, name)
183+
184+
return super().__getattribute__(name)
185+
186+
if os.name != 'nt':
187+
188+
def __del__(self):
189+
handle = self.shared_library._handle
190+
self.shared_library.dlclose.argtypes = (ctypes.c_void_p, )
191+
self.shared_library.dlclose(handle)
192+
else:
193+
194+
def __del__(self):
195+
handle = self.shared_library._handle
196+
ctypes.windll.kernel32.FreeLibrary.argtypes = (ctypes.c_uint64, )
197+
ctypes.windll.kernel32.FreeLibrary(handle)
198+
199+
171200
def compile_module_from_src(src, name):
172201
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
173202
cache = get_cache_manager(key)
@@ -192,6 +221,8 @@ def compile_module_from_src(src, name):
192221

193222
if name == 'arch_utils':
194223
return ArchParser(cache_path)
224+
elif name == '__triton_launcher':
225+
return TritonLauncher(cache_path)
195226

196227
import importlib.util
197228
spec = importlib.util.spec_from_file_location(name, cache_path)
@@ -339,6 +370,12 @@ def format_of(ty):
339370
#include <sycl/sycl.hpp>
340371
{ "#include <ATen/record_function.h>" if COMPILATION_HELPER.inject_pytorch_dep else "" }
341372
373+
#if defined(_WIN32)
374+
#define EXPORT_FUNC __declspec(dllexport)
375+
#else
376+
#define EXPORT_FUNC __attribute__((visibility("default")))
377+
#endif
378+
342379
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
343380
#include <Python.h>
344381
#include <stdio.h>
@@ -466,8 +503,7 @@ def format_of(ty):
466503
}}
467504
// end sycl
468505
469-
static PyObject* launch(PyObject* self, PyObject* args) {{
470-
506+
extern "C" EXPORT_FUNC PyObject* launch(PyObject* args) {{
471507
int gridX, gridY, gridZ;
472508
PyObject *launch_enter_hook = NULL;
473509
PyObject *launch_exit_hook = NULL;
@@ -541,28 +577,6 @@ def format_of(ty):
541577
542578
Py_RETURN_NONE;
543579
}}
544-
545-
static PyMethodDef ModuleMethods[] = {{
546-
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
547-
{{NULL, NULL, 0, NULL}} // sentinel
548-
}};
549-
550-
static struct PyModuleDef ModuleDef = {{
551-
PyModuleDef_HEAD_INIT,
552-
\"__triton_launcher\",
553-
NULL, //documentation
554-
-1, //size
555-
ModuleMethods
556-
}};
557-
558-
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
559-
PyObject *m = PyModule_Create(&ModuleDef);
560-
if(m == NULL) {{
561-
return NULL;
562-
}}
563-
PyModule_AddFunctions(m, ModuleMethods);
564-
return m;
565-
}}
566580
"""
567581
return src
568582

@@ -635,15 +649,14 @@ def __init__(self, src, metadata):
635649
self.constants = {arg_idx(idx): value for idx, value in constants.items()}
636650
self.signature = {idx: value for idx, value in src.signature.items()}
637651
src = make_launcher(self.constants, self.signature)
638-
mod = compile_module_from_src(src, "__triton_launcher")
639-
self.launch = mod.launch
652+
self.mod = compile_module_from_src(src, "__triton_launcher")
640653

641654
def __call__(self, *args, **kwargs):
642655
# Serialize KernelArguments for SPIR-V Runner
643656
serialize_kernel_args = os.getenv('TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS', None)
644657
if serialize_kernel_args:
645658
serialize_args(args, self.constants, self.signature)
646-
self.launch(*args, **kwargs)
659+
self.mod.launch(args)
647660

648661

649662
class XPUDriver(DriverBase):

0 commit comments

Comments
 (0)