diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index a1b2700033..d0eddc9951 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -468,6 +468,12 @@ def raise_(err): self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary( self.name, self.kernel, self.metadata.shared, self.metadata.build_flags, not self.metadata.generate_native_code, device) + # PyTorch could use the updated build flags in load binary. + if hasattr(driver.active.utils, "get_last_selected_build_flags"): + new_build_flags = driver.active.utils.get_last_selected_build_flags() + if new_build_flags != self.metadata.build_flags: + self.metadata = self.metadata._replace(build_flags=new_build_flags) + if hasattr(self.metadata, "threads_per_warp"): warp_size = self.metadata.threads_per_warp else: diff --git a/third_party/intel/backend/driver.c b/third_party/intel/backend/driver.c index 758999f8dc..32de69a334 100644 --- a/third_party/intel/backend/driver.c +++ b/third_party/intel/backend/driver.c @@ -192,6 +192,12 @@ sycl::context get_default_context(const sycl::device &sycl_device) { #endif } +static BuildFlags last_build_flag(""); + +extern "C" EXPORT_FUNC PyObject *get_last_selected_build_flags() { + return Py_BuildValue("s", last_build_flag().data()); +} + extern "C" EXPORT_FUNC PyObject *load_binary(PyObject *args) { const char *name, *build_flags_ptr; int shared; @@ -309,7 +315,7 @@ extern "C" EXPORT_FUNC PyObject *load_binary(PyObject *args) { PyCapsule_New(reinterpret_cast(fun), "kernel", freeKernel); auto kernel_bundle_py = PyCapsule_New(reinterpret_cast(mod), "kernel_bundle", freeKernelBundle); - + last_build_flag = build_flags; return Py_BuildValue("(OOiii)", kernel_bundle_py, kernel_py, n_regs, n_spills, n_max_threads); diff --git a/third_party/intel/backend/driver.py b/third_party/intel/backend/driver.py index 96b74de7e3..9b12afc423 100644 --- a/third_party/intel/backend/driver.py +++ b/third_party/intel/backend/driver.py @@ -199,9 +199,11 @@ def __init__(self, cache_path: str): self.shared_library.get_device_properties.argtypes = (ctypes.c_int, ) self.shared_library.has_opencl_extension.restype = ctypes.py_object self.shared_library.has_opencl_extension.argtypes = (ctypes.c_int, ctypes.c_char_p) + self.shared_library.get_last_selected_build_flags.restype = ctypes.py_object def __getattribute__(self, name): - if name in ("get_device_properties", "init_devices", "wait_on_sycl_queue", "has_opencl_extension"): + if name in ("get_device_properties", "init_devices", "wait_on_sycl_queue", "has_opencl_extension", + "get_last_selected_build_flags"): shared_library = super().__getattribute__("shared_library") return getattr(shared_library, name) @@ -318,6 +320,7 @@ def __init__(self): self.device_count = mod.init_devices(self.get_sycl_queue()) self.wait_on_sycl_queue = mod.wait_on_sycl_queue self.has_opencl_extension = mod.has_opencl_extension + self.get_last_selected_build_flags = mod.get_last_selected_build_flags def get_current_device(self): import torch