1313 from jax .core import UnexpectedTracerError
1414
1515from brainpy import errors
16+ from brainpy .base .naming import get_unique_name
1617from brainpy .math .jaxarray import (JaxArray , Variable ,
17- turn_on_global_jit ,
18- turn_off_global_jit )
18+ add_context ,
19+ del_context )
1920from brainpy .math .numpy_ops import as_device_array
2021
2122__all__ = [
@@ -158,17 +159,19 @@ def make_loop(body_fun, dyn_vars, out_vars=None, has_return=False):
158159 out_vars = out_vars ,
159160 has_return = has_return )
160161
162+ name = get_unique_name ('_brainpy_object_oriented_make_loop_' )
163+
161164 # functions
162165 if has_return :
163166 def call (xs = None , length = None ):
164167 init_values = [v .value for v in dyn_vars ]
165168 try :
166- turn_on_global_jit ( )
169+ add_context ( name )
167170 dyn_values , (out_values , results ) = lax .scan (
168171 f = fun2scan , init = init_values , xs = xs , length = length )
169- turn_off_global_jit ( )
172+ del_context ( name )
170173 except UnexpectedTracerError as e :
171- turn_off_global_jit ( )
174+ del_context ( name )
172175 for v , d in zip (dyn_vars , init_values ): v ._value = d
173176 raise errors .JaxTracerError (variables = dyn_vars ) from e
174177 for v , d in zip (dyn_vars , dyn_values ): v ._value = d
@@ -178,15 +181,15 @@ def call(xs=None, length=None):
178181 def call (xs ):
179182 init_values = [v .value for v in dyn_vars ]
180183 try :
181- turn_on_global_jit ( )
184+ add_context ( name )
182185 dyn_values , out_values = lax .scan (f = fun2scan , init = init_values , xs = xs )
183- turn_off_global_jit ( )
186+ del_context ( name )
184187 except UnexpectedTracerError as e :
185- turn_off_global_jit ( )
188+ del_context ( name )
186189 for v , d in zip (dyn_vars , init_values ): v ._value = d
187190 raise errors .JaxTracerError (variables = dyn_vars ) from e
188191 except Exception as e :
189- turn_off_global_jit ( )
192+ del_context ( name )
190193 for v , d in zip (dyn_vars , init_values ): v ._value = d
191194 raise e
192195 for v , d in zip (dyn_vars , dyn_values ): v ._value = d
@@ -255,20 +258,22 @@ def _cond_fun(op):
255258 for v , d in zip (dyn_vars , dyn_values ): v ._value = d
256259 return as_device_array (cond_fun (static_values ))
257260
261+ name = get_unique_name ('_brainpy_object_oriented_make_while_' )
262+
258263 def call (x = None ):
259264 dyn_init = [v .value for v in dyn_vars ]
260265 try :
261- turn_on_global_jit ( )
266+ add_context ( name )
262267 dyn_values , _ = lax .while_loop (cond_fun = _cond_fun ,
263268 body_fun = _body_fun ,
264269 init_val = (dyn_init , x ))
265- turn_off_global_jit ( )
270+ del_context ( name )
266271 except UnexpectedTracerError as e :
267- turn_off_global_jit ( )
272+ del_context ( name )
268273 for v , d in zip (dyn_vars , dyn_init ): v ._value = d
269274 raise errors .JaxTracerError (variables = dyn_vars ) from e
270275 except Exception as e :
271- turn_off_global_jit ( )
276+ del_context ( name )
272277 for v , d in zip (dyn_vars , dyn_init ): v ._value = d
273278 raise e
274279 for v , d in zip (dyn_vars , dyn_values ): v ._value = d
@@ -330,6 +335,8 @@ def make_cond(true_fun, false_fun, dyn_vars=None):
330335 if not isinstance (v , JaxArray ):
331336 raise ValueError (f'Only support { JaxArray .__name__ } , but got { type (v )} ' )
332337
338+ name = get_unique_name ('_brainpy_object_oriented_make_cond_' )
339+
333340 if len (dyn_vars ) > 0 :
334341 def _true_fun (op ):
335342 dyn_vals , static_vals = op
@@ -348,25 +355,25 @@ def _false_fun(op):
348355 def call (pred , x = None ):
349356 old_values = [v .value for v in dyn_vars ]
350357 try :
351- turn_on_global_jit ( )
358+ add_context ( name )
352359 dyn_values , res = lax .cond (pred , _true_fun , _false_fun , (old_values , x ))
353- turn_off_global_jit ( )
360+ del_context ( name )
354361 except UnexpectedTracerError as e :
355- turn_off_global_jit ( )
362+ del_context ( name )
356363 for v , d in zip (dyn_vars , old_values ): v ._value = d
357364 raise errors .JaxTracerError (variables = dyn_vars ) from e
358365 except Exception as e :
359- turn_off_global_jit ( )
366+ del_context ( name )
360367 for v , d in zip (dyn_vars , old_values ): v ._value = d
361368 raise e
362369 for v , d in zip (dyn_vars , dyn_values ): v ._value = d
363370 return res
364371
365372 else :
366373 def call (pred , x = None ):
367- turn_on_global_jit ( )
374+ add_context ( name )
368375 res = lax .cond (pred , true_fun , false_fun , x )
369- turn_off_global_jit ( )
376+ del_context ( name )
370377 return res
371378
372379 return call
@@ -445,6 +452,8 @@ def cond(
445452 if not isinstance (v , Variable ):
446453 raise ValueError (f'Only support { Variable .__name__ } , but got { type (v )} ' )
447454
455+ name = get_unique_name ('_brainpy_object_oriented_cond_' )
456+
448457 # calling the model
449458 if len (dyn_vars ) > 0 :
450459 def _true_fun (op ):
@@ -463,25 +472,25 @@ def _false_fun(op):
463472
464473 old_values = [v .value for v in dyn_vars ]
465474 try :
466- turn_on_global_jit ( )
475+ add_context ( name )
467476 dyn_values , res = lax .cond (pred = pred ,
468477 true_fun = _true_fun ,
469478 false_fun = _false_fun ,
470479 operand = (old_values , operands ))
471- turn_off_global_jit ( )
480+ del_context ( name )
472481 except UnexpectedTracerError as e :
473- turn_off_global_jit ( )
482+ del_context ( name )
474483 for v , d in zip (dyn_vars , old_values ): v ._value = d
475484 raise errors .JaxTracerError (variables = dyn_vars ) from e
476485 except Exception as e :
477- turn_off_global_jit ( )
486+ del_context ( name )
478487 for v , d in zip (dyn_vars , old_values ): v ._value = d
479488 raise e
480489 for v , d in zip (dyn_vars , dyn_values ): v ._value = d
481490 else :
482- turn_on_global_jit ( )
491+ add_context ( name )
483492 res = lax .cond (pred , true_fun , false_fun , operands )
484- turn_off_global_jit ( )
493+ del_context ( name )
485494 return res
486495
487496
@@ -591,7 +600,11 @@ def ifelse(
591600 if show_code : print (codes )
592601 exec (compile (codes .strip (), '' , 'exec' ), code_scope )
593602 f = code_scope ['f' ]
594- return f (operands )
603+ name = get_unique_name ('_brainpy_object_oriented_ifelse_' )
604+ add_context (name )
605+ r = f (operands )
606+ del_context (name )
607+ return r
595608
596609
597610def for_loop (body_fun : Callable ,
@@ -694,22 +707,24 @@ def fun2scan(dyn_vals, x):
694707 results = body_fun (* x )
695708 return [v .value for v in dyn_vars ], results
696709
710+ name = get_unique_name ('_brainpy_object_oriented_for_loop_' )
711+
697712 # functions
698713 init_vals = [v .value for v in dyn_vars ]
699714 try :
700- turn_on_global_jit ( )
715+ add_context ( name )
701716 dyn_vals , out_vals = lax .scan (f = fun2scan ,
702717 init = init_vals ,
703718 xs = operands ,
704719 reverse = reverse ,
705720 unroll = unroll )
706- turn_off_global_jit ( )
721+ del_context ( name )
707722 except UnexpectedTracerError as e :
708- turn_off_global_jit ( )
723+ del_context ( name )
709724 for v , d in zip (dyn_vars , init_vals ): v ._value = d
710725 raise errors .JaxTracerError (variables = dyn_vars ) from e
711726 except Exception as e :
712- turn_off_global_jit ( )
727+ del_context ( name )
713728 for v , d in zip (dyn_vars , init_vals ): v ._value = d
714729 raise e
715730 for v , d in zip (dyn_vars , dyn_vals ): v ._value = d
@@ -797,19 +812,20 @@ def _cond_fun(op):
797812 r = cond_fun (* static_vals )
798813 return r if isinstance (r , JaxArray ) else r
799814
815+ name = get_unique_name ('_brainpy_object_oriented_while_loop_' )
800816 dyn_init = [v .value for v in dyn_vars ]
801817 try :
802- turn_on_global_jit ( )
818+ add_context ( name )
803819 dyn_values , out = lax .while_loop (cond_fun = _cond_fun ,
804820 body_fun = _body_fun ,
805821 init_val = (dyn_init , operands ))
806- turn_off_global_jit ( )
822+ del_context ( name )
807823 except UnexpectedTracerError as e :
808- turn_off_global_jit ( )
824+ del_context ( name )
809825 for v , d in zip (dyn_vars , dyn_init ): v ._value = d
810826 raise errors .JaxTracerError (variables = dyn_vars ) from e
811827 except Exception as e :
812- turn_off_global_jit ( )
828+ del_context ( name )
813829 for v , d in zip (dyn_vars , dyn_init ): v ._value = d
814830 raise e
815831 for v , d in zip (dyn_vars , dyn_values ): v ._value = d
0 commit comments