Skip to content

Commit 1fbbf93

Browse files
author
Alexander Ororbia
committed
added silu/swish/elu to model_utils
1 parent 53ed773 commit 1fbbf93

File tree

1 file changed

+68
-8
lines changed

1 file changed

+68
-8
lines changed

ngclearn/utils/model_utils.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,47 @@ def inverse_logistic(x, clip_bound=0.03): # 0.03
453453
x_ = jnp.clip(x_, clip_bound, 1.0 - clip_bound)
454454
return jnp.log( x_/((1.0 - x_) + 1e-6) )
455455

456+
@jit
457+
def swish(x, beta):
458+
"""
459+
Applies the Swish parameterized activation, proposed in Ramachandran et al., 2017
460+
("Searching for Activation Functions").
461+
462+
Args:
463+
x: data to transform via inverse logistic function
464+
465+
beta: coefficient/parameters to weight input x by
466+
467+
Returns:
468+
output of the Swish activation
469+
"""
470+
return x * sigmoid(x * beta)
471+
472+
@jit
473+
def d_swish(x, beta):
474+
# df/dx = beta * [ 1/(exp(-x) + 1) + (exp(-x) * x) / (exp(-x) + 1)^2]
475+
# df/dx = beta * sigmoid(x * beta) * (1 - sigmoid(x) * beta)
476+
exp_neg_x = jnp.exp(-x)
477+
_x = (1./(exp_neg_x + 1.)) + (exp_neg_x * x)/jnp.square(exp_neg_x+1)
478+
return _x * beta
479+
480+
@jit
481+
def silu(x):
482+
"""
483+
Applies the sigmoid-weighted linear unit (SiLU or SiL) activation.
484+
485+
Args:
486+
x: data to transform via inverse logistic function
487+
488+
Returns:
489+
output of the Swish activation
490+
"""
491+
return swish(x, beta=1.)
492+
493+
@jit
494+
def d_silu(x):
495+
return d_swish(x, beta=1.)
496+
456497
@jit
457498
def gelu(x):
458499
"""
@@ -464,14 +505,33 @@ def gelu(x):
464505
Returns:
465506
output of the GeLU activation
466507
"""
467-
return x * sigmoid(x * 1.702) ## approximate GeLU
508+
return swish(x, beta=1.702) ## approximate GeLU # beta=1.4
468509

469510
@jit
470511
def d_gelu(x):
471512
# df/dx = 1.702 * [ 1/(exp(-x) + 1) + (exp(-x) * x) / (exp(-x) + 1)^2]
472-
exp_neg_x = jnp.exp(-x)
473-
_x = (1./(exp_neg_x + 1.)) + (exp_neg_x * x)/jnp.square(exp_neg_x+1)
474-
return _x * 1.702
513+
return d_swish(x, beta=1.702) # beta=1.4
514+
515+
@jit
516+
def elu(x, alpha=1.):
517+
"""
518+
Applies the exponential linear unit (ELU) activation.
519+
520+
Args:
521+
x: data to transform via inverse logistic function
522+
523+
alpha: coefficient/parameters to weight input x by
524+
525+
Returns:
526+
output of the GeLU activation
527+
"""
528+
mask = x >= 0.
529+
return x * mask + ((jnp.exp(x) - 1) * alpha) * (1. - mask)
530+
531+
@jit
532+
def elu(x, alpha=1.):
533+
mask = (x >= 0.)
534+
return mask + (1. - mask) * (jnp.exp(x) * alpha)
475535

476536
@jit
477537
def softmax(x, tau=0.0):
@@ -553,24 +613,24 @@ def layer_normalize(x, shift=0., scale=1.):
553613
return _x * scale + shift
554614

555615
@jit
556-
def drop_out(dkey, input, rate=0.0):
616+
def drop_out(dkey, data, rate=0.0):
557617
"""
558618
Applies a drop-out transform to an input matrix.
559619
560620
Args:
561621
dkey: Jax randomness key for this operator
562622
563-
input: data to apply random/drop-out mask to
623+
data: input data to apply random/drop-out mask to
564624
565625
rate: probability of a dimension being dropped
566626
567627
Returns:
568628
output as well as binary mask
569629
"""
570-
eps = random.uniform(dkey, shape=input.shape, minval=0.0, maxval=1.0)
630+
eps = random.uniform(dkey, shape=data.shape, minval=0.0, maxval=1.0)
571631
mask = (eps <= (1.0 - rate)).astype(jnp.float32)
572632
mask = mask * (1.0 / (1.0 - rate)) ## apply inverted dropout scheme
573-
output = input * mask
633+
output = data * mask
574634
return output, mask
575635

576636

0 commit comments

Comments
 (0)