@@ -26,25 +26,14 @@ static std::vector<ze_device_handle_t> g_devices;
2626static std ::vector < std ::pair < sycl ::device , ze_device_handle_t >>
2727 g_sycl_l0_device_list ;
2828
29- static inline void gpuAssert (ze_result_t code ) {
30- if (code != ZE_RESULT_SUCCESS ) {
31- auto str = parseZeResultCode (code );
32- char err [1024 ] = {0 };
33- strncat (err , str .c_str (), std ::min (str .size (), size_t (1024 )));
34- PyGILState_STATE gil_state ;
35- gil_state = PyGILState_Ensure ();
36- PyErr_SetString (PyExc_RuntimeError , err );
37- PyGILState_Release (gil_state );
38- }
39- }
40-
4129template < typename T >
4230static inline T checkSyclErrors (const std ::tuple < T , ze_result_t > tuple ) {
43- gpuAssert ( std ::get < 1 > (tuple ) ) ;
44- if (PyErr_Occurred ())
45- return nullptr ;
46- else
31+ const auto code = std ::get < 1 > (tuple );
32+ if (code != ZE_RESULT_SUCCESS ) {
33+ throw std :: runtime_error ( parseZeResultCode ( code )) ;
34+ } else {
4735 return std ::get < 0 > (tuple );
36+ }
4837}
4938
5039static PyObject * getDeviceProperties (PyObject * self , PyObject * args ) {
@@ -113,6 +102,31 @@ void freeKernelBundle(PyObject *p) {
113102 PyCapsule_GetPointer (p , "kernel_bundle "));
114103}
115104
105+ template < typename L0_DEVICE , typename L0_CONTEXT >
106+ std ::tuple < ze_module_handle_t , ze_kernel_handle_t , int32_t , int32_t >
107+ compileLevelZeroObjects (uint8_t * binary_ptr , const size_t binary_size ,
108+ const std ::string & kernel_name , L0_DEVICE l0_device ,
109+ L0_CONTEXT l0_context , const std ::string & build_flags ,
110+ const bool is_spv ) {
111+ auto l0_module =
112+ checkSyclErrors (create_module (l0_context , l0_device , binary_ptr ,
113+ binary_size , build_flags .c_str (), is_spv ));
114+
115+ // Retrieve the kernel properties (e.g. register spills).
116+ auto l0_kernel = checkSyclErrors (create_function (l0_module , kernel_name ));
117+
118+ ze_kernel_properties_t props ;
119+ props .stype = ZE_STRUCTURE_TYPE_KERNEL_PROPERTIES ;
120+ props .pNext = nullptr ;
121+ checkSyclErrors (
122+ std ::make_tuple (NULL , zeKernelGetProperties (l0_kernel , & props )));
123+
124+ int32_t n_spills = props .spillMemSize ;
125+ const int32_t n_regs = 0 ;
126+
127+ return std ::make_tuple (l0_module , l0_kernel , n_regs , n_spills );
128+ }
129+
116130static PyObject * loadBinary (PyObject * self , PyObject * args ) {
117131 const char * name , * build_flags ;
118132 int shared ;
@@ -130,106 +144,97 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
130144 return NULL ;
131145 }
132146
133- const auto & sycl_l0_device_pair = g_sycl_l0_device_list [devId ];
134- const sycl ::device sycl_device = sycl_l0_device_pair .first ;
135-
136- std ::string kernel_name = name ;
137- const size_t binary_size = PyBytes_Size (py_bytes );
138-
139- uint8_t * binary_ptr = (uint8_t * )PyBytes_AsString (py_bytes );
140- const auto ctx = sycl_device .get_platform ().ext_oneapi_get_default_context ();
141- const auto l0_device =
142- sycl ::get_native < sycl ::backend ::ext_oneapi_level_zero > (sycl_device );
143- const auto l0_context =
144- sycl ::get_native < sycl ::backend ::ext_oneapi_level_zero > (ctx );
145-
146- const auto use_native_code =
147- isEnvValueBool (getStrEnv ("TRITON_XPU_GEN_NATIVE_CODE" ));
148- const bool is_spv = use_native_code ? !(* use_native_code ) : true;
149-
150- auto l0_module = checkSyclErrors (create_module (
151- l0_context , l0_device , binary_ptr , binary_size , build_flags , is_spv ));
152-
153- auto checkL0Errors = [& ](auto l0_module ) -> ze_kernel_handle_t {
154- if (PyErr_Occurred ()) {
155- // check for errors from module creation
156- return NULL ;
157- }
158- ze_kernel_handle_t l0_kernel =
159- checkSyclErrors (create_function (l0_module , kernel_name ));
160- if (PyErr_Occurred ()) {
161- // check for errors from kernel creation
162- return NULL ;
163- }
164- return l0_kernel ;
165- };
166-
167- // Retrieve the kernel properties (e.g. register spills).
168- ze_kernel_handle_t l0_kernel = checkL0Errors (l0_module );
169- ze_kernel_properties_t props ;
170- props .stype = ZE_STRUCTURE_TYPE_KERNEL_PROPERTIES ;
171- props .pNext = nullptr ;
172- gpuAssert (zeKernelGetProperties (l0_kernel , & props ));
173-
174- int32_t n_spills = props .spillMemSize ;
175- const int32_t n_regs = 0 ;
176-
177- if (is_spv ) {
178- constexpr int32_t max_reg_spill = 1000 ;
179- std ::string build_flags_str (build_flags );
180- bool is_GRF_mode_specified = false;
181-
182- // Check whether the GRF mode is specified by the build flags.
183- if (build_flags_str .find ("-cl-intel-256-GRF-per-thread" ) !=
184- std ::string ::npos ||
185- build_flags_str .find ("-cl-intel-128-GRF-per-thread" ) !=
186- std ::string ::npos ||
187- build_flags_str .find ("-cl-intel-enable-auto-large-GRF-mode" ) !=
188- std ::string ::npos ) {
189- is_GRF_mode_specified = true;
190- }
147+ try {
148+
149+ const auto & sycl_l0_device_pair = g_sycl_l0_device_list [devId ];
150+ const sycl ::device sycl_device = sycl_l0_device_pair .first ;
151+
152+ const std ::string kernel_name = name ;
153+ const size_t binary_size = PyBytes_Size (py_bytes );
154+
155+ uint8_t * binary_ptr = (uint8_t * )PyBytes_AsString (py_bytes );
156+ const auto ctx =
157+ sycl_device .get_platform ().ext_oneapi_get_default_context ();
158+ const auto l0_device =
159+ sycl ::get_native < sycl ::backend ::ext_oneapi_level_zero > (sycl_device );
160+ const auto l0_context =
161+ sycl ::get_native < sycl ::backend ::ext_oneapi_level_zero > (ctx );
162+
163+ const auto use_native_code =
164+ isEnvValueBool (getStrEnv ("TRITON_XPU_GEN_NATIVE_CODE" ));
165+ const bool is_spv = use_native_code ? !(* use_native_code ) : true;
166+
167+ auto [l0_module , l0_kernel , n_regs , n_spills ] =
168+ compileLevelZeroObjects (binary_ptr , binary_size , kernel_name , l0_device ,
169+ l0_context , build_flags , is_spv );
170+
171+ if (is_spv ) {
172+ constexpr int32_t max_reg_spill = 1000 ;
173+ std ::string build_flags_str (build_flags );
174+ bool is_GRF_mode_specified = false;
175+
176+ // Check whether the GRF mode is specified by the build flags.
177+ if (build_flags_str .find ("-cl-intel-256-GRF-per-thread" ) !=
178+ std ::string ::npos ||
179+ build_flags_str .find ("-cl-intel-128-GRF-per-thread" ) !=
180+ std ::string ::npos ||
181+ build_flags_str .find ("-cl-intel-enable-auto-large-GRF-mode" ) !=
182+ std ::string ::npos ) {
183+ is_GRF_mode_specified = true;
184+ }
191185
192- // If the register mode isn't set, and the number of spills is greater
193- // than the threshold, recompile the kernel using large GRF mode.
194- if (!is_GRF_mode_specified && n_spills > max_reg_spill ) {
195- const std ::optional < bool > debugEnabled =
186+ // If the register mode isn't set, and the number of spills is greater
187+ // than the threshold, recompile the kernel using large GRF mode.
188+ if (!is_GRF_mode_specified && n_spills > max_reg_spill ) {
189+ const std ::optional < bool > debugEnabled =
196190 isEnvValueBool (getStrEnv ("TRITON_DEBUG" ));
197- if (debugEnabled )
198- std ::cout << "(I): Detected " << n_spills
199- << " spills, recompiling kernel \"" << kernel_name
200- << "\" using large GRF mode" << std ::endl ;
201-
202- const std ::string new_build_flags =
203- build_flags_str .append (" -cl-intel-256-GRF-per-thread" );
204- l0_module = checkSyclErrors (
205- create_module (l0_context , l0_device , binary_ptr , binary_size ,
206- new_build_flags .c_str (), is_spv ));
207-
208- l0_kernel = checkL0Errors (l0_module );
209- gpuAssert (zeKernelGetProperties (l0_kernel , & props ));
210- n_spills = props .spillMemSize ;
211-
191+ if (debugEnabled )
192+ std ::cout << "(I): Detected " << n_spills
193+ << " spills, recompiling the kernel using large GRF mode"
194+ << std ::endl ;
195+
196+ const std ::string new_build_flags =
197+ build_flags_str .append (" -cl-intel-256-GRF-per-thread" );
198+
199+ auto [l0_module , l0_kernel , n_regs , n_spills ] = compileLevelZeroObjects (
200+ binary_ptr , binary_size , kernel_name , l0_device , l0_context ,
201+ new_build_flags , is_spv );
202+
212203 if (debugEnabled )
213204 std ::cout << "(I): Kernel has now " << n_spills << " spills"
214205 << std ::endl ;
206+ }
215207 }
216- }
217208
218- auto mod = new sycl ::kernel_bundle < sycl ::bundle_state ::executable > (
219- sycl ::make_kernel_bundle < sycl ::backend ::ext_oneapi_level_zero ,
220- sycl ::bundle_state ::executable > (
221- {l0_module , sycl ::ext ::oneapi ::level_zero ::ownership ::transfer },
222- ctx ));
223- sycl ::kernel * fun =
224- new sycl ::kernel (sycl ::make_kernel < sycl ::backend ::ext_oneapi_level_zero > (
225- {* mod , l0_kernel , sycl ::ext ::oneapi ::level_zero ::ownership ::transfer },
226- ctx ));
227- auto kernel_py =
228- PyCapsule_New (reinterpret_cast < void * > (fun ), "kernel" , freeKernel );
229- auto kernel_bundle_py = PyCapsule_New (reinterpret_cast < void * > (mod ),
230- "kernel_bundle" , freeKernelBundle );
231-
232- return Py_BuildValue ("(OOii)" , kernel_bundle_py , kernel_py , n_regs , n_spills );
209+ auto mod = new sycl ::kernel_bundle < sycl ::bundle_state ::executable > (
210+ sycl ::make_kernel_bundle < sycl ::backend ::ext_oneapi_level_zero ,
211+ sycl ::bundle_state ::executable > (
212+ {l0_module , sycl ::ext ::oneapi ::level_zero ::ownership ::transfer },
213+ ctx ));
214+ sycl ::kernel * fun = new sycl ::kernel (
215+ sycl ::make_kernel < sycl ::backend ::ext_oneapi_level_zero > (
216+ {* mod , l0_kernel ,
217+ sycl ::ext ::oneapi ::level_zero ::ownership ::transfer },
218+ ctx ));
219+ auto kernel_py =
220+ PyCapsule_New (reinterpret_cast < void * > (fun ), "kernel" , freeKernel );
221+ auto kernel_bundle_py = PyCapsule_New (reinterpret_cast < void * > (mod ),
222+ "kernel_bundle" , freeKernelBundle );
223+
224+ return Py_BuildValue ("(OOii)" , kernel_bundle_py , kernel_py , n_regs ,
225+ n_spills );
226+
227+ } catch (const std ::exception & e ) {
228+ char err [1024 ] = {0 };
229+ std ::string_view error_str (e .what ());
230+ strncat (err , error_str .data (), std ::min (error_str .size (), size_t (1024 )));
231+ PyGILState_STATE gil_state ;
232+ gil_state = PyGILState_Ensure ();
233+ PyErr_SetString (PyExc_RuntimeError , err );
234+ std ::cerr << "Error during Intel loadBinary: " << err << std ::endl ;
235+ PyGILState_Release (gil_state );
236+ return NULL ;
237+ }
233238}
234239
235240static PyObject * initContext (PyObject * self , PyObject * args ) {
0 commit comments