Skip to content

Commit 5b68b71

Browse files
authored
Merge branch 'brainpy:master' into master
2 parents d8634ec + 33cff62 commit 5b68b71

File tree

87 files changed

+3296
-1702
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+3296
-1702
lines changed

brainpy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
inputs, # methods for generating input currents
3434
algorithms, # online or offline training algorithms
3535
encoding, # encoding schema
36+
checkpoints, # checkpoints
37+
check, # error checking
3638
)
3739

3840
# numerical integrators

brainpy/base/base.py

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22

33
import logging
44
import os.path
5+
import warnings
6+
from collections import namedtuple
7+
from typing import Dict, Any, Tuple
58

69
from brainpy import errors
710
from brainpy.base import io, naming
811
from brainpy.base.collector import Collector, ArrayCollector
912

1013
math = None
14+
StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
1115

1216
__all__ = [
1317
'BrainPyObject',
@@ -57,25 +61,27 @@ def name(self, name: str = None):
5761
naming.check_name_uniqueness(name=self._name, obj=self)
5862

5963
def register_implicit_vars(self, *variables, **named_variables):
60-
from brainpy.math import Variable
64+
global math
65+
if math is None: from brainpy import math
66+
6167
for variable in variables:
62-
if isinstance(variable, Variable):
68+
if isinstance(variable, math.Variable):
6369
self.implicit_vars[f'var{id(variable)}'] = variable
6470
elif isinstance(variable, (tuple, list)):
6571
for v in variable:
66-
if not isinstance(v, Variable):
67-
raise ValueError(f'Must be instance of {Variable.__name__}, but we got {type(v)}')
72+
if not isinstance(v, math.Variable):
73+
raise ValueError(f'Must be instance of {math.Variable.__name__}, but we got {type(v)}')
6874
self.implicit_vars[f'var{id(v)}'] = v
6975
elif isinstance(variable, dict):
7076
for k, v in variable.items():
71-
if not isinstance(v, Variable):
72-
raise ValueError(f'Must be instance of {Variable.__name__}, but we got {type(v)}')
77+
if not isinstance(v, math.Variable):
78+
raise ValueError(f'Must be instance of {math.Variable.__name__}, but we got {type(v)}')
7379
self.implicit_vars[k] = v
7480
else:
7581
raise ValueError(f'Unknown type: {type(variable)}')
7682
for key, variable in named_variables.items():
77-
if not isinstance(variable, Variable):
78-
raise ValueError(f'Must be instance of {Variable.__name__}, but we got {type(variable)}')
83+
if not isinstance(variable, math.Variable):
84+
raise ValueError(f'Must be instance of {math.Variable.__name__}, but we got {type(variable)}')
7985
self.implicit_vars[key] = variable
8086

8187
def register_implicit_nodes(self, *nodes, node_cls: type = None, **named_nodes):
@@ -101,7 +107,11 @@ def register_implicit_nodes(self, *nodes, node_cls: type = None, **named_nodes):
101107
raise ValueError(f'Must be instance of {node_cls.__name__}, but we got {type(node)}')
102108
self.implicit_nodes[key] = node
103109

104-
def vars(self, method='absolute', level=-1, include_self=True):
110+
def vars(self,
111+
method: str = 'absolute',
112+
level: int = -1,
113+
include_self: bool = True,
114+
exclude_types: Tuple[type, ...] = None):
105115
"""Collect all variables in this node and the children nodes.
106116
107117
Parameters
@@ -112,6 +122,8 @@ def vars(self, method='absolute', level=-1, include_self=True):
112122
The hierarchy level to find variables.
113123
include_self: bool
114124
Whether include the variables in the self.
125+
exclude_types: tuple of type
126+
The type to exclude.
115127
116128
Returns
117129
-------
@@ -121,12 +133,19 @@ def vars(self, method='absolute', level=-1, include_self=True):
121133
global math
122134
if math is None: from brainpy import math
123135

136+
if exclude_types is None:
137+
exclude_types = (math.VariableView, )
124138
nodes = self.nodes(method=method, level=level, include_self=include_self)
125139
gather = ArrayCollector()
126140
for node_path, node in nodes.items():
127141
for k in dir(node):
128142
v = getattr(node, k)
143+
include = False
129144
if isinstance(v, math.Variable):
145+
include = True
146+
if len(exclude_types) and isinstance(v, exclude_types):
147+
include = False
148+
if include:
130149
if k not in node._excluded_vars:
131150
gather[f'{node_path}.{k}' if node_path else k] = v
132151
gather.update({f'{node_path}.{k}': v for k, v in node.implicit_vars.items()})
@@ -306,6 +325,49 @@ def save_states(self, filename, variables=None, **setting):
306325
else:
307326
raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}')
308327

328+
def state_dict(self):
329+
"""Returns a dictionary containing a whole state of the module.
330+
331+
Returns
332+
-------
333+
out: dict
334+
A dictionary containing a whole state of the module.
335+
"""
336+
return self.vars().unique().dict()
337+
338+
def load_state_dict(self, state_dict: Dict[str, Any], warn: bool = True):
339+
"""Copy parameters and buffers from :attr:`state_dict` into
340+
this module and its descendants.
341+
342+
Parameters
343+
----------
344+
state_dict: dict
345+
A dict containing parameters and persistent buffers.
346+
warn: bool
347+
Warnings when there are missing keys or unexpected keys in the external ``state_dict``.
348+
349+
Returns
350+
-------
351+
out: StateLoadResult
352+
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
353+
354+
* **missing_keys** is a list of str containing the missing keys
355+
* **unexpected_keys** is a list of str containing the unexpected keys
356+
"""
357+
variables = self.vars().unique()
358+
keys1 = set(state_dict.keys())
359+
keys2 = set(variables.keys())
360+
unexpected_keys = list(keys1 - keys2)
361+
missing_keys = list(keys2 - keys1)
362+
for key in keys2.intersection(keys1):
363+
variables[key].value = state_dict[key]
364+
if warn:
365+
if len(unexpected_keys):
366+
warnings.warn(f'Unexpected keys in state_dict: {unexpected_keys}', UserWarning)
367+
if len(missing_keys):
368+
warnings.warn(f'Missing keys in state_dict: {missing_keys}', UserWarning)
369+
return StateLoadResult(missing_keys, unexpected_keys)
370+
309371
# def to(self, devices):
310372
# global math
311373
# if math is None: from brainpy import math
@@ -317,7 +379,6 @@ def save_states(self, filename, variables=None, **setting):
317379
# all_vars = self.vars().unique()
318380
# for data in all_vars.values():
319381
# data[:] = math.asarray(data.value)
320-
# # TODO
321382
#
322383
# def cuda(self):
323384
# global math

brainpy/base/collector.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
from typing import Dict, Sequence, Union
55

6+
from jax.tree_util import register_pytree_node
7+
from jax.util import safe_zip
8+
69
math = None
710

811
__all__ = [
@@ -37,11 +40,12 @@ def update(self, other, **kwargs):
3740
elif isinstance(other, (tuple, list)):
3841
num = len(self)
3942
for i, value in enumerate(other):
40-
self[f'_var{i+num}'] = value
43+
self[f'_var{i + num}'] = value
4144
else:
4245
raise ValueError(f'Only supports dict/list/tuple, but we got {type(other)}')
4346
for key, value in kwargs.items():
4447
self[key] = value
48+
return self
4549

4650
def __add__(self, other):
4751
"""Merging two dicts.
@@ -73,8 +77,8 @@ def __sub__(self, other: Union[Dict, Sequence]):
7377
gather: Collector
7478
The new collector.
7579
"""
76-
if not isinstance(other, dict):
77-
raise ValueError(f'Only support dict, but we got {type(other)}.')
80+
if not isinstance(other, (dict, tuple, list)):
81+
raise ValueError(f'Only support dict/tuple/list, but we got {type(other)}.')
7882
gather = type(self)(self)
7983
if isinstance(other, dict):
8084
for key, val in other.items():
@@ -87,7 +91,21 @@ def __sub__(self, other: Union[Dict, Sequence]):
8791
raise ValueError(f'Cannot remove {key}, because we do not find it '
8892
f'in {self.keys()}.')
8993
elif isinstance(other, (list, tuple)):
94+
id_to_keys = {}
95+
for k, v in self.items():
96+
id_ = id(v)
97+
if id_ not in id_to_keys:
98+
id_to_keys[id_] = []
99+
id_to_keys[id_].append(k)
100+
101+
keys_to_remove = []
90102
for key in other:
103+
if isinstance(key, str):
104+
keys_to_remove.append(key)
105+
else:
106+
keys_to_remove.extend(id_to_keys[id(key)])
107+
108+
for key in set(keys_to_remove):
91109
if key in gather:
92110
gather.pop(key)
93111
else:
@@ -199,3 +217,9 @@ def from_other(cls, other: Union[Sequence, Dict]):
199217

200218

201219
TensorCollector = ArrayCollector
220+
221+
register_pytree_node(
222+
ArrayCollector,
223+
lambda x: (x.values(), x.keys()),
224+
lambda keys, values: ArrayCollector(safe_zip(keys, values))
225+
)

brainpy/base/io.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _check_target(target):
7272

7373
not_found_msg = ('"{key}" is stored in {filename}. But we does '
7474
'not find it is defined as variable in {target}.')
75-
id_dismatch_msg = ('{key1} and {key2} is the same data in {filename}. '
75+
id_mismatch_msg = ('{key1} and {key2} is the same data in {filename}. '
7676
'But we found they are different in {target}.')
7777

7878
DUPLICATE_KEY = 'duplicate_keys'
@@ -92,7 +92,7 @@ def _load(
9292
# get variables
9393
_check_target(target)
9494
variables = target.vars(method='absolute', level=-1)
95-
all_names = list(variables.keys())
95+
var_names_in_obj = list(variables.keys())
9696

9797
# read data from file
9898
for key in load_vars.keys():
@@ -105,22 +105,22 @@ def _load(
105105
else:
106106
value = load_vars[key]
107107
variables[key].value = bm.asarray(value)
108-
all_names.remove(key)
108+
var_names_in_obj.remove(key)
109109

110110
# check duplicate names
111111
duplicate_keys = duplicates[0]
112112
duplicate_targets = duplicates[1]
113113
for key1, key2 in zip(duplicate_keys, duplicate_targets):
114-
if key1 not in all_names:
114+
if key1 not in var_names_in_obj:
115115
raise KeyError(not_found_msg.format(key=key1, target=target.name, filename=filename))
116116
if id(variables[key1]) != id(variables[key2]):
117-
raise ValueError(id_dismatch_msg.format(key1=key1, key2=target, filename=filename, target=target.name))
118-
all_names.remove(key1)
117+
raise ValueError(id_mismatch_msg.format(key1=key1, key2=target, filename=filename, target=target.name))
118+
var_names_in_obj.remove(key1)
119119

120120
# check missing names
121-
if len(all_names):
121+
if len(var_names_in_obj):
122122
logger.warning(f'There are variable states missed in {filename}. '
123-
f'The missed variables are: {all_names}.')
123+
f'The missed variables are: {var_names_in_obj}.')
124124

125125

126126
def _unique_and_duplicate(collector: dict):

0 commit comments

Comments
 (0)