@@ -36,7 +36,7 @@ class _KernelCompiler(_FunctionCompiler):
36
36
functions.
37
37
"""
38
38
39
- def _check_queue_equivalence_of_args (
39
+ def check_queue_equivalence_of_args (
40
40
self , py_func_name : str , args : [types .Type , ...]
41
41
):
42
42
"""Evaluates if all DpnpNdArray arguments passed to a kernel function
@@ -96,11 +96,11 @@ def compile(self, args, return_type):
96
96
kcres = self ._compile_cached (args , return_type )
97
97
if kcres .status :
98
98
return kcres
99
- else :
100
- raise kcres .cres_or_error
99
+
100
+ raise kcres .cres_or_error
101
101
102
102
def _compile_cached (
103
- self , kernel_args , return_type : types .Type
103
+ self , args , return_type : types .Type
104
104
) -> _KernelCompileResult :
105
105
"""Compiles the kernel function to bitcode and generates a host-callable
106
106
wrapper to submit the kernel to a SYCL queue.
@@ -132,16 +132,14 @@ def _compile_cached(
132
132
CompileResult: A CompileResult object storing the LLVM library for
133
133
the host-callable wrapper function.
134
134
"""
135
- key = tuple (kernel_args ), return_type
135
+ key = tuple (args ), return_type
136
136
try :
137
137
return _KernelCompileResult (False , self ._failed_cache [key ], None )
138
138
except KeyError :
139
139
pass
140
140
141
141
try :
142
- kernel_cres : CompileResult = self ._compile_core (
143
- kernel_args , return_type
144
- )
142
+ kernel_cres : CompileResult = self ._compile_core (args , return_type )
145
143
146
144
kernel_library = kernel_cres .library
147
145
kernel_fndesc = kernel_cres .fndesc
@@ -155,14 +153,15 @@ def _compile_cached(
155
153
with open (
156
154
kernel_cres .fndesc .llvm_func_name + ".ll" ,
157
155
"w" ,
156
+ encoding = "UTF-8" ,
158
157
) as f :
159
158
f .write (kernel_cres .library ._final_module .__str__ ())
160
159
161
160
except errors .TypingError as e :
162
161
self ._failed_cache [key ] = e
163
162
return _KernelCompileResult (False , e , None )
164
- else :
165
- return _KernelCompileResult (True , kernel_cres , kernel_module )
163
+
164
+ return _KernelCompileResult (True , kernel_cres , kernel_module )
166
165
167
166
168
167
class KernelDispatcher (Dispatcher ):
@@ -205,7 +204,7 @@ def __init__(
205
204
can_fallback ,
206
205
exact_match_required = False ,
207
206
)
208
- # XXX: What does this function do exactly?
207
+
209
208
functools .update_wrapper (self , pyfunc )
210
209
211
210
self .targetoptions = targetoptions
@@ -247,7 +246,7 @@ def typeof_pyval(self, val):
247
246
self ._types_active_call .append (tp )
248
247
return tp
249
248
250
- def add_overload (self , cres , kernel_module ):
249
+ def add_bitcode_overload (self , cres , kernel_module ):
251
250
args = tuple (cres .signature .args )
252
251
self .overloads [args ] = kernel_module
253
252
@@ -280,7 +279,7 @@ def cb_llvm(dur):
280
279
args , return_type = sigutils .normalize_signature (sig )
281
280
282
281
try :
283
- self ._compiler ._check_queue_equivalence_of_args (
282
+ self ._compiler .check_queue_equivalence_of_args (
284
283
self ._kernel_name , args
285
284
)
286
285
except ExecutionQueueInferenceError as eqie :
@@ -294,11 +293,11 @@ def cb_llvm(dur):
294
293
# FIXME: Enable caching
295
294
# Add code to enable on disk caching of a binary spirv kernel
296
295
self ._cache_misses [sig ] += 1
297
- ev_details = dict (
298
- dispatcher = self ,
299
- args = args ,
300
- return_type = return_type ,
301
- )
296
+ ev_details = {
297
+ " dispatcher" : self ,
298
+ " args" : args ,
299
+ " return_type" : return_type ,
300
+ }
302
301
with ev .trigger_event ("numba_dpex:compile" , data = ev_details ):
303
302
try :
304
303
kcres : _KernelCompileResult = self ._compiler .compile (
@@ -312,7 +311,9 @@ def folded(args, kws):
312
311
)[1 ]
313
312
314
313
raise e .bind_fold_arguments (folded )
315
- self .add_overload (kcres .cres_or_error , kcres .kernel_module )
314
+ self .add_bitcode_overload (
315
+ kcres .cres_or_error , kcres .kernel_module
316
+ )
316
317
317
318
# FIXME: enable caching
318
319
0 commit comments