Skip to content

Commit 11083a0

Browse files
authored
fix a bug in math.activations.py (#196)
fix a bug in math.activations.py
2 parents a7e03c1 + 0bcf9da commit 11083a0

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

brainpy/math/activations.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020

2121
from brainpy.math.jaxarray import JaxArray
22+
from brainpy.math.numpy_ops import tanh
2223

2324
__all__ = [
2425
'celu',
@@ -63,7 +64,7 @@ def get(activation):
6364
return global_vars[activation]
6465

6566

66-
tanh = jnp.tanh
67+
# tanh = jnp.tanh
6768
identity = lambda x: x
6869

6970

@@ -146,7 +147,7 @@ def gelu(x, approximate=True):
146147
x = x.value if isinstance(x, JaxArray) else x
147148
if approximate:
148149
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
149-
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x ** 3))))
150+
cdf = 0.5 * (1.0 + tanh(sqrt_2_over_pi * (x + 0.044715 * (x ** 3))))
150151
y = x * cdf
151152
else:
152153
y = jnp.array(x * (jax.lax.erf(x / np.sqrt(2)) + 1) / 2, dtype=x.dtype)

0 commit comments

Comments
 (0)