Skip to content

Commit d78af61

Browse files
committed
Return n_regs using grf size build str flag
1 parent 7e71ba1 commit d78af61

File tree

1 file changed

+61
-29
lines changed

1 file changed

+61
-29
lines changed

third_party/intel/backend/driver.c

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,17 @@ void freeKernelBundle(PyObject *p) {
102102
PyCapsule_GetPointer(p, "kernel_bundle"));
103103
}
104104

105+
using Spills = int32_t;
106+
105107
template <typename L0_DEVICE, typename L0_CONTEXT>
106-
std::tuple<ze_module_handle_t, ze_kernel_handle_t, int32_t, int32_t>
108+
std::tuple<ze_module_handle_t, ze_kernel_handle_t, Spills>
107109
compileLevelZeroObjects(uint8_t *binary_ptr, const size_t binary_size,
108110
const std::string &kernel_name, L0_DEVICE l0_device,
109-
L0_CONTEXT l0_context, const std::string &build_flags,
110-
const bool is_spv) {
111+
L0_CONTEXT l0_context,
112+
const std::string& build_flags, const bool is_spv) {
111113
auto l0_module =
112114
checkSyclErrors(create_module(l0_context, l0_device, binary_ptr,
113-
binary_size, build_flags.c_str(), is_spv));
115+
binary_size, build_flags.data(), is_spv));
114116

115117
// Retrieve the kernel properties (e.g. register spills).
116118
auto l0_kernel = checkSyclErrors(create_function(l0_module, kernel_name));
@@ -121,20 +123,58 @@ compileLevelZeroObjects(uint8_t *binary_ptr, const size_t binary_size,
121123
checkSyclErrors(
122124
std::make_tuple(NULL, zeKernelGetProperties(l0_kernel, &props)));
123125

124-
int32_t n_spills = props.spillMemSize;
125-
const int32_t n_regs = 0;
126+
const int32_t n_spills = props.spillMemSize;
126127

127-
return std::make_tuple(l0_module, l0_kernel, n_regs, n_spills);
128+
return std::make_tuple(l0_module, l0_kernel, n_spills);
128129
}
129130

131+
struct BuildFlags {
132+
std::string build_flags_str;
133+
134+
const std::string LARGE_GRF_FLAG{"-cl-intel-256-GRF-per-thread"};
135+
const std::string SMALL_GRF_FLAG{"-cl-intel-128-GRF-per-thread"};
136+
const std::string AUTO_GRF_FLAG{"-cl-intel-enable-auto-large-GRF-mode"};
137+
138+
BuildFlags(const char *build_flags) : build_flags_str(build_flags) {}
139+
140+
const std::string& operator()() const {
141+
return build_flags_str;
142+
}
143+
144+
int32_t n_regs() {
145+
if (build_flags_str.find(LARGE_GRF_FLAG) != std::string::npos) {
146+
return 256;
147+
}
148+
if (build_flags_str.find(SMALL_GRF_FLAG) != std::string::npos) {
149+
return 128;
150+
}
151+
// TODO: arguably we could return 128 if we find no flag instead of 0. For
152+
// now, stick with the conservative choice and alert the user only if a
153+
// specific GRF mode is specified.
154+
return 0;
155+
}
156+
157+
const bool hasGRFSizeFlag() {
158+
if (build_flags_str.find(LARGE_GRF_FLAG) != std::string::npos ||
159+
build_flags_str.find(SMALL_GRF_FLAG) != std::string::npos ||
160+
build_flags_str.find(AUTO_GRF_FLAG) != std::string::npos) {
161+
return true;
162+
} else {
163+
return false;
164+
}
165+
}
166+
167+
void addLargeGRFSizeFlag() { build_flags_str = build_flags_str.append(" " + LARGE_GRF_FLAG); }
168+
};
169+
130170
static PyObject *loadBinary(PyObject *self, PyObject *args) {
131-
const char *name, *build_flags;
171+
const char *name, *build_flags_ptr;
132172
int shared;
133173
PyObject *py_bytes;
134174
int devId;
135175

136-
if (!PyArg_ParseTuple(args, "sSisi", &name, &py_bytes, &shared, &build_flags,
137-
&devId)) {
176+
if (!PyArg_ParseTuple(args, "sSisi", &name, &py_bytes, &shared,
177+
&build_flags_ptr, &devId)) {
138178
std::cerr << "loadBinary arg parse failed" << std::endl;
139179
return NULL;
140180
}
@@ -144,6 +184,8 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
144184
return NULL;
145185
}
146186

187+
BuildFlags build_flags(build_flags_ptr);
188+
147189
try {
148190

149191
const auto &sycl_l0_device_pair = g_sycl_l0_device_list[devId];
@@ -164,24 +206,13 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
164206
isEnvValueBool(getStrEnv("TRITON_XPU_GEN_NATIVE_CODE"));
165207
const bool is_spv = use_native_code ? !(*use_native_code) : true;
166208

167-
auto [l0_module, l0_kernel, n_regs, n_spills] =
209+
auto [l0_module, l0_kernel, n_spills] =
168210
compileLevelZeroObjects(binary_ptr, binary_size, kernel_name, l0_device,
169-
l0_context, build_flags, is_spv);
211+
l0_context, build_flags(), is_spv);
170212

171213
if (is_spv) {
172214
constexpr int32_t max_reg_spill = 1000;
173-
std::string build_flags_str(build_flags);
174-
bool is_GRF_mode_specified = false;
175-
176-
// Check whether the GRF mode is specified by the build flags.
177-
if (build_flags_str.find("-cl-intel-256-GRF-per-thread") !=
178-
std::string::npos ||
179-
build_flags_str.find("-cl-intel-128-GRF-per-thread") !=
180-
std::string::npos ||
181-
build_flags_str.find("-cl-intel-enable-auto-large-GRF-mode") !=
182-
std::string::npos) {
183-
is_GRF_mode_specified = true;
184-
}
215+
const bool is_GRF_mode_specified = build_flags.hasGRFSizeFlag();
185216

186217
// If the register mode isn't set, and the number of spills is greater
187218
// than the threshold, recompile the kernel using large GRF mode.
@@ -193,18 +224,19 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
193224
<< " spills, recompiling the kernel using large GRF mode"
194225
<< std::endl;
195226

196-
const std::string new_build_flags =
197-
build_flags_str.append(" -cl-intel-256-GRF-per-thread");
227+
build_flags.addLargeGRFSizeFlag();
198228

199-
auto [l0_module, l0_kernel, n_regs, n_spills] = compileLevelZeroObjects(
229+
auto [l0_module, l0_kernel, n_spills] = compileLevelZeroObjects(
200230
binary_ptr, binary_size, kernel_name, l0_device, l0_context,
201-
new_build_flags, is_spv);
202-
231+
build_flags(), is_spv);
232+
203233
if (debugEnabled)
204234
std::cout << "(I): Kernel has now " << n_spills << " spills"
205235
<< std::endl;
206236
}
207237
}
238+
239+
auto n_regs = build_flags.n_regs();
208240

209241
auto mod = new sycl::kernel_bundle<sycl::bundle_state::executable>(
210242
sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,

0 commit comments

Comments
 (0)