1313 from jax .core import UnexpectedTracerError
1414
1515from brainpy import errors , tools , check
16- from .base_object import get_unique_name
16+ from .base_object import get_unique_name , BrainPyObject
1717from .collector import ArrayCollector
1818from ..ndarray import (Array , Variable ,
1919 add_context ,
@@ -445,6 +445,7 @@ def cond(
445445 false_fun : Union [Callable , jnp .ndarray , Array , float , int , bool ],
446446 operands : Any ,
447447 dyn_vars : Union [Variable , Sequence [Variable ], Dict [str , Variable ]] = None ,
448+ child_objs : Optional [Union [BrainPyObject , Sequence [BrainPyObject ], Dict [str , BrainPyObject ]]] = None ,
448449):
449450 """Simple conditional statement (if-else) with instance of :py:class:`~.Variable`.
450451
@@ -477,6 +478,10 @@ def cond(
477478 can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.
478479 dyn_vars: optional, Variable, sequence of Variable, dict
479480 The dynamically changed variables.
481+ child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
482+ The children objects used in the target function.
483+
484+ .. versionadded:: 2.3.1
480485
481486 Returns
482487 -------
@@ -487,8 +492,11 @@ def cond(
487492 true_fun = _check_f (true_fun )
488493 false_fun = _check_f (false_fun )
489494 dyn_vars = check .is_all_vars (dyn_vars , out_as = 'dict' )
495+ dyn_vars = ArrayCollector (dyn_vars )
490496 dyn_vars .update (infer_dyn_vars (true_fun ))
491497 dyn_vars .update (infer_dyn_vars (false_fun ))
498+ for obj in check .is_all_objs (child_objs , out_as = 'tuple' ):
499+ dyn_vars .update (obj .vars ().unique ())
492500 dyn_vars = list (ArrayCollector (dyn_vars ).unique ().values ())
493501
494502 name = get_unique_name ('_brainpy_object_oriented_cond_' )
@@ -539,6 +547,7 @@ def ifelse(
539547 branches : Sequence [Callable ],
540548 operands : Any = None ,
541549 dyn_vars : Union [Variable , Sequence [Variable ], Dict [str , Variable ]] = None ,
550+ child_objs : Optional [Union [BrainPyObject , Sequence [BrainPyObject ], Dict [str , BrainPyObject ]]] = None ,
542551 show_code : bool = False ,
543552):
544553 """``If-else`` control flows looks like native Pythonic programming.
@@ -578,6 +587,10 @@ def ifelse(
578587 The dynamically changed variables.
579588 show_code: bool
580589 Whether show the formatted code.
590+ child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
591+ The children objects used in the target function.
592+
593+ .. versionadded:: 2.3.1
581594
582595 Returns
583596 -------
@@ -602,7 +615,9 @@ def ifelse(
602615 dyn_vars = ArrayCollector (dyn_vars )
603616 for f in branches :
604617 dyn_vars += infer_dyn_vars (f )
605- dyn_vars = tuple (dyn_vars .values ())
618+ for obj in check .is_all_objs (child_objs , out_as = 'tuple' ):
619+ dyn_vars .update (obj .vars ().unique ())
620+ dyn_vars = tuple (dyn_vars .unique ().values ())
606621
607622 # format new codes
608623 if len (conditions ) == 1 :
@@ -647,6 +662,7 @@ def for_loop(
647662 operands : Any ,
648663 dyn_vars : Union [Variable , Sequence [Variable ], Dict [str , Variable ]] = None ,
649664 out_vars : Optional [Union [Variable , Sequence [Variable ], Dict [str , Variable ]]] = None ,
665+ child_objs : Optional [Union [BrainPyObject , Sequence [BrainPyObject ], Dict [str , BrainPyObject ]]] = None ,
650666 reverse : bool = False ,
651667 unroll : int = 1 ,
652668):
@@ -727,14 +743,21 @@ def for_loop(
727743 Optional positive int specifying, in the underlying operation of the
728744 scan primitive, how many scan iterations to unroll within a single
729745 iteration of a loop.
746+ child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
747+ The children objects used in the target function.
748+
749+ .. versionadded:: 2.3.1
730750
731751 Returns
732752 -------
733753 outs: Any
734754 The stacked outputs of ``body_fun`` when scanned over the leading axis of the inputs.
735755 """
736756 dyn_vars = check .is_all_vars (dyn_vars , out_as = 'dict' )
757+ dyn_vars = ArrayCollector (dyn_vars )
737758 dyn_vars .update (infer_dyn_vars (body_fun ))
759+ for obj in check .is_all_objs (child_objs , out_as = 'tuple' ):
760+ dyn_vars .update (obj .vars ().unique ())
738761 dyn_vars = list (ArrayCollector (dyn_vars ).unique ().values ())
739762 outs , _ = tree_flatten (out_vars , lambda s : isinstance (s , Variable ))
740763 for v in outs :
@@ -785,6 +808,7 @@ def while_loop(
785808 cond_fun : Callable ,
786809 operands : Any ,
787810 dyn_vars : Union [Variable , Sequence [Variable ], Dict [str , Variable ]] = None ,
811+ child_objs : Optional [Union [BrainPyObject , Sequence [BrainPyObject ], Dict [str , BrainPyObject ]]] = None ,
788812):
789813 """``while-loop`` control flow with :py:class:`~.Variable`.
790814
@@ -831,13 +855,18 @@ def while_loop(
831855 The dynamically changed variables.
832856 operands: Any
833857 The operands for ``body_fun`` and ``cond_fun`` functions.
858+ child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
859+ The children objects used in the target function.
834860
861+ .. versionadded:: 2.3.1
835862 """
836863 # iterable variables
837864 dyn_vars = check .is_all_vars (dyn_vars , out_as = 'dict' )
838865 dyn_vars = ArrayCollector (dyn_vars )
839866 dyn_vars .update (infer_dyn_vars (body_fun ))
840867 dyn_vars .update (infer_dyn_vars (cond_fun ))
868+ for obj in check .is_all_objs (child_objs , out_as = 'tuple' ):
869+ dyn_vars .update (obj .vars ().unique ())
841870 dyn_vars = tuple (dyn_vars .values ())
842871 if not isinstance (operands , (list , tuple )):
843872 operands = (operands ,)
0 commit comments