Skip to content

Commit bf06081

Browse files
committed
fix bugs
1 parent 75f6fce commit bf06081

File tree

15 files changed

+59
-1663
lines changed

15 files changed

+59
-1663
lines changed

brainpy/_src/dyn/layers/pooling.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def __init__(
8282

8383
def update(self, *args):
8484
x = args[0] if len(args) == 1 else args[1]
85+
x = bm.as_jax(x)
8586
window_shape = self._infer_shape(x.ndim, self.kernel_size)
8687
stride = self._infer_shape(x.ndim, self.stride)
8788
padding = (self.padding
@@ -258,6 +259,7 @@ def __init__(
258259

259260
def update(self, *args):
260261
x = args[0] if len(args) == 1 else args[1]
262+
x = bm.as_jax(x)
261263
window_shape = self._infer_shape(x.ndim, self.kernel_size)
262264
strides = self._infer_shape(x.ndim, self.stride)
263265
padding = (self.padding if isinstance(self.padding, str) else
@@ -356,6 +358,7 @@ def __init__(
356358

357359
def update(self, *args):
358360
x = args[0] if len(args) == 1 else args[1]
361+
x = bm.as_jax(x)
359362
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
360363
if x.ndim < x_dim:
361364
raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
@@ -521,6 +524,7 @@ def __init__(
521524
class _AvgPoolNd(_MaxPoolNd):
522525
def update(self, *args):
523526
x = args[0] if len(args) == 1 else args[1]
527+
x = bm.as_jax(x)
524528
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
525529
if x.ndim < x_dim:
526530
raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
@@ -694,6 +698,7 @@ def _adaptive_pool1d(x, target_size: int, operation: Callable):
694698
Returns:
695699
A JAX array of shape `(target_size, )`.
696700
"""
701+
x = bm.as_jax(x)
697702
size = jnp.size(x)
698703
num_head_arrays = size % target_size
699704
num_block = size // target_size
@@ -767,6 +772,7 @@ def update(self, *args):
767772
or `(..., dim_1, dim_2)`.
768773
"""
769774
x = args[0] if len(args) == 1 else args[1]
775+
x = bm.as_jax(x)
770776

771777
# channel axis
772778
channel_axis = self.channel_axis

brainpy/_src/dyn/neurons/reduced_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ def __init__(
580580
b: Union[float, ArrayType, Initializer, Callable] = 1.,
581581
tau: Union[float, ArrayType, Initializer, Callable] = 10.,
582582
tau_w: Union[float, ArrayType, Initializer, Callable] = 30.,
583-
tau_ref: Optional[Union[float, ArrayType, Initializer, Callable]] = 30.,
583+
tau_ref: Optional[Union[float, ArrayType, Initializer, Callable]] = None,
584584
R: Union[float, ArrayType, Initializer, Callable] = 1.,
585585
V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(),
586586
w_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(),

brainpy/_src/initialize/random_inits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def __call__(self, shape, dtype=None):
329329
n_cols = np.prod(shape) // n_rows
330330
matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
331331
norm_dst = self.rng.normal(size=matrix_shape)
332-
q_mat, r_mat = jnp.linalg.qr(norm_dst)
332+
q_mat, r_mat = jnp.linalg.qr(bm.as_jax(norm_dst))
333333
# Enforce Q is uniformly distributed
334334
q_mat *= jnp.sign(jnp.diag(r_mat))
335335
if n_rows < n_cols:

brainpy/_src/losses/comparison.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
from typing import Tuple
1212

1313
import jax.numpy as jnp
14+
from jax.lax import scan
1415
from jax.scipy.special import logsumexp
1516
from jax.tree_util import tree_map
16-
from jax.lax import scan
1717

1818
import brainpy.math as bm
1919
from brainpy.types import ArrayType
@@ -106,7 +106,7 @@ def _cel(_pred, _tar):
106106
loss = logsumexp(bm.as_jax(_pred), axis=-1) - (_pred * _tar).sum(axis=-1)
107107
return _reduce(outputs=loss, reduction=reduction)
108108

109-
r = tree_map(_cel, predicts, targets, is_leaf=lambda x: isinstance(x, bm.Array))
109+
r = tree_map(_cel, predicts, targets, is_leaf=_is_leaf)
110110
return _multi_return(r)
111111

112112

@@ -128,7 +128,7 @@ def crs(_prd, _tar):
128128
logits = jnp.take_along_axis(_prd, _tar, -1).squeeze(-1)
129129
return logsumexp(bm.as_jax(_prd), axis=-1) - logits
130130

131-
r = tree_map(crs, predicts, targets, is_leaf=lambda x: isinstance(x, bm.Array))
131+
r = tree_map(crs, predicts, targets, is_leaf=_is_leaf)
132132
return _multi_return(r)
133133

134134

@@ -142,9 +142,14 @@ def cross_entropy_sigmoid(predicts, targets):
142142
Returns:
143143
(batch, ...) tensor of the cross-entropies for each entry.
144144
"""
145-
r = tree_map(lambda pred, tar: jnp.maximum(pred, 0) - pred * tar + jnp.log(1 + jnp.exp(-jnp.abs(pred))),
146-
predicts,
147-
targets)
145+
r = tree_map(
146+
lambda pred, tar: bm.as_jax(
147+
bm.maximum(pred, 0) - pred * tar + bm.log(1 + bm.exp(-bm.abs(pred)))
148+
),
149+
predicts,
150+
targets,
151+
is_leaf=_is_leaf
152+
)
148153
return _multi_return(r)
149154

150155

@@ -201,7 +206,7 @@ def loss(pred, tar):
201206
norm = jnp.linalg.norm(bm.as_jax(diff), ord=1, axis=1, keepdims=False)
202207
return _reduce(outputs=norm, reduction=reduction)
203208

204-
r = tree_map(loss, logits, targets, is_leaf=lambda x: isinstance(x, bm.Array))
209+
r = tree_map(loss, logits, targets, is_leaf=_is_leaf)
205210
return _multi_return(r)
206211

207212

@@ -228,7 +233,9 @@ def l2_loss(predicts, targets):
228233
----------
229234
.. [1] Bishop, Christopher M. 2006. Pattern Recognition and Machine Learning.
230235
"""
231-
r = tree_map(lambda pred, tar: 0.5 * (pred - tar) ** 2, predicts, targets)
236+
r = tree_map(lambda pred, tar: 0.5 * (pred - tar) ** 2,
237+
predicts,
238+
targets)
232239
return _multi_return(r)
233240

234241

@@ -243,7 +250,10 @@ def mean_absolute_error(x, y, axis=None, reduction: str = 'mean'):
243250
Returns:
244251
tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error.
245252
"""
246-
r = tree_map(lambda a, b: _reduce(jnp.abs(a - b), reduction=reduction, axis=axis), x, y)
253+
r = tree_map(lambda a, b: _reduce(bm.abs(a - b), reduction=reduction, axis=axis),
254+
x,
255+
y,
256+
is_leaf=_is_leaf)
247257
return _multi_return(r)
248258

249259

@@ -260,7 +270,8 @@ def mean_squared_error(predicts, targets, axis=None, reduction: str = 'mean'):
260270
"""
261271
r = tree_map(lambda a, b: _reduce((a - b) ** 2, reduction, axis=axis),
262272
predicts,
263-
targets)
273+
targets,
274+
is_leaf=_is_leaf)
264275
return _multi_return(r)
265276

266277

@@ -276,7 +287,9 @@ def mean_squared_log_error(predicts, targets, axis=None, reduction: str = 'mean'
276287
tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error.
277288
"""
278289
r = tree_map(lambda a, b: _reduce((jnp.log1p(a) - jnp.log1p(b)) ** 2, reduction, axis=axis),
279-
predicts, targets, is_leaf=_is_leaf)
290+
predicts,
291+
targets,
292+
is_leaf=_is_leaf)
280293
return _multi_return(r)
281294

282295

@@ -309,12 +322,13 @@ def huber_loss(predicts, targets, delta: float = 1.0):
309322
def _loss(y_predict, y_target):
310323
# 0.5 * err^2 if |err| <= d
311324
# 0.5 * d^2 + d * (|err| - d) if |err| > d
312-
diff = jnp.abs(y_predict - y_target)
313-
return jnp.where(diff > delta,
314-
delta * (diff - .5 * delta),
315-
0.5 * diff ** 2)
325+
diff = bm.abs(y_predict - y_target)
326+
r = bm.where(diff > delta,
327+
delta * (diff - .5 * delta),
328+
0.5 * diff ** 2)
329+
return bm.as_jax(r)
316330

317-
r = tree_map(_loss, targets, predicts)
331+
r = tree_map(_loss, targets, predicts, is_leaf=_is_leaf)
318332
return _multi_return(r)
319333

320334

@@ -382,7 +396,7 @@ def loss(pred, tar):
382396
log_not_p = bm.log_sigmoid(-pred)
383397
return -tar * log_p - (1. - tar) * log_not_p
384398

385-
r = tree_map(loss, logits, labels, is_leaf=lambda x: isinstance(x, bm.Array))
399+
r = tree_map(loss, logits, labels, is_leaf=_is_leaf)
386400
return _multi_return(r)
387401

388402

@@ -433,7 +447,7 @@ def loss(pred, tar):
433447
errors = bm.as_jax(pred - tar)
434448
return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype)
435449

436-
r = tree_map(loss, predicts, targets, is_leaf=lambda x: isinstance(x, bm.Array))
450+
r = tree_map(loss, predicts, targets, is_leaf=_is_leaf)
437451
return _multi_return(r)
438452

439453

brainpy/_src/math/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@
4141
# high-level numpy operations
4242
from .arraycreation import *
4343
from .arrayinterporate import *
44-
# from .arraycompatible import *
44+
from .arraycompatible import *
4545
from .others import *
46-
from . import random
46+
from . import random, linalg, fft
4747

4848
# operators
4949
from .operators import *

brainpy/_src/math/_utils.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Callable
55

66
import jax
7-
import numpy as np
87
from jax.tree_util import tree_map
98

109
from .ndarray import Array
@@ -18,8 +17,8 @@ def _as_jax_array_(obj):
1817
return obj.value if isinstance(obj, Array) else obj
1918

2019

21-
def _return(x):
22-
return Array(x) if _return_bp_array else x
20+
def _return(a):
21+
return Array(a) if isinstance(a, jax.Array) and a.ndim > 1 else a
2322

2423

2524
_return_bp_array = True
@@ -50,10 +49,6 @@ def wrap(op):
5049
return wrap
5150

5251

53-
def _as_brainpy_array(a):
54-
return Array(a) if isinstance(a, (np.ndarray, jax.Array)) else a
55-
56-
5752
def _is_leaf(a):
5853
return isinstance(a, Array)
5954

@@ -65,7 +60,7 @@ def new_fun(*args, **kwargs):
6560
if len(kwargs):
6661
kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf)
6762
r = fun(*args, **kwargs)
68-
return tree_map(_as_brainpy_array, r) if _return_bp_array else r
63+
return tree_map(_return, r) if _return_bp_array else r
6964

7065
new_fun.__doc__ = getattr(fun, "__doc__", None)
7166

brainpy/_src/math/object_transform/autograd.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
from jax.util import safe_map
1919

2020
from brainpy import errors, tools, check
21-
from brainpy._src.math.object_transform.base import BrainPyObject
22-
from brainpy._src.math.object_transform.abstract import ObjectTransform
2321
from brainpy._src.math.ndarray import Array, Variable, add_context, del_context
22+
from brainpy._src.math.object_transform.abstract import ObjectTransform
23+
from brainpy._src.math.object_transform.base import BrainPyObject
2424

2525
__all__ = [
2626
'grad', # gradient of scalar function
@@ -75,7 +75,7 @@ def __init__(
7575
_argnums = tuple(a + 2 for a in _argnums)
7676
if len(self._grad_vars) > 0:
7777
_argnums = (0,) + _argnums
78-
self.nonvar_argnums = argnums
78+
self._nonvar_argnums = argnums
7979
self.return_value = return_value
8080
self.has_aux = has_aux
8181

@@ -134,10 +134,12 @@ def __call__(self, *args, **kwargs):
134134
# old_dyn_vs = [v.value for v in self._dyn_vars]
135135
try:
136136
add_context(self.name)
137-
grads, (outputs, new_grad_vs, new_dyn_vs) = self._call([v.value for v in self._grad_vars],
138-
[v.value for v in self._dyn_vars],
139-
*args,
140-
**kwargs)
137+
grads, (outputs, new_grad_vs, new_dyn_vs) = self._call(
138+
[v.value for v in self._grad_vars],
139+
[v.value for v in self._dyn_vars],
140+
*args,
141+
**kwargs
142+
)
141143
del_context(self._name)
142144
except UnexpectedTracerError as e:
143145
del_context(self._name)
@@ -155,11 +157,11 @@ def __call__(self, *args, **kwargs):
155157

156158
# check returned grads
157159
if len(self._grad_vars) > 0:
158-
if self.nonvar_argnums is None:
160+
if self._nonvar_argnums is None:
159161
grads = self._grad_tree.unflatten(grads)
160162
else:
161163
var_grads = self._grad_tree.unflatten(grads[0])
162-
arg_grads = grads[1] if isinstance(self.nonvar_argnums, int) else grads[1:]
164+
arg_grads = grads[1] if isinstance(self._nonvar_argnums, int) else grads[1:]
163165
grads = (var_grads, arg_grads)
164166

165167
# check returned value

brainpy/_src/train/offline.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,6 @@ def __init__(
8484
for node in self.train_nodes:
8585
node.offline_fit_by = fit_method
8686

87-
# initialize the fitting method
88-
for node in self.train_nodes:
89-
node.offline_init()
90-
9187
def __repr__(self):
9288
name = self.__class__.__name__
9389
prefix = ' ' * len(name)

0 commit comments

Comments
 (0)