Skip to content

Commit 2336904

Browse files
committed
Add parameter management module with bijective transforms and state containers
1 parent cf67678 commit 2336904

18 files changed

+2577
-648
lines changed

braintools/init/_init_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def __init__(
444444
self.unit = unit
445445

446446
def __call__(self, size, **kwargs):
447-
rng = kwargs.get('rng', brainstate.random)
447+
rng = kwargs.get('rng', np.random)
448448
mean, unit = u.split_mantissa_unit(self.mean)
449449
std = u.Quantity(self.std).to(unit).mantissa
450450

braintools/optim/_optax_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3553,7 +3553,7 @@ class Rprop(OptaxOptimizer):
35533553
>>> # Setup for batch learning
35543554
>>> model = brainstate.nn.Sequential(
35553555
... brainstate.nn.Linear(100, 50),
3556-
... brainstate.nn.Tanh(),
3556+
... brainstate.nn.TanhT(),
35573557
... brainstate.nn.Linear(50, 10)
35583558
... )
35593559
>>>

braintools/param/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,21 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
from ._base import *
17-
from ._base import __all__ as base_all
16+
from ._data import *
17+
from ._data import __all__ as data_all
18+
from ._module import *
19+
from ._module import __all__ as module_all
20+
from ._regularization import *
21+
from ._regularization import __all__ as reg_all
1822
from ._state import *
1923
from ._state import __all__ as state_all
2024
from ._transform import *
2125
from ._transform import __all__ as transform_all
2226

23-
__all__ = state_all + transform_all + base_all
27+
__all__ = state_all + transform_all + reg_all + module_all + data_all
2428

25-
del transform_all, base_all, state_all
29+
del transform_all
30+
del state_all
31+
del reg_all
32+
del module_all
33+
del data_all

braintools/param/_base.py

Lines changed: 0 additions & 201 deletions
This file was deleted.

0 commit comments

Comments
 (0)