Skip to content

Commit 4b1101c

Browse files
author
Diptorup Deb
committed
Fix various issues pointed out by pylint.
1 parent 9bc32c3 commit 4b1101c

File tree

6 files changed

+156
-99
lines changed

6 files changed

+156
-99
lines changed

numba_dpex/experimental/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
"""Contains experimental features that are meant as engineering preview and not
6+
yet production ready.
7+
"""
8+
59
from numba.core.imputils import Registry
610

711
from .decorators import kernel
@@ -15,7 +19,10 @@
1519

1620

1721
@lower_constant(KernelDispatcherType)
18-
def dpex_dispatcher_const(context, builder, ty, pyval):
22+
def dpex_dispatcher_const(context):
23+
"""Dummy lowerer for a KernelDispatcherType object. It is added so that a
24+
KernelDispatcher can be passed as an argument to dpjit.
25+
"""
1926
return context.get_dummy_value()
2027

2128

numba_dpex/experimental/decorators.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,13 @@ def _kernel_dispatcher(pyfunc, sigs=None):
3434

3535
if func_or_sig is None:
3636
return _kernel_dispatcher
37-
elif isinstance(func_or_sig, str):
37+
38+
if isinstance(func_or_sig, str):
3839
raise NotImplementedError(
3940
"Specifying signatures as string is not yet supported by numba-dpex"
4041
)
41-
elif isinstance(func_or_sig, list) or sigutils.is_signature(func_or_sig):
42+
43+
if isinstance(func_or_sig, list) or sigutils.is_signature(func_or_sig):
4244
# String signatures are not supported as passing usm_ndarray type as
4345
# a string is not possible. Numba's sigutils relies on the type being
4446
# available in Numba's `types.__dict__` and dpex types are not
@@ -64,13 +66,12 @@ def _specialized_kernel_dispatcher(pyfunc):
6466
)
6567

6668
return _specialized_kernel_dispatcher
67-
else:
68-
func = func_or_sig
69-
if not inspect.isfunction(func):
70-
raise ValueError(
71-
"Argument passed to the kernel decorator is neither a "
72-
"function object, nor a signature. If you are trying to "
73-
"specialize the kernel that takes a single argument, specify "
74-
"the return type as void explicitly."
75-
)
76-
return _kernel_dispatcher(func)
69+
func = func_or_sig
70+
if not inspect.isfunction(func):
71+
raise ValueError(
72+
"Argument passed to the kernel decorator is neither a "
73+
"function object, nor a signature. If you are trying to "
74+
"specialize the kernel that takes a single argument, specify "
75+
"the return type as void explicitly."
76+
)
77+
return _kernel_dispatcher(func)

numba_dpex/experimental/kernel_dispatcher.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class _KernelCompiler(_FunctionCompiler):
3636
functions.
3737
"""
3838

39-
def _check_queue_equivalence_of_args(
39+
def check_queue_equivalence_of_args(
4040
self, py_func_name: str, args: [types.Type, ...]
4141
):
4242
"""Evaluates if all DpnpNdArray arguments passed to a kernel function
@@ -96,11 +96,11 @@ def compile(self, args, return_type):
9696
kcres = self._compile_cached(args, return_type)
9797
if kcres.status:
9898
return kcres
99-
else:
100-
raise kcres.cres_or_error
99+
100+
raise kcres.cres_or_error
101101

102102
def _compile_cached(
103-
self, kernel_args, return_type: types.Type
103+
self, args, return_type: types.Type
104104
) -> _KernelCompileResult:
105105
"""Compiles the kernel function to bitcode and generates a host-callable
106106
wrapper to submit the kernel to a SYCL queue.
@@ -132,16 +132,14 @@ def _compile_cached(
132132
CompileResult: A CompileResult object storing the LLVM library for
133133
the host-callable wrapper function.
134134
"""
135-
key = tuple(kernel_args), return_type
135+
key = tuple(args), return_type
136136
try:
137137
return _KernelCompileResult(False, self._failed_cache[key], None)
138138
except KeyError:
139139
pass
140140

141141
try:
142-
kernel_cres: CompileResult = self._compile_core(
143-
kernel_args, return_type
144-
)
142+
kernel_cres: CompileResult = self._compile_core(args, return_type)
145143

146144
kernel_library = kernel_cres.library
147145
kernel_fndesc = kernel_cres.fndesc
@@ -155,14 +153,15 @@ def _compile_cached(
155153
with open(
156154
kernel_cres.fndesc.llvm_func_name + ".ll",
157155
"w",
156+
encoding="UTF-8",
158157
) as f:
159158
f.write(kernel_cres.library._final_module.__str__())
160159

161160
except errors.TypingError as e:
162161
self._failed_cache[key] = e
163162
return _KernelCompileResult(False, e, None)
164-
else:
165-
return _KernelCompileResult(True, kernel_cres, kernel_module)
163+
164+
return _KernelCompileResult(True, kernel_cres, kernel_module)
166165

167166

168167
class KernelDispatcher(Dispatcher):
@@ -205,7 +204,7 @@ def __init__(
205204
can_fallback,
206205
exact_match_required=False,
207206
)
208-
# XXX: What does this function do exactly?
207+
209208
functools.update_wrapper(self, pyfunc)
210209

211210
self.targetoptions = targetoptions
@@ -247,7 +246,7 @@ def typeof_pyval(self, val):
247246
self._types_active_call.append(tp)
248247
return tp
249248

250-
def add_overload(self, cres, kernel_module):
249+
def add_bitcode_overload(self, cres, kernel_module):
251250
args = tuple(cres.signature.args)
252251
self.overloads[args] = kernel_module
253252

@@ -280,7 +279,7 @@ def cb_llvm(dur):
280279
args, return_type = sigutils.normalize_signature(sig)
281280

282281
try:
283-
self._compiler._check_queue_equivalence_of_args(
282+
self._compiler.check_queue_equivalence_of_args(
284283
self._kernel_name, args
285284
)
286285
except ExecutionQueueInferenceError as eqie:
@@ -294,11 +293,11 @@ def cb_llvm(dur):
294293
# FIXME: Enable caching
295294
# Add code to enable on disk caching of a binary spirv kernel
296295
self._cache_misses[sig] += 1
297-
ev_details = dict(
298-
dispatcher=self,
299-
args=args,
300-
return_type=return_type,
301-
)
296+
ev_details = {
297+
"dispatcher": self,
298+
"args": args,
299+
"return_type": return_type,
300+
}
302301
with ev.trigger_event("numba_dpex:compile", data=ev_details):
303302
try:
304303
kcres: _KernelCompileResult = self._compiler.compile(
@@ -312,7 +311,9 @@ def folded(args, kws):
312311
)[1]
313312

314313
raise e.bind_fold_arguments(folded)
315-
self.add_overload(kcres.cres_or_error, kcres.kernel_module)
314+
self.add_bitcode_overload(
315+
kcres.cres_or_error, kcres.kernel_module
316+
)
316317

317318
# FIXME: enable caching
318319

0 commit comments

Comments
 (0)