@@ -102,15 +102,17 @@ void freeKernelBundle(PyObject *p) {
102102 PyCapsule_GetPointer (p , "kernel_bundle "));
103103}
104104
105+ using Spills = int32_t ;
106+
105107template < typename L0_DEVICE , typename L0_CONTEXT >
106- std ::tuple < ze_module_handle_t , ze_kernel_handle_t , int32_t , int32_t >
108+ std ::tuple < ze_module_handle_t , ze_kernel_handle_t , Spills >
107109compileLevelZeroObjects (uint8_t * binary_ptr , const size_t binary_size ,
108110 const std ::string & kernel_name , L0_DEVICE l0_device ,
109- L0_CONTEXT l0_context , const std :: string & build_flags ,
110- const bool is_spv ) {
111+ L0_CONTEXT l0_context ,
112+ const std :: string & build_flags , const bool is_spv ) {
111113 auto l0_module =
112114 checkSyclErrors (create_module (l0_context , l0_device , binary_ptr ,
113- binary_size , build_flags .c_str (), is_spv ));
115+ binary_size , build_flags .data (), is_spv ));
114116
115117 // Retrieve the kernel properties (e.g. register spills).
116118 auto l0_kernel = checkSyclErrors (create_function (l0_module , kernel_name ));
@@ -121,20 +123,58 @@ compileLevelZeroObjects(uint8_t *binary_ptr, const size_t binary_size,
121123 checkSyclErrors (
122124 std ::make_tuple (NULL , zeKernelGetProperties (l0_kernel , & props )));
123125
124- int32_t n_spills = props .spillMemSize ;
125- const int32_t n_regs = 0 ;
126+ const int32_t n_spills = props .spillMemSize ;
126127
127- return std ::make_tuple (l0_module , l0_kernel , n_regs , n_spills );
128+ return std ::make_tuple (l0_module , l0_kernel , n_spills );
128129}
129130
131+ struct BuildFlags {
132+ std ::string build_flags_str ;
133+
134+ const std ::string LARGE_GRF_FLAG {"- cl - intel - 256 - GRF - per - thread "};
135+ const std ::string SMALL_GRF_FLAG {"-cl-intel-128-GRF-per-thread" };
136+ const std ::string AUTO_GRF_FLAG {"-cl-intel-enable-auto-large-GRF-mode" };
137+
138+ BuildFlags (const char * build_flags ) : build_flags_str (build_flags ) {}
139+
140+ const std ::string & operator ()() const {
141+ return build_flags_str ;
142+ }
143+
144+ int32_t n_regs () {
145+ if (build_flags_str .find (LARGE_GRF_FLAG ) != std ::string ::npos ) {
146+ return 256 ;
147+ }
148+ if (build_flags_str .find (SMALL_GRF_FLAG ) != std ::string ::npos ) {
149+ return 128 ;
150+ }
151+ // TODO: arguably we could return 128 if we find no flag instead of 0. For
152+ // now, stick with the conservative choice and alert the user only if a
153+ // specific GRF mode is specified.
154+ return 0 ;
155+ }
156+
157+ const bool hasGRFSizeFlag () {
158+ if (build_flags_str .find (LARGE_GRF_FLAG ) != std ::string ::npos ||
159+ build_flags_str .find (SMALL_GRF_FLAG ) != std ::string ::npos ||
160+ build_flags_str .find (AUTO_GRF_FLAG ) != std ::string ::npos ) {
161+ return true;
162+ } else {
163+ return false;
164+ }
165+ }
166+
167+ void addLargeGRFSizeFlag () { build_flags_str = build_flags_str .append (" " + LARGE_GRF_FLAG ); }
168+ };
169+
130170static PyObject * loadBinary (PyObject * self , PyObject * args ) {
131- const char * name , * build_flags ;
171+ const char * name , * build_flags_ptr ;
132172 int shared ;
133173 PyObject * py_bytes ;
134174 int devId ;
135175
136- if (!PyArg_ParseTuple (args , "sSisi" , & name , & py_bytes , & shared , & build_flags ,
137- & devId )) {
176+ if (!PyArg_ParseTuple (args , "sSisi" , & name , & py_bytes , & shared ,
177+ & build_flags_ptr , & devId )) {
138178 std ::cerr << "loadBinary arg parse failed" << std ::endl ;
139179 return NULL ;
140180 }
@@ -144,6 +184,8 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
144184 return NULL ;
145185 }
146186
187+ BuildFlags build_flags (build_flags_ptr );
188+
147189 try {
148190
149191 const auto & sycl_l0_device_pair = g_sycl_l0_device_list [devId ];
@@ -164,24 +206,13 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
164206 isEnvValueBool (getStrEnv ("TRITON_XPU_GEN_NATIVE_CODE" ));
165207 const bool is_spv = use_native_code ? !(* use_native_code ) : true;
166208
167- auto [l0_module , l0_kernel , n_regs , n_spills ] =
209+ auto [l0_module , l0_kernel , n_spills ] =
168210 compileLevelZeroObjects (binary_ptr , binary_size , kernel_name , l0_device ,
169- l0_context , build_flags , is_spv );
211+ l0_context , build_flags () , is_spv );
170212
171213 if (is_spv ) {
172214 constexpr int32_t max_reg_spill = 1000 ;
173- std ::string build_flags_str (build_flags );
174- bool is_GRF_mode_specified = false;
175-
176- // Check whether the GRF mode is specified by the build flags.
177- if (build_flags_str .find ("-cl-intel-256-GRF-per-thread" ) !=
178- std ::string ::npos ||
179- build_flags_str .find ("-cl-intel-128-GRF-per-thread" ) !=
180- std ::string ::npos ||
181- build_flags_str .find ("-cl-intel-enable-auto-large-GRF-mode" ) !=
182- std ::string ::npos ) {
183- is_GRF_mode_specified = true;
184- }
215+ const bool is_GRF_mode_specified = build_flags .hasGRFSizeFlag ();
185216
186217 // If the register mode isn't set, and the number of spills is greater
187218 // than the threshold, recompile the kernel using large GRF mode.
@@ -193,18 +224,19 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
193224 << " spills, recompiling the kernel using large GRF mode"
194225 << std ::endl ;
195226
196- const std ::string new_build_flags =
197- build_flags_str .append (" -cl-intel-256-GRF-per-thread" );
227+ build_flags .addLargeGRFSizeFlag ();
198228
199- auto [l0_module , l0_kernel , n_regs , n_spills ] = compileLevelZeroObjects (
229+ auto [l0_module , l0_kernel , n_spills ] = compileLevelZeroObjects (
200230 binary_ptr , binary_size , kernel_name , l0_device , l0_context ,
201- new_build_flags , is_spv );
202-
231+ build_flags () , is_spv );
232+
203233 if (debugEnabled )
204234 std ::cout << "(I): Kernel has now " << n_spills << " spills"
205235 << std ::endl ;
206236 }
207237 }
238+
239+ auto n_regs = build_flags .n_regs ();
208240
209241 auto mod = new sycl ::kernel_bundle < sycl ::bundle_state ::executable > (
210242 sycl ::make_kernel_bundle < sycl ::backend ::ext_oneapi_level_zero ,
0 commit comments