@@ -453,6 +453,47 @@ def inverse_logistic(x, clip_bound=0.03): # 0.03
453453 x_ = jnp .clip (x_ , clip_bound , 1.0 - clip_bound )
454454 return jnp .log ( x_ / ((1.0 - x_ ) + 1e-6 ) )
455455
456+ @jit
457+ def swish (x , beta ):
458+ """
459+ Applies the Swish parameterized activation, proposed in Ramachandran et al., 2017
460+ ("Searching for Activation Functions").
461+
462+ Args:
463+ x: data to transform via inverse logistic function
464+
465+ beta: coefficient/parameters to weight input x by
466+
467+ Returns:
468+ output of the Swish activation
469+ """
470+ return x * sigmoid (x * beta )
471+
472+ @jit
473+ def d_swish (x , beta ):
474+ # df/dx = beta * [ 1/(exp(-x) + 1) + (exp(-x) * x) / (exp(-x) + 1)^2]
475+ # df/dx = beta * sigmoid(x * beta) * (1 - sigmoid(x) * beta)
476+ exp_neg_x = jnp .exp (- x )
477+ _x = (1. / (exp_neg_x + 1. )) + (exp_neg_x * x )/ jnp .square (exp_neg_x + 1 )
478+ return _x * beta
479+
480+ @jit
481+ def silu (x ):
482+ """
483+ Applies the sigmoid-weighted linear unit (SiLU or SiL) activation.
484+
485+ Args:
486+ x: data to transform via inverse logistic function
487+
488+ Returns:
489+ output of the Swish activation
490+ """
491+ return swish (x , beta = 1. )
492+
493+ @jit
494+ def d_silu (x ):
495+ return d_swish (x , beta = 1. )
496+
456497@jit
457498def gelu (x ):
458499 """
@@ -464,14 +505,33 @@ def gelu(x):
464505 Returns:
465506 output of the GeLU activation
466507 """
467- return x * sigmoid ( x * 1.702 ) ## approximate GeLU
508+ return swish ( x , beta = 1.702 ) ## approximate GeLU # beta=1.4
468509
469510@jit
470511def d_gelu (x ):
471512 # df/dx = 1.702 * [ 1/(exp(-x) + 1) + (exp(-x) * x) / (exp(-x) + 1)^2]
472- exp_neg_x = jnp .exp (- x )
473- _x = (1. / (exp_neg_x + 1. )) + (exp_neg_x * x )/ jnp .square (exp_neg_x + 1 )
474- return _x * 1.702
513+ return d_swish (x , beta = 1.702 ) # beta=1.4
514+
515+ @jit
516+ def elu (x , alpha = 1. ):
517+ """
518+ Applies the exponential linear unit (ELU) activation.
519+
520+ Args:
521+ x: data to transform via inverse logistic function
522+
523+ alpha: coefficient/parameters to weight input x by
524+
525+ Returns:
526+ output of the GeLU activation
527+ """
528+ mask = x >= 0.
529+ return x * mask + ((jnp .exp (x ) - 1 ) * alpha ) * (1. - mask )
530+
531+ @jit
532+ def elu (x , alpha = 1. ):
533+ mask = (x >= 0. )
534+ return mask + (1. - mask ) * (jnp .exp (x ) * alpha )
475535
476536@jit
477537def softmax (x , tau = 0.0 ):
@@ -553,24 +613,24 @@ def layer_normalize(x, shift=0., scale=1.):
553613 return _x * scale + shift
554614
555615@jit
556- def drop_out (dkey , input , rate = 0.0 ):
616+ def drop_out (dkey , data , rate = 0.0 ):
557617 """
558618 Applies a drop-out transform to an input matrix.
559619
560620 Args:
561621 dkey: Jax randomness key for this operator
562622
563- input: data to apply random/drop-out mask to
623+ data: input data to apply random/drop-out mask to
564624
565625 rate: probability of a dimension being dropped
566626
567627 Returns:
568628 output as well as binary mask
569629 """
570- eps = random .uniform (dkey , shape = input .shape , minval = 0.0 , maxval = 1.0 )
630+ eps = random .uniform (dkey , shape = data .shape , minval = 0.0 , maxval = 1.0 )
571631 mask = (eps <= (1.0 - rate )).astype (jnp .float32 )
572632 mask = mask * (1.0 / (1.0 - rate )) ## apply inverted dropout scheme
573- output = input * mask
633+ output = data * mask
574634 return output , mask
575635
576636
0 commit comments