Skip to content

Commit ab195a6

Browse files
committed
create SpectralNormalization layer in mltu.tensorflow.layers
1 parent fdc5c9d commit ab195a6

File tree

1 file changed

+95
-1
lines changed

1 file changed

+95
-1
lines changed

mltu/tensorflow/layers.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import tensorflow as tf
22
from keras import layers
3+
from keras import backend as K
34

45
class SelfAttention(layers.Layer):
56
""" A self-attention layer for convolutional neural networks.
@@ -87,4 +88,97 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
8788
# Apply the gamma parameter to the attended value tensor and add it to the output tensor.
8889
attention_output = self.gamma * attention_output + inputs
8990

90-
return attention_output
91+
return attention_output
92+
93+
94+
class SpectralNormalization(tf.keras.layers.Wrapper):
95+
"""Spectral Normalization Wrapper. !!! This is not working yet !!!"""
96+
def __init__(self, layer, power_iterations=1, eps=1e-12, **kwargs):
97+
super(SpectralNormalization, self).__init__(layer, **kwargs)
98+
99+
if power_iterations <= 0:
100+
raise ValueError(
101+
"`power_iterations` should be greater than zero, got "
102+
"`power_iterations={}`".format(power_iterations)
103+
)
104+
self.power_iterations = power_iterations
105+
self.eps = eps
106+
if not isinstance(layer, tf.keras.layers.Layer):
107+
raise ValueError(
108+
'Please initialize `TimeDistributed` layer with a '
109+
'`Layer` instance. You passed: {input}'.format(input=layer))
110+
111+
def build(self, input_shape):
112+
if not self.layer.built:
113+
self.layer.build(input_shape)
114+
115+
self.w = self.layer.kernel
116+
self.w_shape = self.w.shape.as_list()
117+
118+
# self.v = self.add_weight(shape=(1, self.w_shape[0] * self.w_shape[1] * self.w_shape[2]),
119+
# initializer=tf.initializers.TruncatedNormal(stddev=0.02),
120+
# trainable=False,
121+
# name='sn_v',
122+
# dtype=tf.float32)
123+
124+
self.u = self.add_weight(shape=(1, self.w_shape[-1]),
125+
initializer=tf.initializers.TruncatedNormal(stddev=0.02),
126+
trainable=False,
127+
name='sn_u',
128+
dtype=tf.float32)
129+
130+
super(SpectralNormalization, self).build()
131+
132+
def l2normalize(self, v, eps=1e-12):
133+
return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)
134+
135+
def power_iteration(self, W, u, rounds=1):
136+
_u = u
137+
138+
for _ in range(rounds):
139+
# v_ = tf.matmul(_u, tf.transpose(W))
140+
# v_hat = self.l2normalize(v_)
141+
_v = self.l2normalize(K.dot(_u, K.transpose(W)), eps=self.eps)
142+
143+
# u_ = tf.matmul(v_hat, W)
144+
# u_hat = self.l2normalize(u_)
145+
_u = self.l2normalize(K.dot(_v, W), eps=self.eps)
146+
147+
return _u, _v
148+
149+
def call(self, inputs, training=None):
150+
if training is None:
151+
training = tf.keras.backend.learning_phase()
152+
153+
if training:
154+
self.update_weights()
155+
output = self.layer(inputs)
156+
self.restore_weights() # Restore weights because of this formula "W = W - alpha * W_SN`"
157+
return output
158+
159+
return self.layer(inputs)
160+
161+
def update_weights(self):
162+
w_reshaped = tf.reshape(self.w, [-1, self.w_shape[-1]])
163+
164+
# u_hat = self.u
165+
# v_hat = self.v # init v vector
166+
167+
u_hat, v_hat = self.power_iteration(w_reshaped, self.u, self.power_iterations)
168+
# v_ = tf.matmul(u_hat, tf.transpose(w_reshaped))
169+
# # v_hat = v_ / (tf.reduce_sum(v_**2)**0.5 + self.eps)
170+
# v_hat = self.l2normalize(v_, self.eps)
171+
172+
# u_ = tf.matmul(v_hat, w_reshaped)
173+
# # u_hat = u_ / (tf.reduce_sum(u_**2)**0.5 + self.eps)
174+
# u_hat = self.l2normalize(u_, self.eps)
175+
176+
# sigma = tf.matmul(tf.matmul(v_hat, w_reshaped), tf.transpose(u_hat))
177+
sigma=K.dot(K.dot(v_hat, w_reshaped), K.transpose(u_hat))
178+
self.u.assign(u_hat)
179+
# self.v.assign(v_hat)
180+
181+
self.layer.kernel.assign(self.w / sigma)
182+
183+
def restore_weights(self):
184+
self.layer.kernel.assign(self.w)

0 commit comments

Comments
 (0)