11
2+ import jax
3+ import dataclasses
4+ from typing import Dict
25from jax .tree_util import tree_flatten , tree_map , tree_unflatten
36
47from brainpy import math as bm
5- from brainpy ._src .dynsys import DynamicalSystemNS
8+ from brainpy ._src .dynsys import DynamicalSystemNS , DynamicalSystem
9+ from brainpy ._src .context import share
610
711try :
812 import flax # noqa
13+ from flax .linen .recurrent import RNNCellBase
914except :
1015 flax = None
16+ RNNCellBase = object
1117
1218
1319__all__ = [
1420 'FromFlax' ,
21+ 'ToFlaxRNNCell' ,
1522 'ToFlax' ,
1623]
1724
@@ -28,6 +35,18 @@ def _is_bp(a):
2835
2936
3037class FromFlax (DynamicalSystemNS ):
38+ """
39+ Transform a Flax module as a BrainPy :py:class:`~.DynamicalSystem`.
40+
41+ Parameters
42+ ----------
43+ flax_module: Any
44+ The flax Module.
45+ module_args: Any
46+ The module arguments, used to initialize model parameters.
47+ module_kwargs: Any
48+ The module arguments, used to initialize model parameters.
49+ """
3150 def __init__ (self , flax_module , * module_args , ** module_kwargs ):
3251 super ().__init__ ()
3352 self .flax_module = flax_module
@@ -47,14 +66,79 @@ def reset_state(self, *args, **kwargs):
4766 pass
4867
4968
69+ to_flax_doc = """Transform a BrainPy :py:class:`~.DynamicalSystem` into a Flax recurrent module."""
70+
71+
5072if flax is not None :
51- class ToFlax (flax .linen .Module ):
52- pass
73+ class ToFlaxRNNCell (RNNCellBase ):
74+ __doc__ = to_flax_doc
75+
76+ model : DynamicalSystem
77+ train_params : Dict [str , jax .Array ] = dataclasses .field (init = False )
78+
79+ def initialize_carry (self , rng , batch_dims , size = None , init_fn = None ):
80+ if len (batch_dims ) == 0 :
81+ batch_dims = 1
82+ elif len (batch_dims ) == 1 :
83+ batch_dims = batch_dims [0 ]
84+ else :
85+ raise NotImplementedError
86+
87+ _state_vars = self .model .vars ().unique ().not_subset (bm .TrainVar )
88+ self .model .reset_state (batch_size = batch_dims )
89+ return [_state_vars .dict (), 0 , 0. ]
90+
91+ def setup (self ):
92+ _vars = self .model .vars ().unique ()
93+ _train_vars = _vars .subset (bm .TrainVar )
94+ self .train_params = self .param (self .model .name , lambda rng , a : a .dict (), _train_vars )
95+
96+ def __call__ (self , carry , * inputs ):
97+ """A recurrent cell that transformed from a BrainPy :py:class:`~.DynamicalSystem`.
98+
99+ Args:
100+ carry: the hidden state of the transformed recurrent cell, initialized using
101+ `.initialize_carry()` function in which the original `.reset_state()` is called.
102+ inputs: an ndarray with the input for the current time step. All
103+ dimensions except the final are considered batch dimensions.
104+
105+ Returns:
106+ A tuple with the new carry and the output.
107+ """
108+ # shared arguments
109+ i , t = carry [1 ], carry [2 ]
110+ old_i = share .load ('i' , i )
111+ old_t = share .load ('t' , t )
112+ share .save (i = i , t = t )
113+
114+ # carry
115+ _vars = self .model .vars ().unique ()
116+ _state_vars = _vars .not_subset (bm .TrainVar )
117+ for k , v in carry [0 ].items ():
118+ _state_vars [k ].value = v
119+
120+ # train parameters
121+ _train_vars = _vars .subset (bm .TrainVar )
122+ for k , v in self .train_params .items ():
123+ _train_vars [k ].value = v
124+
125+ # recurrent cell
126+ out = self .model (* inputs )
127+
128+ # shared arguments
129+ share .save (i = old_i , t = old_t )
130+ # carray and output
131+ return [_state_vars .dict (), i + 1 , t + share .dt ], out
53132
54133
55134else :
56- class ToFlax (object ):
135+ class ToFlaxRNNCell (object ):
136+ __doc__ = to_flax_doc
137+
57138 def __init__ (self , * args , ** kwargs ):
58139 raise ModuleNotFoundError ('"flax" is not installed, or importing "flax" has errors. Please check.' )
59140
60141
142+ ToFlax = ToFlaxRNNCell
143+
144+
0 commit comments