|
5 | 5 | """Implements a new numba dispatcher class and a compiler class to compile and
|
6 | 6 | call numba_dpex.kernel decorated function.
|
7 | 7 | """
|
8 |
| -import functools |
9 |
| -from collections import Counter, OrderedDict, namedtuple |
| 8 | +from collections import namedtuple |
10 | 9 | from contextlib import ExitStack
|
11 | 10 |
|
12 | 11 | import numba.core.event as ev
|
13 |
| -from numba.core import errors, sigutils, types, utils |
14 |
| -from numba.core.caching import NullCache |
| 12 | +from numba.core import errors, sigutils, types |
15 | 13 | from numba.core.compiler import CompileResult
|
16 | 14 | from numba.core.compiler_lock import global_compiler_lock
|
17 |
| -from numba.core.dispatcher import Dispatcher, _DispatcherBase, _FunctionCompiler |
| 15 | +from numba.core.dispatcher import Dispatcher, _FunctionCompiler |
18 | 16 | from numba.core.typing.typeof import Purpose, typeof
|
19 | 17 |
|
20 | 18 | from numba_dpex import config, spirv_generator
|
@@ -84,12 +82,12 @@ def _compile_to_spirv(
|
84 | 82 |
|
85 | 83 | # makes sure that the spir_func is completely inlined into the
|
86 | 84 | # spir_kernel wrapper
|
87 |
| - kernel_library._optimize_final_module() |
| 85 | + kernel_library.optimize_final_module() |
88 | 86 | # Compiled the LLVM IR to SPIR-V
|
89 | 87 | kernel_spirv_module = spirv_generator.llvm_to_spirv(
|
90 | 88 | kernel_targetctx,
|
91 |
| - kernel_library._final_module, |
92 |
| - kernel_library._final_module.as_bitcode(), |
| 89 | + kernel_library.final_module, |
| 90 | + kernel_library.final_module.as_bitcode(), |
93 | 91 | )
|
94 | 92 | return _KernelModule(
|
95 | 93 | kernel_name=kernel_fn.name, kernel_bitcode=kernel_spirv_module
|
@@ -158,7 +156,7 @@ def _compile_cached(
|
158 | 156 | "w",
|
159 | 157 | encoding="UTF-8",
|
160 | 158 | ) as f:
|
161 |
| - f.write(kernel_cres.library._final_module) |
| 159 | + f.write(kernel_cres.library.final_module) |
162 | 160 |
|
163 | 161 | except errors.TypingError as e:
|
164 | 162 | self._failed_cache[key] = e
|
@@ -187,57 +185,29 @@ class KernelDispatcher(Dispatcher):
|
187 | 185 | def __init__(
|
188 | 186 | self,
|
189 | 187 | pyfunc,
|
190 |
| - debug_flags=None, |
191 |
| - compile_flags=None, |
192 |
| - specialization_sigs=None, |
193 |
| - enable_cache=True, |
194 |
| - locals={}, |
195 |
| - targetoptions={}, |
196 |
| - impl_kind="kernel", |
| 188 | + local_vars_to_numba_types=None, |
| 189 | + targetoptions=None, |
197 | 190 | pipeline_class=kernel_compiler.KernelCompiler,
|
198 | 191 | ):
|
| 192 | + if targetoptions is None: |
| 193 | + targetoptions = {} |
| 194 | + |
| 195 | + if local_vars_to_numba_types is None: |
| 196 | + local_vars_to_numba_types = {} |
| 197 | + |
199 | 198 | targetoptions["nopython"] = True
|
200 | 199 | targetoptions["experimental"] = True
|
201 | 200 |
|
202 | 201 | self._kernel_name = pyfunc.__name__
|
203 |
| - self.typingctx = self.targetdescr.typing_context |
204 |
| - self.targetctx = self.targetdescr.target_context |
205 |
| - |
206 |
| - pysig = utils.pysignature(pyfunc) |
207 |
| - arg_count = len(pysig.parameters) |
208 |
| - |
209 |
| - self.overloads = OrderedDict() |
210 |
| - |
211 |
| - can_fallback = not targetoptions.get("nopython", False) |
212 | 202 |
|
213 |
| - _DispatcherBase.__init__( |
214 |
| - self, |
215 |
| - arg_count, |
216 |
| - pyfunc, |
217 |
| - pysig, |
218 |
| - can_fallback, |
219 |
| - exact_match_required=False, |
| 203 | + super().__init__( |
| 204 | + py_func=pyfunc, |
| 205 | + locals=local_vars_to_numba_types, |
| 206 | + impl_kind="kernel", |
| 207 | + targetoptions=targetoptions, |
| 208 | + pipeline_class=pipeline_class, |
220 | 209 | )
|
221 | 210 |
|
222 |
| - functools.update_wrapper(self, pyfunc) |
223 |
| - |
224 |
| - self.targetoptions = targetoptions |
225 |
| - self.locals = locals |
226 |
| - self._cache = NullCache() |
227 |
| - compiler_class = self._impl_kinds[impl_kind] |
228 |
| - self._impl_kind = impl_kind |
229 |
| - self._compiler: _KernelCompiler = compiler_class( |
230 |
| - pyfunc, self.targetdescr, targetoptions, locals, pipeline_class |
231 |
| - ) |
232 |
| - self._cache_hits = Counter() |
233 |
| - self._cache_misses = Counter() |
234 |
| - |
235 |
| - self._type = types.Dispatcher(self) |
236 |
| - self.typingctx.insert_global(self, self._type) |
237 |
| - |
238 |
| - # Remember target restriction |
239 |
| - self._required_target_backend = targetoptions.get("target_backend") |
240 |
| - |
241 | 211 | def typeof_pyval(self, val):
|
242 | 212 | """
|
243 | 213 | Resolve the Numba type of Python value *val*.
|
|
0 commit comments