11import tensorflow as tf
22from keras import layers
3+ from keras import backend as K
34
45class SelfAttention (layers .Layer ):
56 """ A self-attention layer for convolutional neural networks.
@@ -87,4 +88,97 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
8788 # Apply the gamma parameter to the attended value tensor and add it to the output tensor.
8889 attention_output = self .gamma * attention_output + inputs
8990
90- return attention_output
91+ return attention_output
92+
93+
94+ class SpectralNormalization (tf .keras .layers .Wrapper ):
95+ """Spectral Normalization Wrapper. !!! This is not working yet !!!"""
96+ def __init__ (self , layer , power_iterations = 1 , eps = 1e-12 , ** kwargs ):
97+ super (SpectralNormalization , self ).__init__ (layer , ** kwargs )
98+
99+ if power_iterations <= 0 :
100+ raise ValueError (
101+ "`power_iterations` should be greater than zero, got "
102+ "`power_iterations={}`" .format (power_iterations )
103+ )
104+ self .power_iterations = power_iterations
105+ self .eps = eps
106+ if not isinstance (layer , tf .keras .layers .Layer ):
107+ raise ValueError (
108+ 'Please initialize `TimeDistributed` layer with a '
109+ '`Layer` instance. You passed: {input}' .format (input = layer ))
110+
111+ def build (self , input_shape ):
112+ if not self .layer .built :
113+ self .layer .build (input_shape )
114+
115+ self .w = self .layer .kernel
116+ self .w_shape = self .w .shape .as_list ()
117+
118+ # self.v = self.add_weight(shape=(1, self.w_shape[0] * self.w_shape[1] * self.w_shape[2]),
119+ # initializer=tf.initializers.TruncatedNormal(stddev=0.02),
120+ # trainable=False,
121+ # name='sn_v',
122+ # dtype=tf.float32)
123+
124+ self .u = self .add_weight (shape = (1 , self .w_shape [- 1 ]),
125+ initializer = tf .initializers .TruncatedNormal (stddev = 0.02 ),
126+ trainable = False ,
127+ name = 'sn_u' ,
128+ dtype = tf .float32 )
129+
130+ super (SpectralNormalization , self ).build ()
131+
132+ def l2normalize (self , v , eps = 1e-12 ):
133+ return v / (tf .reduce_sum (v ** 2 ) ** 0.5 + eps )
134+
135+ def power_iteration (self , W , u , rounds = 1 ):
136+ _u = u
137+
138+ for _ in range (rounds ):
139+ # v_ = tf.matmul(_u, tf.transpose(W))
140+ # v_hat = self.l2normalize(v_)
141+ _v = self .l2normalize (K .dot (_u , K .transpose (W )), eps = self .eps )
142+
143+ # u_ = tf.matmul(v_hat, W)
144+ # u_hat = self.l2normalize(u_)
145+ _u = self .l2normalize (K .dot (_v , W ), eps = self .eps )
146+
147+ return _u , _v
148+
149+ def call (self , inputs , training = None ):
150+ if training is None :
151+ training = tf .keras .backend .learning_phase ()
152+
153+ if training :
154+ self .update_weights ()
155+ output = self .layer (inputs )
156+ self .restore_weights () # Restore weights because of this formula "W = W - alpha * W_SN`"
157+ return output
158+
159+ return self .layer (inputs )
160+
161+ def update_weights (self ):
162+ w_reshaped = tf .reshape (self .w , [- 1 , self .w_shape [- 1 ]])
163+
164+ # u_hat = self.u
165+ # v_hat = self.v # init v vector
166+
167+ u_hat , v_hat = self .power_iteration (w_reshaped , self .u , self .power_iterations )
168+ # v_ = tf.matmul(u_hat, tf.transpose(w_reshaped))
169+ # # v_hat = v_ / (tf.reduce_sum(v_**2)**0.5 + self.eps)
170+ # v_hat = self.l2normalize(v_, self.eps)
171+
172+ # u_ = tf.matmul(v_hat, w_reshaped)
173+ # # u_hat = u_ / (tf.reduce_sum(u_**2)**0.5 + self.eps)
174+ # u_hat = self.l2normalize(u_, self.eps)
175+
176+ # sigma = tf.matmul(tf.matmul(v_hat, w_reshaped), tf.transpose(u_hat))
177+ sigma = K .dot (K .dot (v_hat , w_reshaped ), K .transpose (u_hat ))
178+ self .u .assign (u_hat )
179+ # self.v.assign(v_hat)
180+
181+ self .layer .kernel .assign (self .w / sigma )
182+
183+ def restore_weights (self ):
184+ self .layer .kernel .assign (self .w )
0 commit comments