Skip to content

Commit 1a8d26c

Browse files
authored
Updated model_utils.py with Telu (#114)
Added TeLU activation, proposed by our group
1 parent 2d8c6e4 commit 1a8d26c

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

ngclearn/utils/model_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ def create_function(fun_name, args=None):
112112
elif fun_name == "gelu":
113113
fx = gelu
114114
dfx = d_gelu
115+
elif fun_name == "telu":
116+
fx = telu
117+
dfx = d_telu
115118
elif fun_name == "softplus":
116119
fx = softplus
117120
dfx = d_softplus
@@ -294,6 +297,35 @@ def d_relu(x):
294297
return (x >= 0.).astype(jnp.float32)
295298

296299
@jit
300+
def telu(x):
301+
"""
302+
Proposed by Fernandez and Mali 24, https://arxiv.org/abs/2412.20269 and https://arxiv.org/abs/2402.02790
303+
TeLU activation: f(x) = x * tanh(e^x)
304+
305+
Args:
306+
x: input (tensor) value
307+
308+
Returns:
309+
output (tensor) value
310+
"""
311+
return x * jnp.tanh(jnp.exp(x))
312+
313+
@jit
314+
def d_telu(x):
315+
"""
316+
317+
Derivative of TeLU: f'(x) = tanh(e^x) + x * e^x * (1 - tanh^2(e^x))
318+
319+
Args:
320+
x: input (tensor) value
321+
322+
Returns:
323+
output (tensor) derivative value (with respect to input)
324+
"""
325+
ex = jnp.exp(x)
326+
tanh_ex = jnp.tanh(ex)
327+
return tanh_ex + x * ex * (1.0 - tanh_ex ** 2)
328+
@jit
297329
def sine(x, omega_0=30):
298330
"""
299331
f(x) = sin(x * omega_0).

0 commit comments

Comments
 (0)