Skip to content

Commit f970584

Browse files
authored
Merge pull request #1 from mortacious/feature/numba_boxing
Feature/numba boxing
2 parents c465ce4 + 76bf7af commit f970584

File tree

1 file changed

+85
-54
lines changed

1 file changed

+85
-54
lines changed

numba_progress/progress.py

Lines changed: 85 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,61 +4,25 @@
44
from tqdm import tqdm
55
from threading import Thread, Event
66

7-
from numba.core import types
8-
from numba.experimental import structref
9-
from numba.extending import overload_method
7+
from numba.extending import overload_method, typeof_impl, as_numba_type, models, register_model, \
8+
make_attribute_wrapper, overload_attribute, unbox, NativeValue, box
109
from .numba_atomic import atomic_add
11-
10+
from numba import types
11+
from numba.core import cgutils
12+
from numba.core.boxing import unbox_array
1213

1314
__all__ = ['ProgressBar']
1415

1516

16-
@structref.register
17-
class ProgressProxyType(types.StructRef):
18-
def preprocess_fields(self, fields):
19-
# We don't want the struct to take Literal types.
20-
return tuple((name, types.unliteral(typ)) for name, typ in fields)
21-
22-
23-
class _ProgressProxy(structref.StructRefProxy):
24-
def __new__(cls, hook=None):
25-
hook = np.zeros(1, dtype=np.uint64)
26-
return structref.StructRefProxy.__new__(cls, hook)
27-
28-
def update(self, n=1):
29-
return _ProgressProxy_update(self, n)
30-
31-
@property
32-
def value(self):
33-
return _ProgressProxy_value(self)
34-
35-
36-
@nb.njit()
37-
def _ProgressProxy_update(self, n=1):
38-
return self.update(n)
39-
40-
41-
@nb.njit()
42-
def _ProgressProxy_value(self):
43-
return self.hook[0]
44-
45-
46-
structref.define_proxy(_ProgressProxy, ProgressProxyType,
47-
["hook"])
48-
49-
50-
@overload_method(ProgressProxyType, "update", jit_options={"nogil": True})
51-
def _ol_update(self, n=1):
52-
def _update_impl(self, n=1):
53-
atomic_add(self.hook, 0, n)
54-
return _update_impl
55-
56-
5717
class ProgressBar(object):
5818
"""
5919
Wraps the tqdm progress bar enabling it to be updated from within a numba nopython function.
6020
It works by spawning a separate thread that updates the tqdm progress bar based on an atomic counter which can be
6121
accessed within the numba function. The progress bar works with parallel as well as sequential numba functions.
22+
23+
Note: As this Class contains python objects not useable or convertable into numba, it will be boxed as a
24+
proxy object, that only exposes the minimum subset of functionality to update the progress bar. Attempts
25+
to return or create a ProgressBar within a numba function will result in an error.
6226
6327
Parameters
6428
----------
@@ -79,7 +43,7 @@ def __init__(self, file=None, update_interval=0.1, **kwargs):
7943
file = sys.stdout
8044
self._last_value = 0
8145
self._tqdm = tqdm(iterable=None, file=file, **kwargs)
82-
self._numba_proxy = _ProgressProxy()
46+
self.hook = np.zeros(1, dtype=np.uint64)
8347
self._updater_thread = None
8448
self._exit_event = Event()
8549
self.update_interval = update_interval
@@ -97,18 +61,15 @@ def close(self):
9761
self._tqdm.close()
9862

9963
@property
100-
def numba_proxy(self):
101-
"""
102-
Returns the proxy object that can be used from within a numba function.
103-
"""
104-
return self._numba_proxy
64+
def value(self):
65+
return self.hook[0]
10566

10667
def update(self, n=1):
107-
self._numba_proxy.update(n)
68+
atomic_add(self.hook, 0, n)
10869
self._update_tqdm()
10970

11071
def _update_tqdm(self):
111-
value = self._numba_proxy.value
72+
value = self.value
11273
diff = value - self._last_value
11374
self._last_value = value
11475
self._tqdm.update(diff)
@@ -121,7 +82,77 @@ def _update_function(self):
12182
self._exit_event.wait(self.update_interval)
12283

12384
def __enter__(self):
124-
return self._numba_proxy
85+
return self
12586

12687
def __exit__(self, exc_type, exc_val, exc_tb):
12788
self.close()
89+
90+
91+
# Numba Native Implementation for the ProgressBar Class
92+
93+
class ProgressBarType(types.Type):
94+
def __init__(self):
95+
super().__init__(name='ProgressBar')
96+
97+
98+
progressbar_type = ProgressBarType()
99+
100+
101+
@typeof_impl.register(ProgressBar)
102+
def typeof_index(val, c):
103+
return progressbar_type
104+
105+
106+
as_numba_type.register(ProgressBar, progressbar_type)
107+
108+
109+
@register_model(ProgressBarType)
110+
class ProgressBarModel(models.StructModel):
111+
def __init__(self, dmm, fe_type):
112+
members = [
113+
('hook', types.Array(types.uint64, 1, 'C')),
114+
]
115+
models.StructModel.__init__(self, dmm, fe_type, members)
116+
117+
118+
# make the hook attribute accessible
119+
make_attribute_wrapper(ProgressBarType, 'hook', 'hook')
120+
121+
122+
@overload_attribute(ProgressBarType, 'value')
123+
def get_value(progress_bar):
124+
def getter(progress_bar):
125+
return progress_bar.hook[0]
126+
return getter
127+
128+
129+
@unbox(ProgressBarType)
130+
def unbox_progressbar(typ, obj, c):
131+
"""
132+
Convert a ProgressBar to it's native representation (proxy object)
133+
"""
134+
hook_obj = c.pyapi.object_getattr_string(obj, 'hook')
135+
progress_bar = cgutils.create_struct_proxy(typ)(c.context, c.builder)
136+
progress_bar.hook = unbox_array(types.Array(types.uint64, 1, 'C'), hook_obj, c).value
137+
c.pyapi.decref(hook_obj)
138+
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
139+
return NativeValue(progress_bar._getvalue(), is_error=is_error)
140+
141+
142+
@box(ProgressBarType)
143+
def box_progressbar(typ, val, c):
144+
raise TypeError("Native representation of ProgressBar cannot be converted back to a python object "
145+
"as it contains internal python state.")
146+
147+
148+
@overload_method(ProgressBarType, "update", jit_options={"nogil": True})
149+
def _ol_update(self, n=1):
150+
"""
151+
Numpy implementation of the update method.
152+
"""
153+
if isinstance(self, ProgressBarType):
154+
def _update_impl(self, n=1):
155+
atomic_add(self.hook, 0, n)
156+
return _update_impl
157+
158+

0 commit comments

Comments
 (0)