File tree Expand file tree Collapse file tree 1 file changed +32
-0
lines changed
Expand file tree Collapse file tree 1 file changed +32
-0
lines changed Original file line number Diff line number Diff line change @@ -112,6 +112,9 @@ def create_function(fun_name, args=None):
112112 elif fun_name == "gelu" :
113113 fx = gelu
114114 dfx = d_gelu
115+ elif fun_name == "telu" :
116+ fx = telu
117+ dfx = d_telu
115118 elif fun_name == "softplus" :
116119 fx = softplus
117120 dfx = d_softplus
@@ -294,6 +297,35 @@ def d_relu(x):
294297 return (x >= 0. ).astype (jnp .float32 )
295298
296299@jit
300+ def telu (x ):
301+ """
302+ Proposed by Fernandez and Mali 24, https://arxiv.org/abs/2412.20269 and https://arxiv.org/abs/2402.02790
303+ TeLU activation: f(x) = x * tanh(e^x)
304+
305+ Args:
306+ x: input (tensor) value
307+
308+ Returns:
309+ output (tensor) value
310+ """
311+ return x * jnp .tanh (jnp .exp (x ))
312+
313+ @jit
314+ def d_telu (x ):
315+ """
316+
317+ Derivative of TeLU: f'(x) = tanh(e^x) + x * e^x * (1 - tanh^2(e^x))
318+
319+ Args:
320+ x: input (tensor) value
321+
322+ Returns:
323+ output (tensor) derivative value (with respect to input)
324+ """
325+ ex = jnp .exp (x )
326+ tanh_ex = jnp .tanh (ex )
327+ return tanh_ex + x * ex * (1.0 - tanh_ex ** 2 )
328+ @jit
297329def sine (x , omega_0 = 30 ):
298330 """
299331 f(x) = sin(x * omega_0).
You can’t perform that action at this time.
0 commit comments