Skip to content

Commit 75d92fb

Browse files
committed
fig bug
1 parent b5215cf commit 75d92fb

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

brainpy/dyn/rates/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from .base import RateModel
77

88
__all__ = [
9-
''
109
]
1110

1211
class JansenRitModel(RateModel):

brainpy/math/parallels.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from brainpy.base.base import Base
2828
from brainpy.base.collector import TensorCollector
2929
from brainpy.math.random import RandomState
30+
from brainpy.math.jaxarray import JaxArray
3031
from brainpy.tools.codes import change_func_name
3132

3233
__all__ = [
@@ -77,7 +78,7 @@ def vmap(func, dyn_vars=None, batched_vars=None,
7778
----------
7879
func : Base, function, callable
7980
The function or the module to compile.
80-
dyn_vars : dict
81+
dyn_vars : dict, sequence
8182
batched_vars : dict
8283
in_axes : optional, int, sequence of int
8384
Specify which input array axes to map over. If each positional argument to
@@ -207,13 +208,19 @@ def vmap(func, dyn_vars=None, batched_vars=None,
207208
axis_name=axis_name)
208209

209210
else:
211+
if isinstance(dyn_vars, JaxArray):
212+
dyn_vars = [dyn_vars]
213+
if isinstance(dyn_vars, (tuple, list)):
214+
dyn_vars = {f'_vmap_v{i}': v for i, v in enumerate(dyn_vars)}
215+
assert isinstance(dyn_vars, dict)
216+
210217
# dynamical variables
211-
dyn_vars, rand_vars = TensorCollector(), TensorCollector()
218+
_dyn_vars, _rand_vars = TensorCollector(), TensorCollector()
212219
for key, val in dyn_vars.items():
213220
if isinstance(val, RandomState):
214-
rand_vars[key] = val
221+
_rand_vars[key] = val
215222
else:
216-
dyn_vars[key] = val
223+
_dyn_vars[key] = val
217224

218225
# in axes
219226
if in_axes is None:
@@ -249,8 +256,8 @@ def vmap(func, dyn_vars=None, batched_vars=None,
249256

250257
# jit function
251258
return _make_vmap(func=func,
252-
dyn_vars=dyn_vars,
253-
rand_vars=rand_vars,
259+
dyn_vars=_dyn_vars,
260+
rand_vars=_rand_vars,
254261
in_axes=in_axes,
255262
out_axes=out_axes,
256263
axis_name=axis_name,

0 commit comments

Comments
 (0)