@@ -168,6 +168,35 @@ def __del__(self):
168168 ctypes .windll .kernel32 .FreeLibrary (handle )
169169
170170
171+ class TritonLauncher :
172+
173+ def __init__ (self , cache_path : str ):
174+ self .shared_library = ctypes .PyDLL (cache_path )
175+ # breakpoint()
176+ self .shared_library .launch .restype = ctypes .py_object
177+ self .shared_library .launch .argtypes = (ctypes .py_object , )
178+
179+ def __getattribute__ (self , name ):
180+ if name == "launch" :
181+ shared_library = super ().__getattribute__ ("shared_library" )
182+ return getattr (shared_library , name )
183+
184+ return super ().__getattribute__ (name )
185+
186+ if os .name != 'nt' :
187+
188+ def __del__ (self ):
189+ handle = self .shared_library ._handle
190+ self .shared_library .dlclose .argtypes = (ctypes .c_void_p , )
191+ self .shared_library .dlclose (handle )
192+ else :
193+
194+ def __del__ (self ):
195+ handle = self .shared_library ._handle
196+ ctypes .windll .kernel32 .FreeLibrary .argtypes = (ctypes .c_uint64 , )
197+ ctypes .windll .kernel32 .FreeLibrary (handle )
198+
199+
171200def compile_module_from_src (src , name ):
172201 key = hashlib .sha256 (src .encode ("utf-8" )).hexdigest ()
173202 cache = get_cache_manager (key )
@@ -192,6 +221,8 @@ def compile_module_from_src(src, name):
192221
193222 if name == 'arch_utils' :
194223 return ArchParser (cache_path )
224+ elif name == '__triton_launcher' :
225+ return TritonLauncher (cache_path )
195226
196227 import importlib .util
197228 spec = importlib .util .spec_from_file_location (name , cache_path )
@@ -339,6 +370,12 @@ def format_of(ty):
339370#include <sycl/sycl.hpp>
340371{ "#include <ATen/record_function.h>" if COMPILATION_HELPER .inject_pytorch_dep else "" }
341372
373+ #if defined(_WIN32)
374+ #define EXPORT_FUNC __declspec(dllexport)
375+ #else
376+ #define EXPORT_FUNC __attribute__((visibility("default")))
377+ #endif
378+
342379#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
343380#include <Python.h>
344381#include <stdio.h>
@@ -466,8 +503,7 @@ def format_of(ty):
466503}}
467504// end sycl
468505
469- static PyObject* launch(PyObject* self, PyObject* args) {{
470-
506+ extern "C" EXPORT_FUNC PyObject* launch(PyObject* args) {{
471507 int gridX, gridY, gridZ;
472508 PyObject *launch_enter_hook = NULL;
473509 PyObject *launch_exit_hook = NULL;
@@ -541,28 +577,6 @@ def format_of(ty):
541577
542578 Py_RETURN_NONE;
543579}}
544-
545- static PyMethodDef ModuleMethods[] = {{
546- {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
547- {{NULL, NULL, 0, NULL}} // sentinel
548- }};
549-
550- static struct PyModuleDef ModuleDef = {{
551- PyModuleDef_HEAD_INIT,
552- \" __triton_launcher\" ,
553- NULL, //documentation
554- -1, //size
555- ModuleMethods
556- }};
557-
558- PyMODINIT_FUNC PyInit___triton_launcher(void) {{
559- PyObject *m = PyModule_Create(&ModuleDef);
560- if(m == NULL) {{
561- return NULL;
562- }}
563- PyModule_AddFunctions(m, ModuleMethods);
564- return m;
565- }}
566580"""
567581 return src
568582
@@ -635,15 +649,14 @@ def __init__(self, src, metadata):
635649 self .constants = {arg_idx (idx ): value for idx , value in constants .items ()}
636650 self .signature = {idx : value for idx , value in src .signature .items ()}
637651 src = make_launcher (self .constants , self .signature )
638- mod = compile_module_from_src (src , "__triton_launcher" )
639- self .launch = mod .launch
652+ self .mod = compile_module_from_src (src , "__triton_launcher" )
640653
641654 def __call__ (self , * args , ** kwargs ):
642655 # Serialize KernelArguments for SPIR-V Runner
643656 serialize_kernel_args = os .getenv ('TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS' , None )
644657 if serialize_kernel_args :
645658 serialize_args (args , self .constants , self .signature )
646- self .launch (* args , ** kwargs )
659+ self .mod . launch (args )
647660
648661
649662class XPUDriver (DriverBase ):
0 commit comments