11# -*- coding: utf-8 -*-
22
33
4+ from typing import Union , Sequence , Any , Dict
5+
46from jax import lax
57from jax .tree_util import tree_flatten , tree_unflatten
8+
69try :
710 from jax .errors import UnexpectedTracerError
811except ImportError :
912 from jax .core import UnexpectedTracerError
1013
1114from brainpy import errors
12- from brainpy .math .jaxarray import JaxArray , turn_on_global_jit , turn_off_global_jit
15+ from brainpy .math .jaxarray import (JaxArray , Variable ,
16+ turn_on_global_jit ,
17+ turn_off_global_jit )
1318from brainpy .math .numpy_ops import as_device_array
1419
1520__all__ = [
1621 'make_loop' ,
1722 'make_while' ,
1823 'make_cond' ,
24+ 'ifelse' ,
1925]
2026
2127
@@ -85,44 +91,44 @@ def make_loop(body_fun, dyn_vars, out_vars=None, has_return=False):
8591 >>> def f(x): a.value += 1.
8692 >>> loop = bm.make_loop(f, dyn_vars=[a], out_vars=a)
8793 >>> loop(length=10)
88- JaxArray(DeviceArray( [[ 1.],
89- [ 2.],
90- [ 3.],
91- [ 4.],
92- [ 5.],
93- [ 6.],
94- [ 7.],
95- [ 8.],
96- [ 9.],
97- [10.]], dtype=float32) )
94+ JaxArray([[ 1.],
95+ [ 2.],
96+ [ 3.],
97+ [ 4.],
98+ [ 5.],
99+ [ 6.],
100+ [ 7.],
101+ [ 8.],
102+ [ 9.],
103+ [10.]], dtype=float32)
98104 >>> b = bm.zeros(1)
99105 >>> def f(x):
100106 >>> b.value += 1
101107 >>> return b + 1
102108 >>> loop = bm.make_loop(f, dyn_vars=[b], out_vars=b, has_return=True)
103109 >>> hist_b, hist_b_plus = loop(length=10)
104110 >>> hist_b
105- JaxArray(DeviceArray( [[ 1.],
106- [ 2.],
107- [ 3.],
108- [ 4.],
109- [ 5.],
110- [ 6.],
111- [ 7.],
112- [ 8.],
113- [ 9.],
114- [10.]], dtype=float32) )
111+ JaxArray([[ 1.],
112+ [ 2.],
113+ [ 3.],
114+ [ 4.],
115+ [ 5.],
116+ [ 6.],
117+ [ 7.],
118+ [ 8.],
119+ [ 9.],
120+ [10.]], dtype=float32)
115121 >>> hist_b_plus
116- JaxArray(DeviceArray( [[ 2.],
117- [ 3.],
118- [ 4.],
119- [ 5.],
120- [ 6.],
121- [ 7.],
122- [ 8.],
123- [ 9.],
124- [10.],
125- [11.]], dtype=float32) )
122+ JaxArray([[ 2.],
123+ [ 3.],
124+ [ 4.],
125+ [ 5.],
126+ [ 6.],
127+ [ 7.],
128+ [ 8.],
129+ [ 9.],
130+ [10.],
131+ [11.]], dtype=float32)
126132
127133 Parameters
128134 ----------
@@ -201,7 +207,7 @@ def make_while(cond_fun, body_fun, dyn_vars):
201207 >>> loop = bm.make_while(cond_f, body_f, dyn_vars=[a])
202208 >>> loop()
203209 >>> a
204- JaxArray(DeviceArray( [10.], dtype=float32) )
210+ JaxArray([10.], dtype=float32)
205211
206212 Parameters
207213 ----------
@@ -223,12 +229,11 @@ def make_while(cond_fun, body_fun, dyn_vars):
223229 elif isinstance (dyn_vars , (tuple , list )):
224230 dyn_vars = tuple (dyn_vars )
225231 else :
226- raise ValueError (
227- f'"dyn_vars" does not support { type (dyn_vars )} , '
228- f'only support dict/list/tuple of { JaxArray .__name__ } ' )
232+ raise ValueError (f'"dyn_vars" does not support { type (dyn_vars )} , '
233+ f'only support dict/list/tuple of { JaxArray .__name__ } ' )
229234 for v in dyn_vars :
230235 if not isinstance (v , JaxArray ):
231- raise ValueError (f'brainpy.math.jax.loops only support { JaxArray .__name__ } , but got { type (v )} ' )
236+ raise ValueError (f'Only support { JaxArray .__name__ } , but got { type (v )} ' )
232237
233238 def _body_fun (op ):
234239 dyn_values , static_values = op
@@ -274,12 +279,12 @@ def make_cond(true_fun, false_fun, dyn_vars=None):
274279 >>> cond = bm.make_cond(true_f, false_f, dyn_vars=[a, b])
275280 >>> cond(True)
276281 >>> a, b
277- (JaxArray(DeviceArray( [1., 1.], dtype=float32) ),
278- JaxArray(DeviceArray( [1., 1.], dtype=float32) ))
282+ (JaxArray([1., 1.], dtype=float32),
283+ JaxArray([1., 1.], dtype=float32))
279284 >>> cond(False)
280285 >>> a, b
281- (JaxArray(DeviceArray( [1., 1.], dtype=float32) ),
282- JaxArray(DeviceArray( [0., 0.], dtype=float32) ))
286+ (JaxArray([1., 1.], dtype=float32),
287+ JaxArray([0., 0.], dtype=float32))
283288
284289 Parameters
285290 ----------
@@ -300,20 +305,17 @@ def make_cond(true_fun, false_fun, dyn_vars=None):
300305 if dyn_vars is None :
301306 dyn_vars = []
302307 if isinstance (dyn_vars , JaxArray ):
303- dyn_vars = (dyn_vars , )
308+ dyn_vars = (dyn_vars ,)
304309 elif isinstance (dyn_vars , dict ):
305310 dyn_vars = tuple (dyn_vars .values ())
306311 elif isinstance (dyn_vars , (tuple , list )):
307312 dyn_vars = tuple (dyn_vars )
308313 else :
309- raise ValueError (
310- f'"dyn_vars" does not support { type (dyn_vars )} , '
311- f'only support dict/list/tuple of { JaxArray .__name__ } ' )
314+ raise ValueError (f'"dyn_vars" does not support { type (dyn_vars )} , '
315+ f'only support dict/list/tuple of { JaxArray .__name__ } ' )
312316 for v in dyn_vars :
313317 if not isinstance (v , JaxArray ):
314- raise ValueError (
315- f'brainpy.math.jax.loops only support '
316- f'{ JaxArray .__name__ } , but got { type (v )} ' )
318+ raise ValueError (f'Only support { JaxArray .__name__ } , but got { type (v )} ' )
317319
318320 def _true_fun (op ):
319321 dyn_vals , static_vals = op
@@ -346,3 +348,166 @@ def call(pred, x=None):
346348 return res
347349
348350 return call
351+
352+
353+ def _cond_with_dyn_vars (pred , true_fun , false_fun , operands , dyn_vars ):
354+ # iterable variables
355+ if isinstance (dyn_vars , JaxArray ):
356+ dyn_vars = (dyn_vars ,)
357+ elif isinstance (dyn_vars , dict ):
358+ dyn_vars = tuple (dyn_vars .values ())
359+ elif isinstance (dyn_vars , (tuple , list )):
360+ dyn_vars = tuple (dyn_vars )
361+ else :
362+ raise ValueError (f'"dyn_vars" does not support { type (dyn_vars )} , '
363+ f'only support dict/list/tuple of { JaxArray .__name__ } ' )
364+ for v in dyn_vars :
365+ if not isinstance (v , JaxArray ):
366+ raise ValueError (f'Only support { JaxArray .__name__ } , but got { type (v )} ' )
367+
368+ def _true_fun (op ):
369+ dyn_vals , static_vals = op
370+ for v , d in zip (dyn_vars , dyn_vals ): v .value = d
371+ res = true_fun (static_vals )
372+ dyn_vals = [v .value for v in dyn_vars ]
373+ return dyn_vals , res
374+
375+ def _false_fun (op ):
376+ dyn_vals , static_vals = op
377+ for v , d in zip (dyn_vars , dyn_vals ): v .value = d
378+ res = false_fun (static_vals )
379+ dyn_vals = [v .value for v in dyn_vars ]
380+ return dyn_vals , res
381+
382+ # calling the model
383+ old_values = [v .value for v in dyn_vars ]
384+ try :
385+ turn_on_global_jit ()
386+ dyn_values , res = lax .cond (pred = pred ,
387+ true_fun = _true_fun ,
388+ false_fun = _false_fun ,
389+ operand = (old_values , operands ))
390+ turn_off_global_jit ()
391+ except UnexpectedTracerError as e :
392+ turn_off_global_jit ()
393+ for v , d in zip (dyn_vars , old_values ): v .value = d
394+ raise errors .JaxTracerError (variables = dyn_vars ) from e
395+ for v , d in zip (dyn_vars , dyn_values ): v .value = d
396+ return res
397+
398+
399+ def _check_f (f ):
400+ if callable (f ):
401+ return f
402+ else :
403+ return (lambda _ : f )
404+
405+
406+ def ifelse (
407+ conditions : Union [bool , Sequence [bool ]],
408+ branches : Sequence ,
409+ operands : Any = None ,
410+ dyn_vars : Union [Variable , Sequence [Variable ], Dict [str , Variable ]] = None ,
411+ show_code : bool = False ,
412+ ):
413+ """If-else control flows like native Pythonic programming.
414+
415+ Examples
416+ --------
417+
418+ >>> import brainpy.math as bm
419+ >>> def f(a):
420+ >>> return bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
421+ >>> branches=[lambda _: 1,
422+ >>> lambda _: 2,
423+ >>> lambda _: 3,
424+ >>> lambda _: 4,
425+ >>> lambda _: 5])
426+ >>> f(1)
427+ 4
428+ >>> # or, it can be expressed as:
429+ >>> def f(a):
430+ >>> return bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
431+ >>> branches=[1, 2, 3, 4, 5])
432+
433+
434+ Parameters
435+ ----------
436+ conditions: bool, sequence of bool
437+ The boolean conditions.
438+ branches: Sequence
439+ The branches, at least has two elements. Elements can be functions,
440+ arrays, or numbers. The number of ``branches`` and ``conditions`` has
441+ the relationship of `len(branches) == len(conditions) + 1`.
442+ operands: optional, Any
443+ The operands for each branch.
444+ dyn_vars: Variable, sequence of Variable, dict
445+ The dynamically changed variables.
446+ show_code: bool
447+ Whether show the formatted code.
448+
449+ Returns
450+ -------
451+ res: Any
452+ The results of the control flow.
453+ """
454+ # checking
455+ if not isinstance (conditions , (tuple , list )):
456+ conditions = [conditions ]
457+ if not isinstance (conditions , (tuple , list )):
458+ raise ValueError (f'"conditions" must be a tuple/list of boolean values. '
459+ f'But we got { type (conditions )} : { conditions } ' )
460+ if not isinstance (branches , (tuple , list )):
461+ raise ValueError (f'"branches" must be a tuple/list. '
462+ f'But we got { type (branches )} .' )
463+ branches = [_check_f (b ) for b in branches ]
464+ if len (branches ) != len (conditions ) + 1 :
465+ raise ValueError (f'The numbers of branches and conditions do not match. '
466+ f'Got len(conditions)={ len (conditions )} and len(branches)={ len (branches )} . '
467+ f'We expect len(conditions) + 1 == len(branches). ' )
468+ if dyn_vars is None :
469+ dyn_vars = []
470+ if isinstance (dyn_vars , Variable ):
471+ dyn_vars = (dyn_vars ,)
472+ elif isinstance (dyn_vars , dict ):
473+ dyn_vars = tuple (dyn_vars .values ())
474+ elif isinstance (dyn_vars , (tuple , list )):
475+ dyn_vars = tuple (dyn_vars )
476+ else :
477+ raise ValueError (f'"dyn_vars" does not support { type (dyn_vars )} , only '
478+ f'support dict/list/tuple of brainpy.math.Variable' )
479+ for v in dyn_vars :
480+ if not isinstance (v , Variable ):
481+ raise ValueError (f'Only support brainpy.math.Variable, but we got { type (v )} ' )
482+
483+ # format new codes
484+ code_scope = {'conditions' : conditions , 'branches' : branches }
485+ codes = ['def f(operands):' , f' f0 = branches[{ len (conditions )} ]' ]
486+ num_cond = len (conditions ) - 1
487+ if len (dyn_vars ) > 0 :
488+ code_scope ['_cond' ] = _cond_with_dyn_vars
489+ code_scope ['dyn_vars' ] = dyn_vars
490+ for i in range (len (conditions ) - 1 ):
491+ codes .append (f' f{ i + 1 } = lambda r: '
492+ f'_cond(conditions[{ num_cond - i } ], '
493+ f'branches[{ num_cond - i } ], f{ i } , r, dyn_vars)' )
494+ codes .append (f' return _cond(conditions[0], '
495+ f'branches[0], '
496+ f'f{ len (conditions ) - 1 } , '
497+ f'operands, dyn_vars)' )
498+ else :
499+ code_scope ['_cond' ] = lax .cond
500+ for i in range (len (conditions ) - 1 ):
501+ codes .append (f' f{ i + 1 } = lambda r: '
502+ f'_cond(conditions[{ num_cond - i } ], '
503+ f'branches[{ num_cond - i } ], f{ i } , r)' )
504+ codes .append (f' return _cond(conditions[0], '
505+ f'branches[0], '
506+ f'f{ len (conditions ) - 1 } , '
507+ f'operands)' )
508+ codes = '\n ' .join (codes )
509+ if show_code :
510+ print (codes )
511+ exec (compile (codes .strip (), '' , 'exec' ), code_scope )
512+ f = code_scope ['f' ]
513+ return f (operands )
0 commit comments