Skip to content

Commit d5ea988

Browse files
committed
[bug] fix bug
1 parent 457bfcd commit d5ea988

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

brainpy/_src/analysis/utils/others.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from typing import Union, Dict
44

5+
import jax
56
import jax.numpy as jnp
67
import numpy as np
7-
from jax import vmap
88
from jax.tree_util import tree_map
99

1010
import brainpy.math as bm
@@ -80,7 +80,7 @@ def get_sign(f, xs, ys):
8080

8181
def get_sign2(f, *xyz, args=()):
8282
in_axes = tuple(range(len(xyz))) + tuple([None] * len(args))
83-
f = jax.jit(vmap(f_without_jaxarray_return(f), in_axes=in_axes))
83+
f = jax.jit(jax.vmap(f_without_jaxarray_return(f), in_axes=in_axes))
8484
xyz = tuple((v.value if isinstance(v, bm.Array) else v) for v in xyz)
8585
XYZ = jnp.meshgrid(*xyz)
8686
XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ)

0 commit comments

Comments
 (0)