@@ -72,18 +72,20 @@ def create_function(fun_name, args=None):
7272 Activation function creation routine.
7373
7474 Args:
75- fun_name: string name of activation function to produce
76- (Currently supports: "tanh", "relu", "lrelu", "identity")
75+ fun_name: string name of activation function to produce;
76+ Currently supports: "tanh", "bkwta" (binary K-winners-take-all), "sigmoid", "relu", "lrelu", "relu6",
77+ "elu", "silu", "gelu", "softplus", "softmax" (derivative not supported), "unit_threshold", "heaviside",
78+ "identity"
7779
7880 Returns:
7981 function fx, first derivative of function (w.r.t. input) dfx
8082 """
81- fx = None
82- dfx = None
83+ fx = None ## the function
84+ dfx = None ## the first derivative of function w.r.t. its input
8385 if fun_name == "tanh" :
8486 fx = tanh
8587 dfx = d_tanh
86- elif "kwta " in fun_name :
88+ elif "bkwta " in fun_name :
8789 fx = bkwta
8890 dfx = bkwta #d_identity
8991 elif fun_name == "sigmoid" :
@@ -98,6 +100,15 @@ def create_function(fun_name, args=None):
98100 elif fun_name == "relu6" :
99101 fx = relu6
100102 dfx = d_relu6
103+ elif fun_name == "elu" :
104+ fx = elu
105+ dfx = d_elu
106+ elif fun_name == "silu" :
107+ fx = silu
108+ dfx = d_silu
109+ elif fun_name == "gelu" :
110+ fx = gelu
111+ dfx = d_gelu
101112 elif fun_name == "softplus" :
102113 fx = softplus
103114 dfx = d_softplus
@@ -127,35 +138,35 @@ def bkwta(x, nWTA=5): #5 10 15 #K=50):
127138 return topK
128139
129140@partial (jit , static_argnums = [2 , 3 , 4 ])
130- def normalize_matrix (M , wnorm , order = 1 , axis = 0 , scale = 1. ):
141+ def normalize_matrix (data , wnorm , order = 1 , axis = 0 , scale = 1. ):
131142 """
132143 Normalizes the values in matrix to have a particular norm across each vector span.
133144
134145 Args:
135- M : (2D) matrix to normalize
146+ data : (2D) data matrix to normalize
136147
137- wnorm: target norm for each
148+ wnorm: target norm for each row/column of data matrix
138149
139150 order: order of norm to use in normalization (Default: 1);
140151 note that `ord=1` results in the L1-norm, `ord=2` results in the L2-norm
141152
142153 axis: 0 (apply to column vectors), 1 (apply to row vectors)
143154
144- scale: step modifier to produce the projected matrix
155+ scale: step modifier to produce the projected matrix (Unused)
145156
146157 Returns:
147158 a normalized value matrix
148159 """
149160 if order == 2 : ## denominator is L2 norm
150- wOrdSum = jnp .maximum (jnp .sqrt (jnp .sum (jnp .square (M ), axis = axis , keepdims = True )), 1e-8 )
161+ wOrdSum = jnp .maximum (jnp .sqrt (jnp .sum (jnp .square (data ), axis = axis , keepdims = True )), 1e-8 )
151162 else : ## denominator is L1 norm
152- wOrdSum = jnp .maximum (jnp .sum (jnp .abs (M ), axis = axis , keepdims = True ), 1e-8 )
163+ wOrdSum = jnp .maximum (jnp .sum (jnp .abs (data ), axis = axis , keepdims = True ), 1e-8 )
153164 m = (wOrdSum == 0. ).astype (dtype = jnp .float32 )
154165 wOrdSum = wOrdSum * (1. - m ) + m #wAbsSum[wAbsSum == 0.] = 1.
155- _M = M * (wnorm / wOrdSum )
156- #dM = ((wnorm/wOrdSum) - 1.) * M
157- #_M = M + dM * scale
158- return _M
166+ _data = data * (wnorm / wOrdSum )
167+ #d_data = ((wnorm/wOrdSum) - 1.) * data
168+ #_data = data + d_data * scale
169+ return _data
159170
160171@jit
161172def clamp_min (x , min_val ):
@@ -529,7 +540,7 @@ def elu(x, alpha=1.):
529540 return x * mask + ((jnp .exp (x ) - 1 ) * alpha ) * (1. - mask )
530541
531542@jit
532- def elu (x , alpha = 1. ):
543+ def d_elu (x , alpha = 1. ):
533544 mask = (x >= 0. )
534545 return mask + (1. - mask ) * (jnp .exp (x ) * alpha )
535546
0 commit comments