Skip to content

Commit fc487d1

Browse files
authored
Refine JIT wrappers for new JAX for comaptiblity with jax>=0.8.2 (#809)
* refactor: clean up imports and remove unnecessary newlines * refactor: remove unused imports and clean up main execution block * feat: add brainpy_state module and update dependencies
1 parent f690fb4 commit fc487d1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+224
-241
lines changed

brainpy/__init__.py

Lines changed: 6 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,14 @@
1717
__version__ = "2.7.5"
1818
__version_info__ = tuple(map(int, __version__.split(".")))
1919

20-
2120
from brainpy import _errors as errors
22-
from brainpy import mixin
2321
# fundamental supporting modules
2422
from brainpy import check, tools
2523
# Part: Math Foundation #
2624
# ----------------------- #
2725
# math foundation
2826
from brainpy import math
27+
from brainpy import mixin
2928
# Part: Toolbox #
3029
# --------------- #
3130
# modules of toolbox
@@ -109,11 +108,11 @@
109108
# ---------------- #
110109
from brainpy.train.base import (DSTrainer as DSTrainer, )
111110
from brainpy.train.back_propagation import (BPTT as BPTT,
112-
BPFF as BPFF, )
111+
BPFF as BPFF, )
113112
from brainpy.train.online import (OnlineTrainer as OnlineTrainer,
114-
ForceTrainer as ForceTrainer, )
113+
ForceTrainer as ForceTrainer, )
115114
from brainpy.train.offline import (OfflineTrainer as OfflineTrainer,
116-
RidgeTrainer as RidgeTrainer, )
115+
RidgeTrainer as RidgeTrainer, )
117116

118117
# Part: Analysis #
119118
# ---------------- #
@@ -147,42 +146,7 @@
147146

148147
optimizers = optim
149148

150-
try:
151-
import brainpy.state as state
152-
except:
153-
pass
154-
155149

150+
# New package
151+
from brainpy import state
156152

157-
if __name__ == '__main__':
158-
connect
159-
initialize, # weight initialization
160-
optim, # gradient descent optimizers
161-
losses, # loss functions
162-
measure, # methods for data analysis
163-
inputs, # methods for generating input currents
164-
encoding, # encoding schema
165-
checkpoints, # checkpoints
166-
check, # error checking
167-
mixin, # mixin classes
168-
algorithms, # online or offline training algorithms
169-
check, tools, errors, math
170-
BrainPyObject,
171-
integrators, ode, sde, fde
172-
Integrator, JointEq, IntegratorRunner, odeint, sdeint, fdeint
173-
DynamicalSystem, DynSysGroup, Sequential, Dynamic, Projection
174-
receive_update_input, receive_update_output, not_receive_update_input, not_receive_update_output
175-
VarDelay
176-
dnn, layers, dyn
177-
NeuGroup, NeuGroupNS
178-
share
179-
reset_level, reset_state, save_state, load_state, clear_input
180-
DSRunner, LoopOverTime, running
181-
DSTrainer, BPTT, BPFF, OnlineTrainer, ForceTrainer,
182-
OfflineTrainer, RidgeTrainer
183-
analysis
184-
visualize
185-
train
186-
channels, neurons, synapses, rates, synouts, synplast
187-
Base
188-
ArrayCollector, Collector, errors

brainpy/analysis/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,3 @@
3434
from .highdim.slow_points import *
3535
from .lowdim.lowdim_bifurcation import *
3636
from .lowdim.lowdim_phase_plane import *
37-

brainpy/analysis/highdim/slow_points.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from jax.scipy.optimize import minimize
2626

2727
import brainpy.math as bm
28-
from brainpy._errors import AnalyzerError, UnsupportedError
2928
from brainpy import optim, losses
29+
from brainpy._errors import AnalyzerError, UnsupportedError
3030
from brainpy.analysis import utils, base, constants
3131
from brainpy.context import share
3232
from brainpy.deprecations import _input_deprecate_msg

brainpy/channels.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717
This module has been deprecated since brainpy>=2.4.0. Use ``brainpy.dyn`` module instead.
1818
"""
1919

20-
2120
from .dyn.channels import *
2221
from .dyn.ions import *
2322

24-
2523
if __name__ == '__main__':
2624
IL
2725
Potassium

brainpy/check.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,4 +647,3 @@ def true_err_fun(arg, transforms):
647647
cond(remove_vmap(as_jax(pred)),
648648
lambda: jax.pure_callback(true_err_fun, None),
649649
lambda: None)
650-

brainpy/connect/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
import jax.numpy as jnp
2121
import numpy as onp
2222

23-
from brainpy._errors import ConnectorError
2423
from brainpy import tools, math as bm
24+
from brainpy._errors import ConnectorError
2525

2626
__all__ = [
2727
# the connection types

brainpy/connect/custom_conn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import jax.numpy as jnp
1818
import numpy as np
1919

20-
from brainpy._errors import ConnectorError
2120
from brainpy import math as bm, tools
21+
from brainpy._errors import ConnectorError
2222
from .base import *
2323

2424
__all__ = [

brainpy/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import Any, Union
2222

2323
import brainstate
24+
2425
from brainpy.math.defaults import env
2526
from brainpy.tools.dicts import DotDict
2627

brainpy/delay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
import jax.numpy as jnp
2525
import numpy as np
2626

27-
from brainpy.mixin import ParamDesc, ReturnInfo, JointType, SupportAutoDelay
2827
from brainpy import check, math as bm
2928
from brainpy.check import jit_error
3029
from brainpy.context import share
3130
from brainpy.dynsys import DynamicalSystem
3231
from brainpy.initialize import variable_
3332
from brainpy.math.delayvars import ROTATE_UPDATE, CONCAT_UPDATE
33+
from brainpy.mixin import ParamDesc, ReturnInfo, JointType, SupportAutoDelay
3434

3535
__all__ = [
3636
'Delay',

brainpy/dnn/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,3 @@
2222
from .linear import *
2323
from .normalization import *
2424
from .pooling import *
25-

0 commit comments

Comments
 (0)