@@ -217,16 +217,12 @@ def __init__(
217217
218218 # build the update step
219219 if self .jit ['predict' ]:
220- _loop_func = bm .make_loop (
221- self ._step ,
222- dyn_vars = self .dyn_vars ,
223- out_vars = {k : self .variables [k ] for k in self .monitors .keys ()},
224- has_return = True
225- )
220+ def _loop_func (times ):
221+ return bm .for_loop (self ._step , self .dyn_vars , times )
226222 else :
227223 def _loop_func (times ):
228- out_vars = {k : [] for k in self .monitors .keys ()}
229224 returns = {k : [] for k in self .fun_monitors .keys ()}
225+ returns .update ({k : [] for k in self .monitors .keys ()})
230226 for i in range (len (times )):
231227 _t = times [i ]
232228 _dt = self .dt
@@ -237,9 +233,9 @@ def _loop_func(times):
237233 self ._step (_t )
238234 # variable monitors
239235 for k in self .monitors .keys ():
240- out_vars [k ].append (bm .as_device_array (self .variables [k ]))
241- out_vars = {k : bm .asarray (out_vars [k ]) for k in self . monitors .keys ()}
242- return out_vars , returns
236+ returns [k ].append (bm .as_device_array (self .variables [k ]))
237+ returns = {k : bm .asarray (returns [k ]) for k in returns .keys ()}
238+ return returns
243239 self .step_func = _loop_func
244240
245241 def _step (self , t ):
@@ -252,11 +248,6 @@ def _step(self, t):
252248 kwargs .update ({k : v [self .idx .value ] for k , v in self ._dyn_args .items ()})
253249 self .idx += 1
254250
255- # return of function monitors
256- returns = dict ()
257- for key , func in self .fun_monitors .items ():
258- returns [key ] = func (t , self .dt )
259-
260251 # call integrator function
261252 update_values = self .target (** kwargs )
262253 if len (self .target .variables ) == 1 :
@@ -268,6 +259,13 @@ def _step(self, t):
268259 # progress bar
269260 if self .progress_bar :
270261 id_tap (lambda * args : self ._pbar .update (), ())
262+
263+ # return of function monitors
264+ returns = dict ()
265+ for key , func in self .fun_monitors .items ():
266+ returns [key ] = func (t , self .dt )
267+ for k in self .monitors .keys ():
268+ returns [k ] = self .variables [k ].value
271269 return returns
272270
273271 def run (self , duration , start_t = None , eval_time = False ):
@@ -302,14 +300,13 @@ def run(self, duration, start_t=None, eval_time=False):
302300 refresh = True )
303301 if eval_time :
304302 t0 = time .time ()
305- hists , returns = self .step_func (times )
303+ hists = self .step_func (times )
306304 if eval_time :
307305 running_time = time .time () - t0
308306 if self .progress_bar :
309307 self ._pbar .close ()
310308
311309 # post-running
312- hists .update (returns )
313310 times += self .dt
314311 if self .numpy_mon_after_run :
315312 times = np .asarray (times )
0 commit comments