Skip to content

Commit eba2182

Browse files
committed
structural control flow transformations supports specify child_objs
1 parent 0658244 commit eba2182

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

brainpy/math/object_transform/controls.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from jax.core import UnexpectedTracerError
1414

1515
from brainpy import errors, tools, check
16-
from .base_object import get_unique_name
16+
from .base_object import get_unique_name, BrainPyObject
1717
from .collector import ArrayCollector
1818
from ..ndarray import (Array, Variable,
1919
add_context,
@@ -445,6 +445,7 @@ def cond(
445445
false_fun: Union[Callable, jnp.ndarray, Array, float, int, bool],
446446
operands: Any,
447447
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
448+
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
448449
):
449450
"""Simple conditional statement (if-else) with instance of :py:class:`~.Variable`.
450451
@@ -477,6 +478,10 @@ def cond(
477478
can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.
478479
dyn_vars: optional, Variable, sequence of Variable, dict
479480
The dynamically changed variables.
481+
child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
482+
The children objects used in the target function.
483+
484+
.. versionadded:: 2.3.1
480485
481486
Returns
482487
-------
@@ -487,8 +492,11 @@ def cond(
487492
true_fun = _check_f(true_fun)
488493
false_fun = _check_f(false_fun)
489494
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')
495+
dyn_vars = ArrayCollector(dyn_vars)
490496
dyn_vars.update(infer_dyn_vars(true_fun))
491497
dyn_vars.update(infer_dyn_vars(false_fun))
498+
for obj in check.is_all_objs(child_objs, out_as='tuple'):
499+
dyn_vars.update(obj.vars().unique())
492500
dyn_vars = list(ArrayCollector(dyn_vars).unique().values())
493501

494502
name = get_unique_name('_brainpy_object_oriented_cond_')
@@ -539,6 +547,7 @@ def ifelse(
539547
branches: Sequence[Callable],
540548
operands: Any = None,
541549
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
550+
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
542551
show_code: bool = False,
543552
):
544553
"""``If-else`` control flows looks like native Pythonic programming.
@@ -578,6 +587,10 @@ def ifelse(
578587
The dynamically changed variables.
579588
show_code: bool
580589
Whether show the formatted code.
590+
child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
591+
The children objects used in the target function.
592+
593+
.. versionadded:: 2.3.1
581594
582595
Returns
583596
-------
@@ -602,7 +615,9 @@ def ifelse(
602615
dyn_vars = ArrayCollector(dyn_vars)
603616
for f in branches:
604617
dyn_vars += infer_dyn_vars(f)
605-
dyn_vars = tuple(dyn_vars.values())
618+
for obj in check.is_all_objs(child_objs, out_as='tuple'):
619+
dyn_vars.update(obj.vars().unique())
620+
dyn_vars = tuple(dyn_vars.unique().values())
606621

607622
# format new codes
608623
if len(conditions) == 1:
@@ -647,6 +662,7 @@ def for_loop(
647662
operands: Any,
648663
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
649664
out_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
665+
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
650666
reverse: bool = False,
651667
unroll: int = 1,
652668
):
@@ -727,14 +743,21 @@ def for_loop(
727743
Optional positive int specifying, in the underlying operation of the
728744
scan primitive, how many scan iterations to unroll within a single
729745
iteration of a loop.
746+
child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
747+
The children objects used in the target function.
748+
749+
.. versionadded:: 2.3.1
730750
731751
Returns
732752
-------
733753
outs: Any
734754
The stacked outputs of ``body_fun`` when scanned over the leading axis of the inputs.
735755
"""
736756
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')
757+
dyn_vars = ArrayCollector(dyn_vars)
737758
dyn_vars.update(infer_dyn_vars(body_fun))
759+
for obj in check.is_all_objs(child_objs, out_as='tuple'):
760+
dyn_vars.update(obj.vars().unique())
738761
dyn_vars = list(ArrayCollector(dyn_vars).unique().values())
739762
outs, _ = tree_flatten(out_vars, lambda s: isinstance(s, Variable))
740763
for v in outs:
@@ -785,6 +808,7 @@ def while_loop(
785808
cond_fun: Callable,
786809
operands: Any,
787810
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
811+
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
788812
):
789813
"""``while-loop`` control flow with :py:class:`~.Variable`.
790814
@@ -831,13 +855,18 @@ def while_loop(
831855
The dynamically changed variables.
832856
operands: Any
833857
The operands for ``body_fun`` and ``cond_fun`` functions.
858+
child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
859+
The children objects used in the target function.
834860
861+
.. versionadded:: 2.3.1
835862
"""
836863
# iterable variables
837864
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')
838865
dyn_vars = ArrayCollector(dyn_vars)
839866
dyn_vars.update(infer_dyn_vars(body_fun))
840867
dyn_vars.update(infer_dyn_vars(cond_fun))
868+
for obj in check.is_all_objs(child_objs, out_as='tuple'):
869+
dyn_vars.update(obj.vars().unique())
841870
dyn_vars = tuple(dyn_vars.values())
842871
if not isinstance(operands, (list, tuple)):
843872
operands = (operands,)

0 commit comments

Comments
 (0)