Skip to content

Commit fee5d2d

Browse files
authored
feat: easier control flows with brainpy.math.ifelse (#189)
feat: easier control flows with `brainpy.math.ifelse`
2 parents af29fa5 + 1a1bb29 commit fee5d2d

File tree

3 files changed

+290
-48
lines changed

3 files changed

+290
-48
lines changed

brainpy/math/controls.py

Lines changed: 212 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
11
# -*- coding: utf-8 -*-
22

33

4+
from typing import Union, Sequence, Any, Dict
5+
46
from jax import lax
57
from jax.tree_util import tree_flatten, tree_unflatten
8+
69
try:
710
from jax.errors import UnexpectedTracerError
811
except ImportError:
912
from jax.core import UnexpectedTracerError
1013

1114
from brainpy import errors
12-
from brainpy.math.jaxarray import JaxArray, turn_on_global_jit, turn_off_global_jit
15+
from brainpy.math.jaxarray import (JaxArray, Variable,
16+
turn_on_global_jit,
17+
turn_off_global_jit)
1318
from brainpy.math.numpy_ops import as_device_array
1419

1520
__all__ = [
1621
'make_loop',
1722
'make_while',
1823
'make_cond',
24+
'ifelse',
1925
]
2026

2127

@@ -85,44 +91,44 @@ def make_loop(body_fun, dyn_vars, out_vars=None, has_return=False):
8591
>>> def f(x): a.value += 1.
8692
>>> loop = bm.make_loop(f, dyn_vars=[a], out_vars=a)
8793
>>> loop(length=10)
88-
JaxArray(DeviceArray([[ 1.],
89-
[ 2.],
90-
[ 3.],
91-
[ 4.],
92-
[ 5.],
93-
[ 6.],
94-
[ 7.],
95-
[ 8.],
96-
[ 9.],
97-
[10.]], dtype=float32))
94+
JaxArray([[ 1.],
95+
[ 2.],
96+
[ 3.],
97+
[ 4.],
98+
[ 5.],
99+
[ 6.],
100+
[ 7.],
101+
[ 8.],
102+
[ 9.],
103+
[10.]], dtype=float32)
98104
>>> b = bm.zeros(1)
99105
>>> def f(x):
100106
>>> b.value += 1
101107
>>> return b + 1
102108
>>> loop = bm.make_loop(f, dyn_vars=[b], out_vars=b, has_return=True)
103109
>>> hist_b, hist_b_plus = loop(length=10)
104110
>>> hist_b
105-
JaxArray(DeviceArray([[ 1.],
106-
[ 2.],
107-
[ 3.],
108-
[ 4.],
109-
[ 5.],
110-
[ 6.],
111-
[ 7.],
112-
[ 8.],
113-
[ 9.],
114-
[10.]], dtype=float32))
111+
JaxArray([[ 1.],
112+
[ 2.],
113+
[ 3.],
114+
[ 4.],
115+
[ 5.],
116+
[ 6.],
117+
[ 7.],
118+
[ 8.],
119+
[ 9.],
120+
[10.]], dtype=float32)
115121
>>> hist_b_plus
116-
JaxArray(DeviceArray([[ 2.],
117-
[ 3.],
118-
[ 4.],
119-
[ 5.],
120-
[ 6.],
121-
[ 7.],
122-
[ 8.],
123-
[ 9.],
124-
[10.],
125-
[11.]], dtype=float32))
122+
JaxArray([[ 2.],
123+
[ 3.],
124+
[ 4.],
125+
[ 5.],
126+
[ 6.],
127+
[ 7.],
128+
[ 8.],
129+
[ 9.],
130+
[10.],
131+
[11.]], dtype=float32)
126132
127133
Parameters
128134
----------
@@ -201,7 +207,7 @@ def make_while(cond_fun, body_fun, dyn_vars):
201207
>>> loop = bm.make_while(cond_f, body_f, dyn_vars=[a])
202208
>>> loop()
203209
>>> a
204-
JaxArray(DeviceArray([10.], dtype=float32))
210+
JaxArray([10.], dtype=float32)
205211
206212
Parameters
207213
----------
@@ -223,12 +229,11 @@ def make_while(cond_fun, body_fun, dyn_vars):
223229
elif isinstance(dyn_vars, (tuple, list)):
224230
dyn_vars = tuple(dyn_vars)
225231
else:
226-
raise ValueError(
227-
f'"dyn_vars" does not support {type(dyn_vars)}, '
228-
f'only support dict/list/tuple of {JaxArray.__name__}')
232+
raise ValueError(f'"dyn_vars" does not support {type(dyn_vars)}, '
233+
f'only support dict/list/tuple of {JaxArray.__name__}')
229234
for v in dyn_vars:
230235
if not isinstance(v, JaxArray):
231-
raise ValueError(f'brainpy.math.jax.loops only support {JaxArray.__name__}, but got {type(v)}')
236+
raise ValueError(f'Only support {JaxArray.__name__}, but got {type(v)}')
232237

233238
def _body_fun(op):
234239
dyn_values, static_values = op
@@ -274,12 +279,12 @@ def make_cond(true_fun, false_fun, dyn_vars=None):
274279
>>> cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])
275280
>>> cond(True)
276281
>>> a, b
277-
(JaxArray(DeviceArray([1., 1.], dtype=float32)),
278-
JaxArray(DeviceArray([1., 1.], dtype=float32)))
282+
(JaxArray([1., 1.], dtype=float32),
283+
JaxArray([1., 1.], dtype=float32))
279284
>>> cond(False)
280285
>>> a, b
281-
(JaxArray(DeviceArray([1., 1.], dtype=float32)),
282-
JaxArray(DeviceArray([0., 0.], dtype=float32)))
286+
(JaxArray([1., 1.], dtype=float32),
287+
JaxArray([0., 0.], dtype=float32))
283288
284289
Parameters
285290
----------
@@ -300,20 +305,17 @@ def make_cond(true_fun, false_fun, dyn_vars=None):
300305
if dyn_vars is None:
301306
dyn_vars = []
302307
if isinstance(dyn_vars, JaxArray):
303-
dyn_vars = (dyn_vars, )
308+
dyn_vars = (dyn_vars,)
304309
elif isinstance(dyn_vars, dict):
305310
dyn_vars = tuple(dyn_vars.values())
306311
elif isinstance(dyn_vars, (tuple, list)):
307312
dyn_vars = tuple(dyn_vars)
308313
else:
309-
raise ValueError(
310-
f'"dyn_vars" does not support {type(dyn_vars)}, '
311-
f'only support dict/list/tuple of {JaxArray.__name__}')
314+
raise ValueError(f'"dyn_vars" does not support {type(dyn_vars)}, '
315+
f'only support dict/list/tuple of {JaxArray.__name__}')
312316
for v in dyn_vars:
313317
if not isinstance(v, JaxArray):
314-
raise ValueError(
315-
f'brainpy.math.jax.loops only support '
316-
f'{JaxArray.__name__}, but got {type(v)}')
318+
raise ValueError(f'Only support {JaxArray.__name__}, but got {type(v)}')
317319

318320
def _true_fun(op):
319321
dyn_vals, static_vals = op
@@ -346,3 +348,166 @@ def call(pred, x=None):
346348
return res
347349

348350
return call
351+
352+
353+
def _cond_with_dyn_vars(pred, true_fun, false_fun, operands, dyn_vars):
354+
# iterable variables
355+
if isinstance(dyn_vars, JaxArray):
356+
dyn_vars = (dyn_vars,)
357+
elif isinstance(dyn_vars, dict):
358+
dyn_vars = tuple(dyn_vars.values())
359+
elif isinstance(dyn_vars, (tuple, list)):
360+
dyn_vars = tuple(dyn_vars)
361+
else:
362+
raise ValueError(f'"dyn_vars" does not support {type(dyn_vars)}, '
363+
f'only support dict/list/tuple of {JaxArray.__name__}')
364+
for v in dyn_vars:
365+
if not isinstance(v, JaxArray):
366+
raise ValueError(f'Only support {JaxArray.__name__}, but got {type(v)}')
367+
368+
def _true_fun(op):
369+
dyn_vals, static_vals = op
370+
for v, d in zip(dyn_vars, dyn_vals): v.value = d
371+
res = true_fun(static_vals)
372+
dyn_vals = [v.value for v in dyn_vars]
373+
return dyn_vals, res
374+
375+
def _false_fun(op):
376+
dyn_vals, static_vals = op
377+
for v, d in zip(dyn_vars, dyn_vals): v.value = d
378+
res = false_fun(static_vals)
379+
dyn_vals = [v.value for v in dyn_vars]
380+
return dyn_vals, res
381+
382+
# calling the model
383+
old_values = [v.value for v in dyn_vars]
384+
try:
385+
turn_on_global_jit()
386+
dyn_values, res = lax.cond(pred=pred,
387+
true_fun=_true_fun,
388+
false_fun=_false_fun,
389+
operand=(old_values, operands))
390+
turn_off_global_jit()
391+
except UnexpectedTracerError as e:
392+
turn_off_global_jit()
393+
for v, d in zip(dyn_vars, old_values): v.value = d
394+
raise errors.JaxTracerError(variables=dyn_vars) from e
395+
for v, d in zip(dyn_vars, dyn_values): v.value = d
396+
return res
397+
398+
399+
def _check_f(f):
400+
if callable(f):
401+
return f
402+
else:
403+
return (lambda _: f)
404+
405+
406+
def ifelse(
407+
conditions: Union[bool, Sequence[bool]],
408+
branches: Sequence,
409+
operands: Any = None,
410+
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
411+
show_code: bool = False,
412+
):
413+
"""If-else control flows like native Pythonic programming.
414+
415+
Examples
416+
--------
417+
418+
>>> import brainpy.math as bm
419+
>>> def f(a):
420+
>>> return bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
421+
>>> branches=[lambda _: 1,
422+
>>> lambda _: 2,
423+
>>> lambda _: 3,
424+
>>> lambda _: 4,
425+
>>> lambda _: 5])
426+
>>> f(1)
427+
4
428+
>>> # or, it can be expressed as:
429+
>>> def f(a):
430+
>>> return bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
431+
>>> branches=[1, 2, 3, 4, 5])
432+
433+
434+
Parameters
435+
----------
436+
conditions: bool, sequence of bool
437+
The boolean conditions.
438+
branches: Sequence
439+
The branches, at least has two elements. Elements can be functions,
440+
arrays, or numbers. The number of ``branches`` and ``conditions`` has
441+
the relationship of `len(branches) == len(conditions) + 1`.
442+
operands: optional, Any
443+
The operands for each branch.
444+
dyn_vars: Variable, sequence of Variable, dict
445+
The dynamically changed variables.
446+
show_code: bool
447+
Whether show the formatted code.
448+
449+
Returns
450+
-------
451+
res: Any
452+
The results of the control flow.
453+
"""
454+
# checking
455+
if not isinstance(conditions, (tuple, list)):
456+
conditions = [conditions]
457+
if not isinstance(conditions, (tuple, list)):
458+
raise ValueError(f'"conditions" must be a tuple/list of boolean values. '
459+
f'But we got {type(conditions)}: {conditions}')
460+
if not isinstance(branches, (tuple, list)):
461+
raise ValueError(f'"branches" must be a tuple/list. '
462+
f'But we got {type(branches)}.')
463+
branches = [_check_f(b) for b in branches]
464+
if len(branches) != len(conditions) + 1:
465+
raise ValueError(f'The numbers of branches and conditions do not match. '
466+
f'Got len(conditions)={len(conditions)} and len(branches)={len(branches)}. '
467+
f'We expect len(conditions) + 1 == len(branches). ')
468+
if dyn_vars is None:
469+
dyn_vars = []
470+
if isinstance(dyn_vars, Variable):
471+
dyn_vars = (dyn_vars,)
472+
elif isinstance(dyn_vars, dict):
473+
dyn_vars = tuple(dyn_vars.values())
474+
elif isinstance(dyn_vars, (tuple, list)):
475+
dyn_vars = tuple(dyn_vars)
476+
else:
477+
raise ValueError(f'"dyn_vars" does not support {type(dyn_vars)}, only '
478+
f'support dict/list/tuple of brainpy.math.Variable')
479+
for v in dyn_vars:
480+
if not isinstance(v, Variable):
481+
raise ValueError(f'Only support brainpy.math.Variable, but we got {type(v)}')
482+
483+
# format new codes
484+
code_scope = {'conditions': conditions, 'branches': branches}
485+
codes = ['def f(operands):', f' f0 = branches[{len(conditions)}]']
486+
num_cond = len(conditions) - 1
487+
if len(dyn_vars) > 0:
488+
code_scope['_cond'] = _cond_with_dyn_vars
489+
code_scope['dyn_vars'] = dyn_vars
490+
for i in range(len(conditions) - 1):
491+
codes.append(f' f{i+1} = lambda r: '
492+
f'_cond(conditions[{num_cond - i}], '
493+
f'branches[{num_cond - i}], f{i}, r, dyn_vars)')
494+
codes.append(f' return _cond(conditions[0], '
495+
f'branches[0], '
496+
f'f{len(conditions) - 1}, '
497+
f'operands, dyn_vars)')
498+
else:
499+
code_scope['_cond'] = lax.cond
500+
for i in range(len(conditions) - 1):
501+
codes.append(f' f{i+1} = lambda r: '
502+
f'_cond(conditions[{num_cond - i}], '
503+
f'branches[{num_cond - i}], f{i}, r)')
504+
codes.append(f' return _cond(conditions[0], '
505+
f'branches[0], '
506+
f'f{len(conditions) - 1}, '
507+
f'operands)')
508+
codes = '\n'.join(codes)
509+
if show_code:
510+
print(codes)
511+
exec(compile(codes.strip(), '', 'exec'), code_scope)
512+
f = code_scope['f']
513+
return f(operands)

brainpy/math/jaxarray.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,20 @@
3333

3434

3535
def turn_on_global_jit():
36+
"""Turn on the global JIT mode to declare
37+
all instantiated JaxArray cannot be updated."""
3638
global _global_jit_mode
3739
_global_jit_mode = True
3840

3941

4042
def turn_off_global_jit():
43+
"""Turn off the global JIT mode."""
4144
global _global_jit_mode
4245
_global_jit_mode = False
4346

4447

4548
class JaxArray(object):
46-
"""Multiple-dimensional array for JAX backend.
49+
"""Multiple-dimensional array in JAX backend.
4750
"""
4851
__slots__ = ("_value", "_outside_global_jit")
4952

0 commit comments

Comments
 (0)