Skip to content

Commit 9bda03d

Browse files
authored
Emit informational msgs regarding recompilation with different GRF mode under TRITON_DEBUG (#2385)
Fixes #2251 --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 2a4b054 commit 9bda03d

File tree

2 files changed

+63906
-4
lines changed

2 files changed

+63906
-4
lines changed

third_party/intel/backend/driver.c

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,13 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
192192
// If the register mode isn't set, and the number of spills is greater
193193
// than the threshold, recompile the kernel using large GRF mode.
194194
if (!is_GRF_mode_specified && n_spills > max_reg_spill) {
195-
std::cout << "(I): Detected " << n_spills
196-
<< " spills, recompiling the kernel using large GRF mode"
197-
<< std::endl;
195+
const std::optional<bool> debugEnabled =
196+
isEnvValueBool(getStrEnv("TRITON_DEBUG"));
197+
if (debugEnabled)
198+
std::cout << "(I): Detected " << n_spills
199+
<< " spills, recompiling kernel \"" << kernel_name
200+
<< "\" using large GRF mode" << std::endl;
201+
198202
const std::string new_build_flags =
199203
build_flags_str.append(" -cl-intel-256-GRF-per-thread");
200204
l0_module = checkSyclErrors(
@@ -204,7 +208,10 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
204208
l0_kernel = checkL0Errors(l0_module);
205209
gpuAssert(zeKernelGetProperties(l0_kernel, &props));
206210
n_spills = props.spillMemSize;
207-
std::cout << "(I): Kernel has now " << n_spills << " spills" << std::endl;
211+
212+
if (debugEnabled)
213+
std::cout << "(I): Kernel has now " << n_spills << " spills"
214+
<< std::endl;
208215
}
209216
}
210217

0 commit comments

Comments
 (0)