44from tqdm import tqdm
55from threading import Thread , Event
66
7- from numba .core import types
8- from numba .experimental import structref
9- from numba .extending import overload_method
7+ from numba .extending import overload_method , typeof_impl , as_numba_type , models , register_model , \
8+ make_attribute_wrapper , overload_attribute , unbox , NativeValue , box
109from .numba_atomic import atomic_add
11-
10+ from numba import types
11+ from numba .core import cgutils
12+ from numba .core .boxing import unbox_array
1213
1314__all__ = ['ProgressBar' ]
1415
1516
16- @structref .register
17- class ProgressProxyType (types .StructRef ):
18- def preprocess_fields (self , fields ):
19- # We don't want the struct to take Literal types.
20- return tuple ((name , types .unliteral (typ )) for name , typ in fields )
21-
22-
23- class _ProgressProxy (structref .StructRefProxy ):
24- def __new__ (cls , hook = None ):
25- hook = np .zeros (1 , dtype = np .uint64 )
26- return structref .StructRefProxy .__new__ (cls , hook )
27-
28- def update (self , n = 1 ):
29- return _ProgressProxy_update (self , n )
30-
31- @property
32- def value (self ):
33- return _ProgressProxy_value (self )
34-
35-
36- @nb .njit ()
37- def _ProgressProxy_update (self , n = 1 ):
38- return self .update (n )
39-
40-
41- @nb .njit ()
42- def _ProgressProxy_value (self ):
43- return self .hook [0 ]
44-
45-
46- structref .define_proxy (_ProgressProxy , ProgressProxyType ,
47- ["hook" ])
48-
49-
50- @overload_method (ProgressProxyType , "update" , jit_options = {"nogil" : True })
51- def _ol_update (self , n = 1 ):
52- def _update_impl (self , n = 1 ):
53- atomic_add (self .hook , 0 , n )
54- return _update_impl
55-
56-
5717class ProgressBar (object ):
5818 """
5919 Wraps the tqdm progress bar enabling it to be updated from within a numba nopython function.
6020 It works by spawning a separate thread that updates the tqdm progress bar based on an atomic counter which can be
6121 accessed within the numba function. The progress bar works with parallel as well as sequential numba functions.
22+
23+ Note: As this Class contains python objects not useable or convertable into numba, it will be boxed as a
24+ proxy object, that only exposes the minimum subset of functionality to update the progress bar. Attempts
25+ to return or create a ProgressBar within a numba function will result in an error.
6226
6327 Parameters
6428 ----------
@@ -79,7 +43,7 @@ def __init__(self, file=None, update_interval=0.1, **kwargs):
7943 file = sys .stdout
8044 self ._last_value = 0
8145 self ._tqdm = tqdm (iterable = None , file = file , ** kwargs )
82- self ._numba_proxy = _ProgressProxy ( )
46+ self .hook = np . zeros ( 1 , dtype = np . uint64 )
8347 self ._updater_thread = None
8448 self ._exit_event = Event ()
8549 self .update_interval = update_interval
@@ -97,18 +61,15 @@ def close(self):
9761 self ._tqdm .close ()
9862
9963 @property
100- def numba_proxy (self ):
101- """
102- Returns the proxy object that can be used from within a numba function.
103- """
104- return self ._numba_proxy
64+ def value (self ):
65+ return self .hook [0 ]
10566
10667 def update (self , n = 1 ):
107- self ._numba_proxy . update ( n )
68+ atomic_add ( self .hook , 0 , n )
10869 self ._update_tqdm ()
10970
11071 def _update_tqdm (self ):
111- value = self ._numba_proxy . value
72+ value = self .value
11273 diff = value - self ._last_value
11374 self ._last_value = value
11475 self ._tqdm .update (diff )
@@ -121,7 +82,77 @@ def _update_function(self):
12182 self ._exit_event .wait (self .update_interval )
12283
12384 def __enter__ (self ):
124- return self . _numba_proxy
85+ return self
12586
12687 def __exit__ (self , exc_type , exc_val , exc_tb ):
12788 self .close ()
89+
90+
91+ # Numba Native Implementation for the ProgressBar Class
92+
93+ class ProgressBarType (types .Type ):
94+ def __init__ (self ):
95+ super ().__init__ (name = 'ProgressBar' )
96+
97+
98+ progressbar_type = ProgressBarType ()
99+
100+
101+ @typeof_impl .register (ProgressBar )
102+ def typeof_index (val , c ):
103+ return progressbar_type
104+
105+
106+ as_numba_type .register (ProgressBar , progressbar_type )
107+
108+
109+ @register_model (ProgressBarType )
110+ class ProgressBarModel (models .StructModel ):
111+ def __init__ (self , dmm , fe_type ):
112+ members = [
113+ ('hook' , types .Array (types .uint64 , 1 , 'C' )),
114+ ]
115+ models .StructModel .__init__ (self , dmm , fe_type , members )
116+
117+
118+ # make the hook attribute accessible
119+ make_attribute_wrapper (ProgressBarType , 'hook' , 'hook' )
120+
121+
122+ @overload_attribute (ProgressBarType , 'value' )
123+ def get_value (progress_bar ):
124+ def getter (progress_bar ):
125+ return progress_bar .hook [0 ]
126+ return getter
127+
128+
129+ @unbox (ProgressBarType )
130+ def unbox_progressbar (typ , obj , c ):
131+ """
132+ Convert a ProgressBar to it's native representation (proxy object)
133+ """
134+ hook_obj = c .pyapi .object_getattr_string (obj , 'hook' )
135+ progress_bar = cgutils .create_struct_proxy (typ )(c .context , c .builder )
136+ progress_bar .hook = unbox_array (types .Array (types .uint64 , 1 , 'C' ), hook_obj , c ).value
137+ c .pyapi .decref (hook_obj )
138+ is_error = cgutils .is_not_null (c .builder , c .pyapi .err_occurred ())
139+ return NativeValue (progress_bar ._getvalue (), is_error = is_error )
140+
141+
142+ @box (ProgressBarType )
143+ def box_progressbar (typ , val , c ):
144+ raise TypeError ("Native representation of ProgressBar cannot be converted back to a python object "
145+ "as it contains internal python state." )
146+
147+
148+ @overload_method (ProgressBarType , "update" , jit_options = {"nogil" : True })
149+ def _ol_update (self , n = 1 ):
150+ """
151+ Numpy implementation of the update method.
152+ """
153+ if isinstance (self , ProgressBarType ):
154+ def _update_impl (self , n = 1 ):
155+ atomic_add (self .hook , 0 , n )
156+ return _update_impl
157+
158+
0 commit comments