|
27 | 27 | from brainpy.base.base import Base |
28 | 28 | from brainpy.base.collector import TensorCollector |
29 | 29 | from brainpy.math.random import RandomState |
| 30 | +from brainpy.math.jaxarray import JaxArray |
30 | 31 | from brainpy.tools.codes import change_func_name |
31 | 32 |
|
32 | 33 | __all__ = [ |
@@ -77,7 +78,7 @@ def vmap(func, dyn_vars=None, batched_vars=None, |
77 | 78 | ---------- |
78 | 79 | func : Base, function, callable |
79 | 80 | The function or the module to compile. |
80 | | - dyn_vars : dict |
| 81 | + dyn_vars : dict, sequence |
81 | 82 | batched_vars : dict |
82 | 83 | in_axes : optional, int, sequence of int |
83 | 84 | Specify which input array axes to map over. If each positional argument to |
@@ -207,13 +208,19 @@ def vmap(func, dyn_vars=None, batched_vars=None, |
207 | 208 | axis_name=axis_name) |
208 | 209 |
|
209 | 210 | else: |
| 211 | + if isinstance(dyn_vars, JaxArray): |
| 212 | + dyn_vars = [dyn_vars] |
| 213 | + if isinstance(dyn_vars, (tuple, list)): |
| 214 | + dyn_vars = {f'_vmap_v{i}': v for i, v in enumerate(dyn_vars)} |
| 215 | + assert isinstance(dyn_vars, dict) |
| 216 | + |
210 | 217 | # dynamical variables |
211 | | - dyn_vars, rand_vars = TensorCollector(), TensorCollector() |
| 218 | + _dyn_vars, _rand_vars = TensorCollector(), TensorCollector() |
212 | 219 | for key, val in dyn_vars.items(): |
213 | 220 | if isinstance(val, RandomState): |
214 | | - rand_vars[key] = val |
| 221 | + _rand_vars[key] = val |
215 | 222 | else: |
216 | | - dyn_vars[key] = val |
| 223 | + _dyn_vars[key] = val |
217 | 224 |
|
218 | 225 | # in axes |
219 | 226 | if in_axes is None: |
@@ -249,8 +256,8 @@ def vmap(func, dyn_vars=None, batched_vars=None, |
249 | 256 |
|
250 | 257 | # jit function |
251 | 258 | return _make_vmap(func=func, |
252 | | - dyn_vars=dyn_vars, |
253 | | - rand_vars=rand_vars, |
| 259 | + dyn_vars=_dyn_vars, |
| 260 | + rand_vars=_rand_vars, |
254 | 261 | in_axes=in_axes, |
255 | 262 | out_axes=out_axes, |
256 | 263 | axis_name=axis_name, |
|
0 commit comments