Skip to content

Commit 600a41f

Browse files
author
Diptorup Deb
authored
Merge pull request #1206 from IntelPython/fix/pylint_errors
Fix Pylint issues in the numba_dpex.experimental module.
2 parents 16d9c0a + 7f87223 commit 600a41f

File tree

3 files changed

+33
-59
lines changed

3 files changed

+33
-59
lines changed

numba_dpex/core/codegen.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ def _optimize_final_module(self):
5757
pmb.populate(pm)
5858
pm.run(self._final_module)
5959

60+
def optimize_final_module(self):
61+
"""Public member function to optimize the final LLVM module in the
62+
library. The function calls the protected overridden function.
63+
"""
64+
self._optimize_final_module()
65+
6066
def _finalize_specific(self):
6167
# Fix global naming
6268
for gv in self._final_module.global_variables:
@@ -68,6 +74,10 @@ def get_asm_str(self):
6874
# generated (in numba_dpex.compiler).
6975
return None
7076

77+
@property
78+
def final_module(self):
79+
return self._final_module
80+
7181

7282
class JITSPIRVCodegen(CPUCodegen):
7383
"""

numba_dpex/experimental/decorators.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .kernel_dispatcher import KernelDispatcher
1414

1515

16-
def kernel(func_or_sig=None, debug=False, cache=False, **options):
16+
def kernel(func_or_sig=None, **options):
1717
"""A decorator to define a kernel function.
1818
1919
A kernel function is conceptually equivalent to a SYCL kernel function, and
@@ -27,12 +27,9 @@ def kernel(func_or_sig=None, debug=False, cache=False, **options):
2727
# FIXME: The options need to be evaluated and checked here like it is
2828
# done in numba.core.decorators.jit
2929

30-
def _kernel_dispatcher(pyfunc, sigs=None):
30+
def _kernel_dispatcher(pyfunc):
3131
return KernelDispatcher(
3232
pyfunc=pyfunc,
33-
debug_flags=debug,
34-
enable_cache=cache,
35-
specialization_sigs=sigs,
3633
targetoptions=options,
3734
)
3835

@@ -64,9 +61,6 @@ def _kernel_dispatcher(pyfunc, sigs=None):
6461
def _specialized_kernel_dispatcher(pyfunc):
6562
return KernelDispatcher(
6663
pyfunc=pyfunc,
67-
debug_flags=debug,
68-
enable_cache=cache,
69-
specialization_sigs=func_or_sig,
7064
)
7165

7266
return _specialized_kernel_dispatcher

numba_dpex/experimental/kernel_dispatcher.py

Lines changed: 21 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,14 @@
55
"""Implements a new numba dispatcher class and a compiler class to compile and
66
call numba_dpex.kernel decorated function.
77
"""
8-
import functools
9-
from collections import Counter, OrderedDict, namedtuple
8+
from collections import namedtuple
109
from contextlib import ExitStack
1110

1211
import numba.core.event as ev
13-
from numba.core import errors, sigutils, types, utils
14-
from numba.core.caching import NullCache
12+
from numba.core import errors, sigutils, types
1513
from numba.core.compiler import CompileResult
1614
from numba.core.compiler_lock import global_compiler_lock
17-
from numba.core.dispatcher import Dispatcher, _DispatcherBase, _FunctionCompiler
15+
from numba.core.dispatcher import Dispatcher, _FunctionCompiler
1816
from numba.core.typing.typeof import Purpose, typeof
1917

2018
from numba_dpex import config, spirv_generator
@@ -84,12 +82,12 @@ def _compile_to_spirv(
8482

8583
# makes sure that the spir_func is completely inlined into the
8684
# spir_kernel wrapper
87-
kernel_library._optimize_final_module()
85+
kernel_library.optimize_final_module()
8886
# Compiled the LLVM IR to SPIR-V
8987
kernel_spirv_module = spirv_generator.llvm_to_spirv(
9088
kernel_targetctx,
91-
kernel_library._final_module,
92-
kernel_library._final_module.as_bitcode(),
89+
kernel_library.final_module,
90+
kernel_library.final_module.as_bitcode(),
9391
)
9492
return _KernelModule(
9593
kernel_name=kernel_fn.name, kernel_bitcode=kernel_spirv_module
@@ -158,7 +156,7 @@ def _compile_cached(
158156
"w",
159157
encoding="UTF-8",
160158
) as f:
161-
f.write(kernel_cres.library._final_module)
159+
f.write(kernel_cres.library.final_module)
162160

163161
except errors.TypingError as e:
164162
self._failed_cache[key] = e
@@ -187,57 +185,29 @@ class KernelDispatcher(Dispatcher):
187185
def __init__(
188186
self,
189187
pyfunc,
190-
debug_flags=None,
191-
compile_flags=None,
192-
specialization_sigs=None,
193-
enable_cache=True,
194-
locals={},
195-
targetoptions={},
196-
impl_kind="kernel",
188+
local_vars_to_numba_types=None,
189+
targetoptions=None,
197190
pipeline_class=kernel_compiler.KernelCompiler,
198191
):
192+
if targetoptions is None:
193+
targetoptions = {}
194+
195+
if local_vars_to_numba_types is None:
196+
local_vars_to_numba_types = {}
197+
199198
targetoptions["nopython"] = True
200199
targetoptions["experimental"] = True
201200

202201
self._kernel_name = pyfunc.__name__
203-
self.typingctx = self.targetdescr.typing_context
204-
self.targetctx = self.targetdescr.target_context
205-
206-
pysig = utils.pysignature(pyfunc)
207-
arg_count = len(pysig.parameters)
208-
209-
self.overloads = OrderedDict()
210-
211-
can_fallback = not targetoptions.get("nopython", False)
212202

213-
_DispatcherBase.__init__(
214-
self,
215-
arg_count,
216-
pyfunc,
217-
pysig,
218-
can_fallback,
219-
exact_match_required=False,
203+
super().__init__(
204+
py_func=pyfunc,
205+
locals=local_vars_to_numba_types,
206+
impl_kind="kernel",
207+
targetoptions=targetoptions,
208+
pipeline_class=pipeline_class,
220209
)
221210

222-
functools.update_wrapper(self, pyfunc)
223-
224-
self.targetoptions = targetoptions
225-
self.locals = locals
226-
self._cache = NullCache()
227-
compiler_class = self._impl_kinds[impl_kind]
228-
self._impl_kind = impl_kind
229-
self._compiler: _KernelCompiler = compiler_class(
230-
pyfunc, self.targetdescr, targetoptions, locals, pipeline_class
231-
)
232-
self._cache_hits = Counter()
233-
self._cache_misses = Counter()
234-
235-
self._type = types.Dispatcher(self)
236-
self.typingctx.insert_global(self, self._type)
237-
238-
# Remember target restriction
239-
self._required_target_backend = targetoptions.get("target_backend")
240-
241211
def typeof_pyval(self, val):
242212
"""
243213
Resolve the Numba type of Python value *val*.

0 commit comments

Comments
 (0)