Skip to content

Commit b3419c1

Browse files
committed
[#5153] This is workaround for feature requirement in #5153. The IGC build flag is updated when the large GRF mode is used in loading SPIRV kernel when register spill size > 1000.
Signed-off-by: etaf <[email protected]> Co-authored-by: Lu,Chengjun <[email protected]>
1 parent 3de5f93 commit b3419c1

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

python/test/unit/intel/test_regressions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_regression_4441(device, tmp_path: pathlib.Path):
5050
# L0 build module failed. Log: IGC: Internal Compiler Error: Segmentation violation
5151
# Error during Intel loadBinary: Triton Error [ZE]: 0x70000004
5252
# RuntimeError: Triton Error [ZE]: 0x70000004
53-
module, function, n_regs, n_spills, n_max_threads = driver.active.utils.load_binary(
53+
module, function, n_regs, n_spills, n_max_threads, _ = driver.active.utils.load_binary(
5454
kernel.name, kernel.kernel, kernel.metadata.shared, kernel.metadata.build_flags,
5555
not kernel.metadata.generate_native_code, device)
5656

@@ -1911,6 +1911,6 @@ def test_regression_5374(device, tmp_path: pathlib.Path):
19111911
# L0 build module failed. Log: IGC: Internal Compiler Error: Segmentation violation
19121912
# Error during Intel loadBinary: Triton Error [ZE]: 0x70000004
19131913
# RuntimeError: Triton Error [ZE]: 0x70000004
1914-
module, function, n_regs, n_spills, n_max_threads = driver.active.utils.load_binary(
1914+
module, function, n_regs, n_spills, n_max_threads, _ = driver.active.utils.load_binary(
19151915
kernel.name, kernel.kernel, kernel.metadata.shared, kernel.metadata.build_flags,
19161916
not kernel.metadata.generate_native_code, device)

python/triton/compiler/compiler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,9 +465,14 @@ def raise_(err):
465465
if knobs.runtime.kernel_load_start_hook is not None:
466466
knobs.runtime.kernel_load_start_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
467467
# TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
468-
self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary(
468+
# FIXME: remove the workaround for updating the build flags in loading binary
469+
self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads, new_build_flags = driver.active.utils.load_binary(
469470
self.name, self.kernel, self.metadata.shared, self.metadata.build_flags,
470471
not self.metadata.generate_native_code, device)
472+
473+
if new_build_flags != self.metadata.build_flags:
474+
self.metadata = self.metadata._replace(build_flags=new_build_flags)
475+
471476
if hasattr(self.metadata, "threads_per_warp"):
472477
warp_size = self.metadata.threads_per_warp
473478
else:

third_party/intel/backend/driver.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,8 @@ extern "C" EXPORT_FUNC PyObject *load_binary(PyObject *args) {
310310
auto kernel_bundle_py = PyCapsule_New(reinterpret_cast<void *>(mod),
311311
"kernel_bundle", freeKernelBundle);
312312

313-
return Py_BuildValue("(OOiii)", kernel_bundle_py, kernel_py, n_regs,
314-
n_spills, n_max_threads);
313+
return Py_BuildValue("(OOiiis)", kernel_bundle_py, kernel_py, n_regs,
314+
n_spills, n_max_threads, build_flags().data());
315315

316316
} catch (const std::exception &e) {
317317
PyGILState_STATE gil_state;

0 commit comments

Comments
 (0)