Skip to content

Commit 94454a8

Browse files
committed
Refactor loadBinary to support n_regs
1 parent fe45283 commit 94454a8

File tree

1 file changed

+114
-109
lines changed

1 file changed

+114
-109
lines changed

third_party/intel/backend/driver.c

Lines changed: 114 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,14 @@ static std::vector<ze_device_handle_t> g_devices;
2626
static std::vector<std::pair<sycl::device, ze_device_handle_t>>
2727
g_sycl_l0_device_list;
2828

29-
static inline void gpuAssert(ze_result_t code) {
30-
if (code != ZE_RESULT_SUCCESS) {
31-
auto str = parseZeResultCode(code);
32-
char err[1024] = {0};
33-
strncat(err, str.c_str(), std::min(str.size(), size_t(1024)));
34-
PyGILState_STATE gil_state;
35-
gil_state = PyGILState_Ensure();
36-
PyErr_SetString(PyExc_RuntimeError, err);
37-
PyGILState_Release(gil_state);
38-
}
39-
}
40-
4129
template <typename T>
4230
static inline T checkSyclErrors(const std::tuple<T, ze_result_t> tuple) {
43-
gpuAssert(std::get<1>(tuple));
44-
if (PyErr_Occurred())
45-
return nullptr;
46-
else
31+
const auto code = std::get<1>(tuple);
32+
if (code != ZE_RESULT_SUCCESS) {
33+
throw std::runtime_error(parseZeResultCode(code));
34+
} else {
4735
return std::get<0>(tuple);
36+
}
4837
}
4938

5039
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
@@ -113,6 +102,31 @@ void freeKernelBundle(PyObject *p) {
113102
PyCapsule_GetPointer(p, "kernel_bundle"));
114103
}
115104

105+
template <typename L0_DEVICE, typename L0_CONTEXT>
106+
std::tuple<ze_module_handle_t, ze_kernel_handle_t, int32_t, int32_t>
107+
compileLevelZeroObjects(uint8_t *binary_ptr, const size_t binary_size,
108+
const std::string &kernel_name, L0_DEVICE l0_device,
109+
L0_CONTEXT l0_context, const std::string &build_flags,
110+
const bool is_spv) {
111+
auto l0_module =
112+
checkSyclErrors(create_module(l0_context, l0_device, binary_ptr,
113+
binary_size, build_flags.c_str(), is_spv));
114+
115+
// Retrieve the kernel properties (e.g. register spills).
116+
auto l0_kernel = checkSyclErrors(create_function(l0_module, kernel_name));
117+
118+
ze_kernel_properties_t props;
119+
props.stype = ZE_STRUCTURE_TYPE_KERNEL_PROPERTIES;
120+
props.pNext = nullptr;
121+
checkSyclErrors(
122+
std::make_tuple(NULL, zeKernelGetProperties(l0_kernel, &props)));
123+
124+
int32_t n_spills = props.spillMemSize;
125+
const int32_t n_regs = 0;
126+
127+
return std::make_tuple(l0_module, l0_kernel, n_regs, n_spills);
128+
}
129+
116130
static PyObject *loadBinary(PyObject *self, PyObject *args) {
117131
const char *name, *build_flags;
118132
int shared;
@@ -130,106 +144,97 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
130144
return NULL;
131145
}
132146

133-
const auto &sycl_l0_device_pair = g_sycl_l0_device_list[devId];
134-
const sycl::device sycl_device = sycl_l0_device_pair.first;
135-
136-
std::string kernel_name = name;
137-
const size_t binary_size = PyBytes_Size(py_bytes);
138-
139-
uint8_t *binary_ptr = (uint8_t *)PyBytes_AsString(py_bytes);
140-
const auto ctx = sycl_device.get_platform().ext_oneapi_get_default_context();
141-
const auto l0_device =
142-
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
143-
const auto l0_context =
144-
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(ctx);
145-
146-
const auto use_native_code =
147-
isEnvValueBool(getStrEnv("TRITON_XPU_GEN_NATIVE_CODE"));
148-
const bool is_spv = use_native_code ? !(*use_native_code) : true;
149-
150-
auto l0_module = checkSyclErrors(create_module(
151-
l0_context, l0_device, binary_ptr, binary_size, build_flags, is_spv));
152-
153-
auto checkL0Errors = [&](auto l0_module) -> ze_kernel_handle_t {
154-
if (PyErr_Occurred()) {
155-
// check for errors from module creation
156-
return NULL;
157-
}
158-
ze_kernel_handle_t l0_kernel =
159-
checkSyclErrors(create_function(l0_module, kernel_name));
160-
if (PyErr_Occurred()) {
161-
// check for errors from kernel creation
162-
return NULL;
163-
}
164-
return l0_kernel;
165-
};
166-
167-
// Retrieve the kernel properties (e.g. register spills).
168-
ze_kernel_handle_t l0_kernel = checkL0Errors(l0_module);
169-
ze_kernel_properties_t props;
170-
props.stype = ZE_STRUCTURE_TYPE_KERNEL_PROPERTIES;
171-
props.pNext = nullptr;
172-
gpuAssert(zeKernelGetProperties(l0_kernel, &props));
173-
174-
int32_t n_spills = props.spillMemSize;
175-
const int32_t n_regs = 0;
176-
177-
if (is_spv) {
178-
constexpr int32_t max_reg_spill = 1000;
179-
std::string build_flags_str(build_flags);
180-
bool is_GRF_mode_specified = false;
181-
182-
// Check whether the GRF mode is specified by the build flags.
183-
if (build_flags_str.find("-cl-intel-256-GRF-per-thread") !=
184-
std::string::npos ||
185-
build_flags_str.find("-cl-intel-128-GRF-per-thread") !=
186-
std::string::npos ||
187-
build_flags_str.find("-cl-intel-enable-auto-large-GRF-mode") !=
188-
std::string::npos) {
189-
is_GRF_mode_specified = true;
190-
}
147+
try {
148+
149+
const auto &sycl_l0_device_pair = g_sycl_l0_device_list[devId];
150+
const sycl::device sycl_device = sycl_l0_device_pair.first;
151+
152+
const std::string kernel_name = name;
153+
const size_t binary_size = PyBytes_Size(py_bytes);
154+
155+
uint8_t *binary_ptr = (uint8_t *)PyBytes_AsString(py_bytes);
156+
const auto ctx =
157+
sycl_device.get_platform().ext_oneapi_get_default_context();
158+
const auto l0_device =
159+
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
160+
const auto l0_context =
161+
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(ctx);
162+
163+
const auto use_native_code =
164+
isEnvValueBool(getStrEnv("TRITON_XPU_GEN_NATIVE_CODE"));
165+
const bool is_spv = use_native_code ? !(*use_native_code) : true;
166+
167+
auto [l0_module, l0_kernel, n_regs, n_spills] =
168+
compileLevelZeroObjects(binary_ptr, binary_size, kernel_name, l0_device,
169+
l0_context, build_flags, is_spv);
170+
171+
if (is_spv) {
172+
constexpr int32_t max_reg_spill = 1000;
173+
std::string build_flags_str(build_flags);
174+
bool is_GRF_mode_specified = false;
175+
176+
// Check whether the GRF mode is specified by the build flags.
177+
if (build_flags_str.find("-cl-intel-256-GRF-per-thread") !=
178+
std::string::npos ||
179+
build_flags_str.find("-cl-intel-128-GRF-per-thread") !=
180+
std::string::npos ||
181+
build_flags_str.find("-cl-intel-enable-auto-large-GRF-mode") !=
182+
std::string::npos) {
183+
is_GRF_mode_specified = true;
184+
}
191185

192-
// If the register mode isn't set, and the number of spills is greater
193-
// than the threshold, recompile the kernel using large GRF mode.
194-
if (!is_GRF_mode_specified && n_spills > max_reg_spill) {
195-
const std::optional<bool> debugEnabled =
186+
// If the register mode isn't set, and the number of spills is greater
187+
// than the threshold, recompile the kernel using large GRF mode.
188+
if (!is_GRF_mode_specified && n_spills > max_reg_spill) {
189+
const std::optional<bool> debugEnabled =
196190
isEnvValueBool(getStrEnv("TRITON_DEBUG"));
197-
if (debugEnabled)
198-
std::cout << "(I): Detected " << n_spills
199-
<< " spills, recompiling kernel \"" << kernel_name
200-
<< "\" using large GRF mode" << std::endl;
201-
202-
const std::string new_build_flags =
203-
build_flags_str.append(" -cl-intel-256-GRF-per-thread");
204-
l0_module = checkSyclErrors(
205-
create_module(l0_context, l0_device, binary_ptr, binary_size,
206-
new_build_flags.c_str(), is_spv));
207-
208-
l0_kernel = checkL0Errors(l0_module);
209-
gpuAssert(zeKernelGetProperties(l0_kernel, &props));
210-
n_spills = props.spillMemSize;
211-
191+
if (debugEnabled)
192+
std::cout << "(I): Detected " << n_spills
193+
<< " spills, recompiling the kernel using large GRF mode"
194+
<< std::endl;
195+
196+
const std::string new_build_flags =
197+
build_flags_str.append(" -cl-intel-256-GRF-per-thread");
198+
199+
auto [l0_module, l0_kernel, n_regs, n_spills] = compileLevelZeroObjects(
200+
binary_ptr, binary_size, kernel_name, l0_device, l0_context,
201+
new_build_flags, is_spv);
202+
212203
if (debugEnabled)
213204
std::cout << "(I): Kernel has now " << n_spills << " spills"
214205
<< std::endl;
206+
}
215207
}
216-
}
217208

218-
auto mod = new sycl::kernel_bundle<sycl::bundle_state::executable>(
219-
sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
220-
sycl::bundle_state::executable>(
221-
{l0_module, sycl::ext::oneapi::level_zero::ownership::transfer},
222-
ctx));
223-
sycl::kernel *fun =
224-
new sycl::kernel(sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
225-
{*mod, l0_kernel, sycl::ext::oneapi::level_zero::ownership::transfer},
226-
ctx));
227-
auto kernel_py =
228-
PyCapsule_New(reinterpret_cast<void *>(fun), "kernel", freeKernel);
229-
auto kernel_bundle_py = PyCapsule_New(reinterpret_cast<void *>(mod),
230-
"kernel_bundle", freeKernelBundle);
231-
232-
return Py_BuildValue("(OOii)", kernel_bundle_py, kernel_py, n_regs, n_spills);
209+
auto mod = new sycl::kernel_bundle<sycl::bundle_state::executable>(
210+
sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
211+
sycl::bundle_state::executable>(
212+
{l0_module, sycl::ext::oneapi::level_zero::ownership::transfer},
213+
ctx));
214+
sycl::kernel *fun = new sycl::kernel(
215+
sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
216+
{*mod, l0_kernel,
217+
sycl::ext::oneapi::level_zero::ownership::transfer},
218+
ctx));
219+
auto kernel_py =
220+
PyCapsule_New(reinterpret_cast<void *>(fun), "kernel", freeKernel);
221+
auto kernel_bundle_py = PyCapsule_New(reinterpret_cast<void *>(mod),
222+
"kernel_bundle", freeKernelBundle);
223+
224+
return Py_BuildValue("(OOii)", kernel_bundle_py, kernel_py, n_regs,
225+
n_spills);
226+
227+
} catch (const std::exception &e) {
228+
char err[1024] = {0};
229+
std::string_view error_str(e.what());
230+
strncat(err, error_str.data(), std::min(error_str.size(), size_t(1024)));
231+
PyGILState_STATE gil_state;
232+
gil_state = PyGILState_Ensure();
233+
PyErr_SetString(PyExc_RuntimeError, err);
234+
std::cerr << "Error during Intel loadBinary: " << err << std::endl;
235+
PyGILState_Release(gil_state);
236+
return NULL;
237+
}
233238
}
234239

235240
static PyObject *initContext(PyObject *self, PyObject *args) {

0 commit comments

Comments
 (0)