Skip to content

Commit 3f7472a

Browse files
committed
fix bugs
1 parent f41654c commit 3f7472a

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

brainpy/math/object_transform/base_object.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3+
import os
34
import logging
45
import warnings
56
from collections import namedtuple
@@ -426,7 +427,19 @@ def load_states(self, filename, verbose=False):
426427
verbose: bool
427428
Whether report the load progress.
428429
"""
429-
raise errors.NoLongerSupportError('Use brainpy.checkpoints.load().')
430+
from brainpy.checkpoints import io
431+
if not os.path.exists(filename):
432+
raise errors.BrainPyError(f'Cannot find the file path: {filename}')
433+
elif filename.endswith('.hdf5') or filename.endswith('.h5'):
434+
io.load_by_h5(filename, target=self, verbose=verbose)
435+
elif filename.endswith('.pkl'):
436+
io.load_by_pkl(filename, target=self, verbose=verbose)
437+
elif filename.endswith('.npz'):
438+
io.load_by_npz(filename, target=self, verbose=verbose)
439+
elif filename.endswith('.mat'):
440+
io.load_by_mat(filename, target=self, verbose=verbose)
441+
else:
442+
raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}')
430443

431444
def save_states(self, filename, variables=None, **setting):
432445
"""Save the model states.
@@ -438,7 +451,20 @@ def save_states(self, filename, variables=None, **setting):
438451
variables: optional, dict, ArrayCollector
439452
The variables to save. If not provided, all variables retrieved by ``~.vars()`` will be used.
440453
"""
441-
raise errors.NoLongerSupportError('Use brainpy.checkpoints.save().')
454+
if variables is None:
455+
variables = self.vars(method='absolute', level=-1)
456+
457+
from brainpy.checkpoints import io
458+
if filename.endswith('.hdf5') or filename.endswith('.h5'):
459+
io.save_as_h5(filename, variables=variables)
460+
elif filename.endswith('.pkl') or filename.endswith('.pickle'):
461+
io.save_as_pkl(filename, variables=variables)
462+
elif filename.endswith('.npz'):
463+
io.save_as_npz(filename, variables=variables, **setting)
464+
elif filename.endswith('.mat'):
465+
io.save_as_mat(filename, variables=variables)
466+
else:
467+
raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}')
442468

443469
# def to(self, devices):
444470
# global math

brainpy/math/random.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,8 +429,9 @@ def __init__(self,
429429
'seed will be removed since 2.4.0', UserWarning)
430430

431431
if seed_or_key is None:
432-
key = DEFAULT.split_key()
433-
elif isinstance(seed_or_key, int):
432+
# key = DEFAULT.split_key()
433+
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
434+
if isinstance(seed_or_key, int):
434435
key = jr.PRNGKey(seed_or_key)
435436
else:
436437
if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:

0 commit comments

Comments
 (0)