Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit d07f9af

Browse files
committed
Patch for with context (#96)
This modifications make jit() decorator use TargetDispatcher from dppl. Changes made in #57 by @AlexanderKalistratov and @1e-to.
1 parent cd2896c commit d07f9af

File tree

7 files changed

+67
-23
lines changed

7 files changed

+67
-23
lines changed

numba/core/decorators.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def bar(x, y):
149149
target = options.pop('target')
150150
warnings.warn("The 'target' keyword argument is deprecated.", NumbaDeprecationWarning)
151151
else:
152-
target = options.pop('_target', 'cpu')
152+
target = options.pop('_target', None)
153153

154154
options['boundscheck'] = boundscheck
155155

@@ -183,27 +183,16 @@ def bar(x, y):
183183

184184

185185
def _jit(sigs, locals, target, cache, targetoptions, **dispatcher_args):
186-
dispatcher = registry.dispatcher_registry[target]
187-
188-
def wrapper(func):
189-
if extending.is_jitted(func):
190-
raise TypeError(
191-
"A jit decorator was called on an already jitted function "
192-
f"{func}. If trying to access the original python "
193-
f"function, use the {func}.py_func attribute."
194-
)
195-
196-
if not inspect.isfunction(func):
197-
raise TypeError(
198-
"The decorated object is not a function (got type "
199-
f"{type(func)})."
200-
)
201186

187+
def wrapper(func, dispatcher):
202188
if config.ENABLE_CUDASIM and target == 'cuda':
203189
from numba import cuda
204190
return cuda.jit(func)
205191
if config.DISABLE_JIT and not target == 'npyufunc':
206192
return func
193+
if target == 'dppl':
194+
from . import dppl
195+
return dppl.jit(func)
207196
disp = dispatcher(py_func=func, locals=locals,
208197
targetoptions=targetoptions,
209198
**dispatcher_args)
@@ -219,7 +208,42 @@ def wrapper(func):
219208
disp.disable_compile()
220209
return disp
221210

222-
return wrapper
211+
def __wrapper(func):
212+
if extending.is_jitted(func):
213+
raise TypeError(
214+
"A jit decorator was called on an already jitted function "
215+
f"{func}. If trying to access the original python "
216+
f"function, use the {func}.py_func attribute."
217+
)
218+
219+
if not inspect.isfunction(func):
220+
raise TypeError(
221+
"The decorated object is not a function (got type "
222+
f"{type(func)})."
223+
)
224+
225+
is_numba_dppy_present = False
226+
try:
227+
import numba_dppy.config as dppy_config
228+
229+
is_numba_dppy_present = dppy_config.dppy_present
230+
except ImportError:
231+
pass
232+
233+
if (not is_numba_dppy_present
234+
or target == 'npyufunc' or targetoptions.get('no_cpython_wrapper')
235+
or sigs or config.DISABLE_JIT or not targetoptions.get('nopython')):
236+
target_ = target
237+
if target_ is None:
238+
target_ = 'cpu'
239+
disp = registry.dispatcher_registry[target_]
240+
return wrapper(func, disp)
241+
242+
from numba_dppy.target_dispatcher import TargetDispatcher
243+
disp = TargetDispatcher(func, wrapper, target, targetoptions.get('parallel'))
244+
return disp
245+
246+
return __wrapper
223247

224248

225249
def generated_jit(function=None, target='cpu', cache=False,

numba/core/dispatcher.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,14 @@ def _set_uuid(self, u):
673673
self._recent.append(self)
674674

675675

676-
class Dispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
676+
import abc
677+
678+
class DispatcherMeta(abc.ABCMeta):
679+
def __instancecheck__(self, other):
680+
return type(type(other)) == DispatcherMeta
681+
682+
683+
class Dispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase, metaclass=DispatcherMeta):
677684
"""
678685
Implementation of user-facing dispatcher objects (i.e. created using
679686
the @jit decorator).
@@ -899,6 +906,9 @@ def get_function_type(self):
899906
cres = tuple(self.overloads.values())[0]
900907
return types.FunctionType(cres.signature)
901908

909+
def get_compiled(self):
910+
return self
911+
902912

903913
class LiftedCode(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
904914
"""

numba/core/registry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from numba.core.descriptors import TargetDescriptor
44
from numba.core import utils, typing, dispatcher, cpu
5+
from numba.core.compiler_lock import global_compiler_lock
56

67
# -----------------------------------------------------------------------------
78
# Default CPU target descriptors
@@ -26,16 +27,19 @@ class CPUTarget(TargetDescriptor):
2627
_nested = _NestedContext()
2728

2829
@utils.cached_property
30+
@global_compiler_lock
2931
def _toplevel_target_context(self):
3032
# Lazily-initialized top-level target context, for all threads
3133
return cpu.CPUContext(self.typing_context)
3234

3335
@utils.cached_property
36+
@global_compiler_lock
3437
def _toplevel_typing_context(self):
3538
# Lazily-initialized top-level typing context, for all threads
3639
return typing.Context()
3740

3841
@property
42+
@global_compiler_lock
3943
def target_context(self):
4044
"""
4145
The target context for CPU targets.
@@ -47,6 +51,7 @@ def target_context(self):
4751
return self._toplevel_target_context
4852

4953
@property
54+
@global_compiler_lock
5055
def typing_context(self):
5156
"""
5257
The typing context for CPU targets.
@@ -57,6 +62,7 @@ def typing_context(self):
5762
else:
5863
return self._toplevel_typing_context
5964

65+
@global_compiler_lock
6066
def nested_context(self, typing_context, target_context):
6167
"""
6268
A context manager temporarily replacing the contexts with the

numba/tests/test_dispatcher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,8 @@ def test_serialization(self):
398398
def foo(x):
399399
return x + 1
400400

401+
foo = foo.get_compiled()
402+
401403
self.assertEqual(foo(1), 2)
402404

403405
# get serialization memo

numba/tests/test_nrt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@ def alloc_nrt_memory():
249249
"""
250250
return np.empty(N, dtype)
251251

252+
alloc_nrt_memory = alloc_nrt_memory.get_compiled()
253+
252254
def keep_memory():
253255
return alloc_nrt_memory()
254256

numba/tests/test_record_dtype.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -803,8 +803,8 @@ def test_record_arg_transform(self):
803803
self.assertIn('Array', transformed)
804804
self.assertNotIn('first', transformed)
805805
self.assertNotIn('second', transformed)
806-
# Length is usually 50 - 5 chars tolerance as above.
807-
self.assertLess(len(transformed), 50)
806+
# Length is usually 60 - 5 chars tolerance as above.
807+
self.assertLess(len(transformed), 60)
808808

809809
def test_record_two_arrays(self):
810810
"""

numba/tests/test_serialize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ def test_reuse(self):
135135
136136
Note that "same function" is intentionally under-specified.
137137
"""
138-
func = closure(5)
138+
func = closure(5).get_compiled()
139139
pickled = pickle.dumps(func)
140-
func2 = closure(6)
140+
func2 = closure(6).get_compiled()
141141
pickled2 = pickle.dumps(func2)
142142

143143
f = pickle.loads(pickled)
@@ -152,7 +152,7 @@ def test_reuse(self):
152152
self.assertEqual(h(2, 3), 11)
153153

154154
# Now make sure the original object doesn't exist when deserializing
155-
func = closure(7)
155+
func = closure(7).get_compiled()
156156
func(42, 43)
157157
pickled = pickle.dumps(func)
158158
del func

0 commit comments

Comments
 (0)