Skip to content

Commit 3122d03

Browse files
authored
feat: fix io for brainpy.Base (#211)
feat: fix `io` for brainpy.Base
2 parents c6cfe3f + 6630bd8 commit 3122d03

File tree

7 files changed

+562
-126
lines changed

7 files changed

+562
-126
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ publishment.md
33
.vscode
44

55

6+
brainpy/base/tests/io_test_tmp*
7+
68
development
79

810
examples/simulation/data
@@ -53,7 +55,6 @@ develop/benchmark/CUBA/annarchy*
5355
develop/benchmark/CUBA/brian2*
5456

5557

56-
5758
*~
5859
\#*\#
5960
*.pyc

brainpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
__version__ = "2.1.11"
3+
__version__ = "2.1.12"
44

55

66
try:

brainpy/base/base.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -208,49 +208,50 @@ def unique_name(self, name=None, type_=None):
208208
naming.check_name_uniqueness(name=name, obj=self)
209209
return name
210210

211-
def load_states(self, filename, verbose=False, check_missing=False):
211+
def load_states(self, filename, verbose=False):
212212
"""Load the model states.
213213
214214
Parameters
215215
----------
216216
filename : str
217217
The filename which stores the model states.
218218
verbose: bool
219-
check_missing: bool
219+
Whether report the load progress.
220220
"""
221221
if not os.path.exists(filename):
222222
raise errors.BrainPyError(f'Cannot find the file path: {filename}')
223223
elif filename.endswith('.hdf5') or filename.endswith('.h5'):
224-
io.load_h5(filename, target=self, verbose=verbose, check=check_missing)
224+
io.load_by_h5(filename, target=self, verbose=verbose)
225225
elif filename.endswith('.pkl'):
226-
io.load_pkl(filename, target=self, verbose=verbose, check=check_missing)
226+
io.load_by_pkl(filename, target=self, verbose=verbose)
227227
elif filename.endswith('.npz'):
228-
io.load_npz(filename, target=self, verbose=verbose, check=check_missing)
228+
io.load_by_npz(filename, target=self, verbose=verbose)
229229
elif filename.endswith('.mat'):
230-
io.load_mat(filename, target=self, verbose=verbose, check=check_missing)
230+
io.load_by_mat(filename, target=self, verbose=verbose)
231231
else:
232232
raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}')
233233

234-
def save_states(self, filename, all_vars=None, **setting):
234+
def save_states(self, filename, variables=None, **setting):
235235
"""Save the model states.
236236
237237
Parameters
238238
----------
239239
filename : str
240240
The file name which to store the model states.
241-
all_vars: optional, dict, TensorCollector
241+
variables: optional, dict, TensorCollector
242+
The variables to save. If not provided, all variables retrieved by ``~.vars()`` will be used.
242243
"""
243-
if all_vars is None:
244-
all_vars = self.vars(method='relative').unique()
244+
if variables is None:
245+
variables = self.vars(method='absolute', level=-1)
245246

246247
if filename.endswith('.hdf5') or filename.endswith('.h5'):
247-
io.save_h5(filename, all_vars=all_vars)
248-
elif filename.endswith('.pkl'):
249-
io.save_pkl(filename, all_vars=all_vars)
248+
io.save_as_h5(filename, variables=variables)
249+
elif filename.endswith('.pkl') or filename.endswith('.pickle'):
250+
io.save_as_pkl(filename, variables=variables)
250251
elif filename.endswith('.npz'):
251-
io.save_npz(filename, all_vars=all_vars, **setting)
252+
io.save_as_npz(filename, variables=variables, **setting)
252253
elif filename.endswith('.mat'):
253-
io.save_mat(filename, all_vars=all_vars)
254+
io.save_as_mat(filename, variables=variables)
254255
else:
255256
raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}')
256257

brainpy/base/collector.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,35 @@ def update(self, other, **kwargs):
3939
self[key] = value
4040

4141
def __add__(self, other):
42+
"""Merging two dicts.
43+
44+
Parameters
45+
----------
46+
other: dict
47+
The other dict instance.
48+
49+
Returns
50+
-------
51+
gather: Collector
52+
The new collector.
53+
"""
4254
gather = type(self)(self)
4355
gather.update(other)
4456
return gather
4557

4658
def __sub__(self, other):
59+
"""Remove other item in the collector.
60+
61+
Parameters
62+
----------
63+
other: dict
64+
The items to remove.
65+
66+
Returns
67+
-------
68+
gather: Collector
69+
The new collector.
70+
"""
4771
if not isinstance(other, dict):
4872
raise ValueError(f'Only support dict, but we got {type(other)}.')
4973
gather = type(self)()

0 commit comments

Comments
 (0)