Skip to content

Commit e6a2dbe

Browse files
authored
Merge pull request #323 from chaoming0625/master
Fix jax import error when `jax>=0.4.2`
2 parents c63447b + 9b13321 commit e6a2dbe

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+534
-4348
lines changed

brainpy/_src/analysis/highdim/slow_points.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def filter_loss(self, tolerance: float = 1e-5):
462462
else:
463463
num_fps = self._fixed_points.shape[0]
464464
ids = self._losses < tolerance
465-
keep_ids = bm.as_jax(jnp.where(ids)[0])
465+
keep_ids = bm.as_jax(bm.where(ids)[0])
466466
self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points)
467467
self._losses = self._losses[keep_ids]
468468
self._selected_ids = self._selected_ids[keep_ids]

brainpy/_src/dyn/base.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import gc
55
from typing import Union, Dict, Callable, Sequence, Optional, Tuple, Any
66

7+
import jax
78
import jax.numpy as jnp
89
import numpy as np
910

@@ -18,8 +19,6 @@
1819
from brainpy.errors import NoImplementationError, UnsupportedError
1920
from brainpy.types import ArrayType, Shape
2021

21-
22-
2322
__all__ = [
2423
# general class
2524
'DynamicalSystem',
@@ -170,14 +169,14 @@ def register_delay(
170169
raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support '
171170
f'integer, array of integers, callable function, brainpy.init.Initializer.')
172171
if delay_type == 'heter':
173-
if delay_step.dtype not in [jnp.int32, jnp.int64]:
172+
if delay_step.dtype not in [bm.int32, bm.int64]:
174173
raise ValueError('Only support delay steps of int32, int64. If your '
175174
'provide delay time length, please divide the "dt" '
176175
'then provide us the number of delay steps.')
177176
if delay_target.shape[0] != delay_step.shape[0]:
178177
raise ValueError(f'Shape is mismatched: {delay_target.shape[0]} != {delay_step.shape[0]}')
179178
if delay_type != 'none':
180-
max_delay_step = int(jnp.max(delay_step))
179+
max_delay_step = int(bm.max(delay_step))
181180

182181
# delay target
183182
if delay_type != 'none':
@@ -207,8 +206,8 @@ def register_delay(
207206
def get_delay_data(
208207
self,
209208
identifier: str,
210-
delay_step: Optional[Union[int, bm.Array, jnp.DeviceArray]],
211-
*indices: Union[int, slice, bm.Array, jnp.DeviceArray],
209+
delay_step: Optional[Union[int, bm.Array, jax.Array]],
210+
*indices: Union[int, slice, bm.Array, jax.Array],
212211
):
213212
"""Get delay data according to the provided delay steps.
214213
@@ -230,19 +229,19 @@ def get_delay_data(
230229
return self.global_delay_data[identifier][1].value
231230

232231
if identifier in self.global_delay_data:
233-
if jnp.ndim(delay_step) == 0:
232+
if bm.ndim(delay_step) == 0:
234233
return self.global_delay_data[identifier][0](delay_step, *indices)
235234
else:
236235
if len(indices) == 0:
237-
indices = (jnp.arange(delay_step.size),)
236+
indices = (bm.arange(delay_step.size),)
238237
return self.global_delay_data[identifier][0](delay_step, *indices)
239238

240239
elif identifier in self.local_delay_vars:
241-
if jnp.ndim(delay_step) == 0:
240+
if bm.ndim(delay_step) == 0:
242241
return self.local_delay_vars[identifier](delay_step)
243242
else:
244243
if len(indices) == 0:
245-
indices = (jnp.arange(delay_step.size),)
244+
indices = (bm.arange(delay_step.size),)
246245
return self.local_delay_vars[identifier](delay_step, *indices)
247246

248247
else:
@@ -878,7 +877,7 @@ def __init__(
878877
# ------------
879878
if isinstance(conn, TwoEndConnector):
880879
self.conn = conn(pre.size, post.size)
881-
elif isinstance(conn, (bm.ndarray, np.ndarray, jnp.ndarray)):
880+
elif isinstance(conn, (bm.ndarray, np.ndarray, jax.Array)):
882881
if (pre.num, post.num) != conn.shape:
883882
raise ValueError(f'"conn" is provided as a matrix, and it is expected '
884883
f'to be an array with shape of (pre.num, post.num) = '
@@ -1157,11 +1156,11 @@ def _init_weights(
11571156
return weight, conn_mask
11581157

11591158
def _syn2post_with_all2all(self, syn_value, syn_weight):
1160-
if jnp.ndim(syn_weight) == 0:
1159+
if bm.ndim(syn_weight) == 0:
11611160
if isinstance(self.mode, bm.BatchingMode):
1162-
post_vs = jnp.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:])
1161+
post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:])
11631162
else:
1164-
post_vs = jnp.sum(syn_value)
1163+
post_vs = bm.sum(syn_value)
11651164
if not self.conn.include_self:
11661165
post_vs = post_vs - syn_value
11671166
post_vs = syn_weight * post_vs
@@ -1173,7 +1172,7 @@ def _syn2post_with_one2one(self, syn_value, syn_weight):
11731172
return syn_value * syn_weight
11741173

11751174
def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
1176-
if jnp.ndim(syn_weight) == 0:
1175+
if bm.ndim(syn_weight) == 0:
11771176
post_vs = (syn_weight * syn_value) @ conn_mat
11781177
else:
11791178
post_vs = syn_value @ (syn_weight * conn_mat)
@@ -1253,8 +1252,8 @@ def __init__(
12531252

12541253
# variables
12551254
self.V = variable(V_initializer, self.mode, self.varshape)
1256-
self.input = variable(jnp.zeros, self.mode, self.varshape)
1257-
self.spike = variable(lambda s: jnp.zeros(s, dtype=bool), self.mode, self.varshape)
1255+
self.input = variable(bm.zeros, self.mode, self.varshape)
1256+
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), self.mode, self.varshape)
12581257

12591258
# function
12601259
if self.noise is None:
@@ -1271,8 +1270,8 @@ def derivative(self, V, t):
12711270

12721271
def reset_state(self, batch_size=None):
12731272
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
1274-
self.spike.value = variable(lambda s: jnp.zeros(s, dtype=bool), batch_size, self.varshape)
1275-
self.input.value = variable(jnp.zeros, batch_size, self.varshape)
1273+
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
1274+
self.input.value = variable(bm.zeros, batch_size, self.varshape)
12761275
for channel in self.nodes(level=1, include_self=False).subset(Channel).unique().values():
12771276
channel.reset_state(self.V.value, batch_size=batch_size)
12781277

@@ -1286,7 +1285,7 @@ def update(self, tdi, *args, **kwargs):
12861285
# update variables
12871286
for node in channels.values():
12881287
node.update(tdi, self.V.value)
1289-
self.spike.value = jnp.logical_and(V >= self.V_th, self.V < self.V_th)
1288+
self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th)
12901289
self.V.value = V
12911290

12921291
def register_implicit_nodes(self, *channels, **named_channels):
@@ -1295,7 +1294,7 @@ def register_implicit_nodes(self, *channels, **named_channels):
12951294

12961295
def clear_input(self):
12971296
"""Useful for monitoring inputs. """
1298-
self.input.value = jnp.zeros_like(self.input.value)
1297+
self.input.value = bm.zeros_like(self.input.value)
12991298

13001299

13011300
class Channel(DynamicalSystem):

brainpy/_src/dyn/channels/IH.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from typing import Union, Callable
99

10-
import jax.numpy as jnp
1110
import brainpy.math as bm
1211
from brainpy._src.initialize import Initializer, parameter, variable
1312
from brainpy._src.integrators import odeint, JointEq
@@ -76,7 +75,7 @@ def __init__(
7675
self.E = parameter(E, self.varshape, allow_none=False)
7776

7877
# variable
79-
self.p = variable(jnp.zeros, self.mode, self.varshape)
78+
self.p = variable(bm.zeros, self.mode, self.varshape)
8079

8180
# function
8281
self.integral = odeint(self.derivative, method=method)
@@ -96,10 +95,10 @@ def current(self, V):
9695
return self.g_max * self.p * (self.E - V)
9796

9897
def f_p_inf(self, V):
99-
return 1. / (1. + jnp.exp((V + 75.) / 5.5))
98+
return 1. / (1. + bm.exp((V + 75.) / 5.5))
10099

101100
def f_p_tau(self, V):
102-
return 1. / (jnp.exp(-0.086 * V - 14.59) + jnp.exp(0.0701 * V - 1.87))
101+
return 1. / (bm.exp(-0.086 * V - 14.59) + bm.exp(0.0701 * V - 1.87))
103102

104103

105104
class Ih_De1996(IhChannel, CalciumChannel):
@@ -200,9 +199,9 @@ def __init__(
200199
self.g_inc = parameter(g_inc, self.varshape, allow_none=False)
201200

202201
# variable
203-
self.O = variable(jnp.zeros, self.mode, self.varshape)
204-
self.OL = variable(jnp.zeros, self.mode, self.varshape)
205-
self.P1 = variable(jnp.zeros, self.mode, self.varshape)
202+
self.O = variable(bm.zeros, self.mode, self.varshape)
203+
self.OL = variable(bm.zeros, self.mode, self.varshape)
204+
self.P1 = variable(bm.zeros, self.mode, self.varshape)
206205

207206
# function
208207
self.integral = odeint(JointEq(self.dO, self.dOL, self.dP1), method=method)
@@ -229,7 +228,7 @@ def current(self, V, C_Ca, E_Ca):
229228

230229
def reset_state(self, V, C_Ca, E_Ca, batch_size=None):
231230
varshape = self.varshape if (batch_size is None) else ((batch_size,) + self.varshape)
232-
self.P1.value = jnp.broadcast_to(self.k1 * C_Ca ** 4 / (self.k1 * C_Ca ** 4 + self.k2), varshape)
231+
self.P1.value = bm.broadcast_to(self.k1 * C_Ca ** 4 / (self.k1 * C_Ca ** 4 + self.k2), varshape)
233232
inf = self.f_inf(V)
234233
tau = self.f_tau(V)
235234
alpha = inf / tau
@@ -242,8 +241,8 @@ def reset_state(self, V, C_Ca, E_Ca, batch_size=None):
242241
assert self.OL.shape[0] == batch_size
243242

244243
def f_inf(self, V):
245-
return 1 / (1 + jnp.exp((V + 75 - self.V_sh) / 5.5))
244+
return 1 / (1 + bm.exp((V + 75 - self.V_sh) / 5.5))
246245

247246
def f_tau(self, V):
248-
return (20. + 1000 / (jnp.exp((V + 71.5 - self.V_sh) / 14.2) +
249-
jnp.exp(-(V + 89 - self.V_sh) / 11.6))) / self.phi
247+
return (20. + 1000 / (bm.exp((V + 71.5 - self.V_sh) / 14.2) +
248+
bm.exp(-(V + 89 - self.V_sh) / 11.6))) / self.phi

0 commit comments

Comments
 (0)