File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change 1919import numpy as np
2020
2121from brainpy .math .jaxarray import JaxArray
22+ from brainpy .math .numpy_ops import tanh
2223
2324__all__ = [
2425 'celu' ,
@@ -63,7 +64,7 @@ def get(activation):
6364 return global_vars [activation ]
6465
6566
66- tanh = jnp .tanh
67+ # tanh = jnp.tanh
6768identity = lambda x : x
6869
6970
@@ -146,7 +147,7 @@ def gelu(x, approximate=True):
146147 x = x .value if isinstance (x , JaxArray ) else x
147148 if approximate :
148149 sqrt_2_over_pi = np .sqrt (2 / np .pi ).astype (x .dtype )
149- cdf = 0.5 * (1.0 + jnp . tanh (sqrt_2_over_pi * (x + 0.044715 * (x ** 3 ))))
150+ cdf = 0.5 * (1.0 + tanh (sqrt_2_over_pi * (x + 0.044715 * (x ** 3 ))))
150151 y = x * cdf
151152 else :
152153 y = jnp .array (x * (jax .lax .erf (x / np .sqrt (2 )) + 1 ) / 2 , dtype = x .dtype )
You can’t perform that action at this time.
0 commit comments