Skip to content

Commit bd1b2c1

Browse files
committed
[interoperation] add apis and docs for brainpy.layers.FromFlax and brainpy.layer.ToFlaxRNNCell
1 parent 237607d commit bd1b2c1

File tree

14 files changed

+1812
-177
lines changed

14 files changed

+1812
-177
lines changed

brainpy/_src/initialize/generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def variable(
176176
new_shape = size[:batch_axis] + (int(batch_size_or_mode),) + size[batch_axis:]
177177
return bm.Variable(init(new_shape), batch_axis=batch_axis)
178178
else:
179-
raise ValueError('Unknown batch_size_or_mode.')
179+
raise ValueError(f'Unknown batch_size_or_mode: {batch_size_or_mode}')
180180

181181
else:
182182
if size is not None:

brainpy/_src/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
from .normalization import *
1010
from .pooling import *
1111
from .function import *
12+
from .interoperation_flax import *

brainpy/_src/layers/interoperation_flax.py

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11

2+
import jax
3+
import dataclasses
4+
from typing import Dict
25
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
36

47
from brainpy import math as bm
5-
from brainpy._src.dynsys import DynamicalSystemNS
8+
from brainpy._src.dynsys import DynamicalSystemNS, DynamicalSystem
9+
from brainpy._src.context import share
610

711
try:
812
import flax # noqa
13+
from flax.linen.recurrent import RNNCellBase
914
except:
1015
flax = None
16+
RNNCellBase = object
1117

1218

1319
__all__ = [
1420
'FromFlax',
21+
'ToFlaxRNNCell',
1522
'ToFlax',
1623
]
1724

@@ -28,6 +35,18 @@ def _is_bp(a):
2835

2936

3037
class FromFlax(DynamicalSystemNS):
38+
"""
39+
Transform a Flax module as a BrainPy :py:class:`~.DynamicalSystem`.
40+
41+
Parameters
42+
----------
43+
flax_module: Any
44+
The flax Module.
45+
module_args: Any
46+
The module arguments, used to initialize model parameters.
47+
module_kwargs: Any
48+
The module arguments, used to initialize model parameters.
49+
"""
3150
def __init__(self, flax_module, *module_args, **module_kwargs):
3251
super().__init__()
3352
self.flax_module = flax_module
@@ -47,14 +66,79 @@ def reset_state(self, *args, **kwargs):
4766
pass
4867

4968

69+
to_flax_doc = """Transform a BrainPy :py:class:`~.DynamicalSystem` into a Flax recurrent module."""
70+
71+
5072
if flax is not None:
51-
class ToFlax(flax.linen.Module):
52-
pass
73+
class ToFlaxRNNCell(RNNCellBase):
74+
__doc__ = to_flax_doc
75+
76+
model: DynamicalSystem
77+
train_params: Dict[str, jax.Array] = dataclasses.field(init=False)
78+
79+
def initialize_carry(self, rng, batch_dims, size=None, init_fn=None):
80+
if len(batch_dims) == 0:
81+
batch_dims = 1
82+
elif len(batch_dims) == 1:
83+
batch_dims = batch_dims[0]
84+
else:
85+
raise NotImplementedError
86+
87+
_state_vars = self.model.vars().unique().not_subset(bm.TrainVar)
88+
self.model.reset_state(batch_size=batch_dims)
89+
return [_state_vars.dict(), 0, 0.]
90+
91+
def setup(self):
92+
_vars = self.model.vars().unique()
93+
_train_vars = _vars.subset(bm.TrainVar)
94+
self.train_params = self.param(self.model.name, lambda rng, a: a.dict(), _train_vars)
95+
96+
def __call__(self, carry, *inputs):
97+
"""A recurrent cell that transformed from a BrainPy :py:class:`~.DynamicalSystem`.
98+
99+
Args:
100+
carry: the hidden state of the transformed recurrent cell, initialized using
101+
`.initialize_carry()` function in which the original `.reset_state()` is called.
102+
inputs: an ndarray with the input for the current time step. All
103+
dimensions except the final are considered batch dimensions.
104+
105+
Returns:
106+
A tuple with the new carry and the output.
107+
"""
108+
# shared arguments
109+
i, t = carry[1], carry[2]
110+
old_i = share.load('i', i)
111+
old_t = share.load('t', t)
112+
share.save(i=i, t=t)
113+
114+
# carry
115+
_vars = self.model.vars().unique()
116+
_state_vars = _vars.not_subset(bm.TrainVar)
117+
for k, v in carry[0].items():
118+
_state_vars[k].value = v
119+
120+
# train parameters
121+
_train_vars = _vars.subset(bm.TrainVar)
122+
for k, v in self.train_params.items():
123+
_train_vars[k].value = v
124+
125+
# recurrent cell
126+
out = self.model(*inputs)
127+
128+
# shared arguments
129+
share.save(i=old_i, t=old_t)
130+
# carray and output
131+
return [_state_vars.dict(), i + 1, t + share.dt], out
53132

54133

55134
else:
56-
class ToFlax(object):
135+
class ToFlaxRNNCell(object):
136+
__doc__ = to_flax_doc
137+
57138
def __init__(self, *args, **kwargs):
58139
raise ModuleNotFoundError('"flax" is not installed, or importing "flax" has errors. Please check.')
59140

60141

142+
ToFlax = ToFlaxRNNCell
143+
144+

brainpy/_src/math/object_transform/collectors.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,20 @@ def not_subset(self, var_type):
148148
gather[key] = value
149149
return gather
150150

151+
def include(self, *types):
152+
gather = type(self)()
153+
for key, value in self.items():
154+
if value.__class__ in types:
155+
gather[key] = value
156+
return gather
157+
158+
def exclude(self, *types):
159+
gather = type(self)()
160+
for key, value in self.items():
161+
if value.__class__ not in types:
162+
gather[key] = value
163+
return gather
164+
151165
def unique(self):
152166
"""Get a new type of collector with unique values.
153167

brainpy/layers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22

33

4-
54
from brainpy._src.layers.base import (
65
Layer as Layer,
76
)
@@ -85,3 +84,8 @@
8584
Conv3dLSTMCell as Conv3dLSTMCell,
8685
)
8786

87+
from brainpy._src.layers.interoperation_flax import (
88+
FromFlax,
89+
ToFlaxRNNCell, ToFlax,
90+
)
91+

docs/auto_generater.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def generate_inputs_docs():
381381

382382
def generate_layers_docs():
383383
_write_subsections_v2(
384-
'brainpy._src.dyn.layers',
384+
'brainpy._src.layers',
385385
'brainpy.layers',
386386
'apis/auto/layers.rst',
387387
subsections={
@@ -395,6 +395,7 @@ def generate_layers_docs():
395395
'pooling': 'Pooling Layers',
396396
'reservoir': 'Reservoir Layers',
397397
'rnncells': 'Artificial Recurrent Layers',
398+
'interoperation_flax': 'Interoperation with Flax',
398399
}
399400
)
400401

docs/index.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ general-purpose Brain Dynamics Programming (BDP). Among its key ingredients, Bra
5757
:caption: Advanced Tutorials
5858

5959
tutorial_advanced/adavanced_lowdim_analysis.ipynb
60-
tutorial_advanced/interoperation.ipynb
60+
tutorial_advanced/differentiation.ipynb
61+
tutorial_advanced/integrate_flax_into_brainpy.ipynb
62+
tutorial_advanced/integrate_bp_lif_into_flax.ipynb
63+
tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb
6164

6265

6366

0 commit comments

Comments
 (0)