Skip to content

Commit de5302d

Browse files
authored
Return n_regs for binaries compiled explicitly with a register size mode option (#2391)
Intel Data Center Max GPUs will dynamically scale the number of hardware threads available per XVE depending on the specified GRF mode. With small GRF mode (default), a single hardware thread can access 128 GRF registers and each XVE engine has 8 hardware threads. In large GRF mode, a single hardware thread can access 256 GRF registers but each XVE engine only has 4 hardware threads. There is also an auto mode. ([see the docs for more info](https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2024-2/small-register-mode-vs-large-register-mode.html)) This PR adds support for populating the `n_regs` parameter returned from loading a binary with information about the selected GRF mode. Because L0 does not return the number of registers and our register size info does not work like NVIDIA, the semantics are a bit different from upstream Triton. We _only_ return a value if the user has specified a small or large GRF mode build flag. The purpose of returning `n_regs` in upstream Triton/Torch Inductor is b/c NVIDIA can dynamically adjust occupancy of a SM based on the register pressure per warp. This means high register pressure can result in fewer running warps which reduces parallelism and performance. Theoretically, you can have many different "GRF modes" on a NVIDIA GPU as you adjust SM occupancy. For Intel GPUs, the choice is binary - large or small - and the performance penalty for register spills in small always outweighs any parallelism gains (at least, in our testing so far). It is not clear that returning 128 is actionable as further reductions in register usage will not effect occupancy - only the large GRF mode effects occupancy. So, I focused on making sure large GRF mode was properly handled and other cases were handled as we were able, with any ambiguous case returning 0 (which will cause torch inductor to skip any register-specific optimization). The approach to returning GRF size is dependent on parsing the build flags passed to the binary loader. Because the build flags are modified in the `make_spv` step during generation of native code instead of a SPIRV file, this approach should work for the native code POC recently merged in #2148. Note that I had to introduce exceptions into our `driver.c` code to make the error handling acceptable. This cleaned up a lot of the code, and I believe should be acceptable both because we already depend on c++ in `driver.c` (just not in the external signatures) and because exceptions are used in other parts of the Triton codebase. I marked this as a draft PR because I would like to do a bit more testing, but it is ready for review. Close #1641
1 parent 26baece commit de5302d

File tree

2 files changed

+148
-112
lines changed

2 files changed

+148
-112
lines changed

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def make_spv(src, metadata, options):
338338
if os.path.exists(flog.name):
339339
with open(flog.name) as log_file:
340340
log = log_file.read().strip()
341-
if 'spilled' in log:
341+
if 'spilled' in log and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1:
342342
"""
343343
The exact message is something like:
344344
warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217

third_party/intel/backend/driver.c

Lines changed: 147 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,13 @@ static std::vector<ze_device_handle_t> g_devices;
2626
static std::vector<std::pair<sycl::device, ze_device_handle_t>>
2727
g_sycl_l0_device_list;
2828

29-
static inline void gpuAssert(ze_result_t code) {
30-
if (code != ZE_RESULT_SUCCESS) {
31-
auto str = parseZeResultCode(code);
32-
char err[1024] = {0};
33-
strncat(err, str.c_str(), std::min(str.size(), size_t(1024)));
34-
PyGILState_STATE gil_state;
35-
gil_state = PyGILState_Ensure();
36-
PyErr_SetString(PyExc_RuntimeError, err);
37-
PyGILState_Release(gil_state);
38-
}
39-
}
40-
4129
template <typename T>
4230
static inline T checkSyclErrors(const std::tuple<T, ze_result_t> tuple) {
43-
gpuAssert(std::get<1>(tuple));
44-
if (PyErr_Occurred())
45-
return nullptr;
46-
else
47-
return std::get<0>(tuple);
31+
const auto code = std::get<1>(tuple);
32+
if (code != ZE_RESULT_SUCCESS) {
33+
throw std::runtime_error(parseZeResultCode(code));
34+
}
35+
return std::get<0>(tuple);
4836
}
4937

5038
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
@@ -113,14 +101,79 @@ void freeKernelBundle(PyObject *p) {
113101
PyCapsule_GetPointer(p, "kernel_bundle"));
114102
}
115103

104+
using Spills = int32_t;
105+
106+
template <typename L0_DEVICE, typename L0_CONTEXT>
107+
std::tuple<ze_module_handle_t, ze_kernel_handle_t, Spills>
108+
compileLevelZeroObjects(uint8_t *binary_ptr, const size_t binary_size,
109+
const std::string &kernel_name, L0_DEVICE l0_device,
110+
L0_CONTEXT l0_context, const std::string &build_flags,
111+
const bool is_spv) {
112+
auto l0_module =
113+
checkSyclErrors(create_module(l0_context, l0_device, binary_ptr,
114+
binary_size, build_flags.data(), is_spv));
115+
116+
// Retrieve the kernel properties (e.g. register spills).
117+
auto l0_kernel = checkSyclErrors(create_function(l0_module, kernel_name));
118+
119+
ze_kernel_properties_t props;
120+
props.stype = ZE_STRUCTURE_TYPE_KERNEL_PROPERTIES;
121+
props.pNext = nullptr;
122+
checkSyclErrors(
123+
std::make_tuple(NULL, zeKernelGetProperties(l0_kernel, &props)));
124+
125+
const int32_t n_spills = props.spillMemSize;
126+
127+
return std::make_tuple(l0_module, l0_kernel, n_spills);
128+
}
129+
130+
struct BuildFlags {
131+
std::string build_flags_str;
132+
133+
const char *LARGE_GRF_FLAG{"-cl-intel-256-GRF-per-thread"};
134+
const char *SMALL_GRF_FLAG{"-cl-intel-128-GRF-per-thread"};
135+
const char *AUTO_GRF_FLAG{"-cl-intel-enable-auto-large-GRF-mode"};
136+
137+
BuildFlags(const char *build_flags) : build_flags_str(build_flags) {}
138+
139+
const std::string &operator()() const { return build_flags_str; }
140+
141+
int32_t n_regs() const {
142+
if (build_flags_str.find(LARGE_GRF_FLAG) != std::string::npos) {
143+
return 256;
144+
}
145+
if (build_flags_str.find(SMALL_GRF_FLAG) != std::string::npos) {
146+
return 128;
147+
}
148+
// TODO: arguably we could return 128 if we find no flag instead of 0. For
149+
// now, stick with the conservative choice and alert the user only if a
150+
// specific GRF mode is specified.
151+
return 0;
152+
}
153+
154+
const bool hasGRFSizeFlag() const {
155+
if (build_flags_str.find(LARGE_GRF_FLAG) != std::string::npos ||
156+
build_flags_str.find(SMALL_GRF_FLAG) != std::string::npos ||
157+
build_flags_str.find(AUTO_GRF_FLAG) != std::string::npos) {
158+
return true;
159+
}
160+
161+
return false;
162+
}
163+
164+
void addLargeGRFSizeFlag() {
165+
build_flags_str = build_flags_str.append(" ").append(LARGE_GRF_FLAG);
166+
}
167+
};
168+
116169
static PyObject *loadBinary(PyObject *self, PyObject *args) {
117-
const char *name, *build_flags;
170+
const char *name, *build_flags_ptr;
118171
int shared;
119172
PyObject *py_bytes;
120173
int devId;
121174

122-
if (!PyArg_ParseTuple(args, "sSisi", &name, &py_bytes, &shared, &build_flags,
123-
&devId)) {
175+
if (!PyArg_ParseTuple(args, "sSisi", &name, &py_bytes, &shared,
176+
&build_flags_ptr, &devId)) {
124177
std::cerr << "loadBinary arg parse failed" << std::endl;
125178
return NULL;
126179
}
@@ -130,106 +183,89 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
130183
return NULL;
131184
}
132185

133-
const auto &sycl_l0_device_pair = g_sycl_l0_device_list[devId];
134-
const sycl::device sycl_device = sycl_l0_device_pair.first;
186+
BuildFlags build_flags(build_flags_ptr);
135187

136-
std::string kernel_name = name;
137-
const size_t binary_size = PyBytes_Size(py_bytes);
188+
try {
138189

139-
uint8_t *binary_ptr = (uint8_t *)PyBytes_AsString(py_bytes);
140-
const auto ctx = sycl_device.get_platform().ext_oneapi_get_default_context();
141-
const auto l0_device =
142-
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
143-
const auto l0_context =
144-
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(ctx);
190+
const auto &sycl_l0_device_pair = g_sycl_l0_device_list[devId];
191+
const sycl::device sycl_device = sycl_l0_device_pair.first;
145192

146-
const auto use_native_code =
147-
isEnvValueBool(getStrEnv("TRITON_XPU_GEN_NATIVE_CODE"));
148-
const bool is_spv = use_native_code ? !(*use_native_code) : true;
193+
const std::string kernel_name = name;
194+
const size_t binary_size = PyBytes_Size(py_bytes);
149195

150-
auto l0_module = checkSyclErrors(create_module(
151-
l0_context, l0_device, binary_ptr, binary_size, build_flags, is_spv));
196+
uint8_t *binary_ptr = (uint8_t *)PyBytes_AsString(py_bytes);
197+
const auto ctx =
198+
sycl_device.get_platform().ext_oneapi_get_default_context();
199+
const auto l0_device =
200+
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
201+
const auto l0_context =
202+
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(ctx);
152203

153-
auto checkL0Errors = [&](auto l0_module) -> ze_kernel_handle_t {
154-
if (PyErr_Occurred()) {
155-
// check for errors from module creation
156-
return NULL;
157-
}
158-
ze_kernel_handle_t l0_kernel =
159-
checkSyclErrors(create_function(l0_module, kernel_name));
160-
if (PyErr_Occurred()) {
161-
// check for errors from kernel creation
162-
return NULL;
163-
}
164-
return l0_kernel;
165-
};
204+
const auto use_native_code =
205+
isEnvValueBool(getStrEnv("TRITON_XPU_GEN_NATIVE_CODE"));
206+
const bool is_spv = use_native_code ? !(*use_native_code) : true;
166207

167-
// Retrieve the kernel properties (e.g. register spills).
168-
ze_kernel_handle_t l0_kernel = checkL0Errors(l0_module);
169-
ze_kernel_properties_t props;
170-
props.stype = ZE_STRUCTURE_TYPE_KERNEL_PROPERTIES;
171-
props.pNext = nullptr;
172-
gpuAssert(zeKernelGetProperties(l0_kernel, &props));
173-
174-
int32_t n_spills = props.spillMemSize;
175-
const int32_t n_regs = 0;
176-
177-
if (is_spv) {
178-
constexpr int32_t max_reg_spill = 1000;
179-
std::string build_flags_str(build_flags);
180-
bool is_GRF_mode_specified = false;
181-
182-
// Check whether the GRF mode is specified by the build flags.
183-
if (build_flags_str.find("-cl-intel-256-GRF-per-thread") !=
184-
std::string::npos ||
185-
build_flags_str.find("-cl-intel-128-GRF-per-thread") !=
186-
std::string::npos ||
187-
build_flags_str.find("-cl-intel-enable-auto-large-GRF-mode") !=
188-
std::string::npos) {
189-
is_GRF_mode_specified = true;
190-
}
208+
auto [l0_module, l0_kernel, n_spills] =
209+
compileLevelZeroObjects(binary_ptr, binary_size, kernel_name, l0_device,
210+
l0_context, build_flags(), is_spv);
211+
212+
if (is_spv) {
213+
constexpr int32_t max_reg_spill = 1000;
214+
const bool is_GRF_mode_specified = build_flags.hasGRFSizeFlag();
215+
216+
// If the register mode isn't set, and the number of spills is greater
217+
// than the threshold, recompile the kernel using large GRF mode.
218+
if (!is_GRF_mode_specified && n_spills > max_reg_spill) {
219+
const std::optional<bool> debugEnabled =
220+
isEnvValueBool(getStrEnv("TRITON_DEBUG"));
221+
if (debugEnabled)
222+
std::cout << "(I): Detected " << n_spills
223+
<< " spills, recompiling the kernel using large GRF mode"
224+
<< std::endl;
225+
226+
build_flags.addLargeGRFSizeFlag();
191227

192-
// If the register mode isn't set, and the number of spills is greater
193-
// than the threshold, recompile the kernel using large GRF mode.
194-
if (!is_GRF_mode_specified && n_spills > max_reg_spill) {
195-
const std::optional<bool> debugEnabled =
196-
isEnvValueBool(getStrEnv("TRITON_DEBUG"));
197-
if (debugEnabled)
198-
std::cout << "(I): Detected " << n_spills
199-
<< " spills, recompiling kernel \"" << kernel_name
200-
<< "\" using large GRF mode" << std::endl;
201-
202-
const std::string new_build_flags =
203-
build_flags_str.append(" -cl-intel-256-GRF-per-thread");
204-
l0_module = checkSyclErrors(
205-
create_module(l0_context, l0_device, binary_ptr, binary_size,
206-
new_build_flags.c_str(), is_spv));
207-
208-
l0_kernel = checkL0Errors(l0_module);
209-
gpuAssert(zeKernelGetProperties(l0_kernel, &props));
210-
n_spills = props.spillMemSize;
211-
212-
if (debugEnabled)
213-
std::cout << "(I): Kernel has now " << n_spills << " spills"
214-
<< std::endl;
228+
auto [l0_module, l0_kernel, n_spills] = compileLevelZeroObjects(
229+
binary_ptr, binary_size, kernel_name, l0_device, l0_context,
230+
build_flags(), is_spv);
231+
232+
if (debugEnabled)
233+
std::cout << "(I): Kernel has now " << n_spills << " spills"
234+
<< std::endl;
235+
}
215236
}
216-
}
217237

218-
auto mod = new sycl::kernel_bundle<sycl::bundle_state::executable>(
219-
sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
220-
sycl::bundle_state::executable>(
221-
{l0_module, sycl::ext::oneapi::level_zero::ownership::transfer},
222-
ctx));
223-
sycl::kernel *fun =
224-
new sycl::kernel(sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
225-
{*mod, l0_kernel, sycl::ext::oneapi::level_zero::ownership::transfer},
226-
ctx));
227-
auto kernel_py =
228-
PyCapsule_New(reinterpret_cast<void *>(fun), "kernel", freeKernel);
229-
auto kernel_bundle_py = PyCapsule_New(reinterpret_cast<void *>(mod),
230-
"kernel_bundle", freeKernelBundle);
231-
232-
return Py_BuildValue("(OOii)", kernel_bundle_py, kernel_py, n_regs, n_spills);
238+
auto n_regs = build_flags.n_regs();
239+
240+
auto mod = new sycl::kernel_bundle<sycl::bundle_state::executable>(
241+
sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
242+
sycl::bundle_state::executable>(
243+
{l0_module, sycl::ext::oneapi::level_zero::ownership::transfer},
244+
ctx));
245+
sycl::kernel *fun = new sycl::kernel(
246+
sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
247+
{*mod, l0_kernel,
248+
sycl::ext::oneapi::level_zero::ownership::transfer},
249+
ctx));
250+
auto kernel_py =
251+
PyCapsule_New(reinterpret_cast<void *>(fun), "kernel", freeKernel);
252+
auto kernel_bundle_py = PyCapsule_New(reinterpret_cast<void *>(mod),
253+
"kernel_bundle", freeKernelBundle);
254+
255+
return Py_BuildValue("(OOii)", kernel_bundle_py, kernel_py, n_regs,
256+
n_spills);
257+
258+
} catch (const std::exception &e) {
259+
char err[1024] = {0};
260+
std::string_view error_str(e.what());
261+
strncat(err, error_str.data(), std::min(error_str.size(), size_t(1024)));
262+
PyGILState_STATE gil_state;
263+
gil_state = PyGILState_Ensure();
264+
PyErr_SetString(PyExc_RuntimeError, err);
265+
std::cerr << "Error during Intel loadBinary: " << err << std::endl;
266+
PyGILState_Release(gil_state);
267+
return NULL;
268+
}
233269
}
234270

235271
static PyObject *initContext(PyObject *self, PyObject *args) {

0 commit comments

Comments
 (0)