Skip to content

Commit b17ca27

Browse files
authored
Add FNet Encoder Layer (#43)
* Add rough code for FNet Encoder * Format code * Minor doc-string changes * Format __init__.py * Address review comments - 1 * Add detailed comment about padding masks * Add kernel and bias initialisers * Add unit tests for the layer * Address review comments - 2 * Address review comments - 3 * Address review comments - 4 * Minor change * Correct doc-string
1 parent ba8b150 commit b17ca27

File tree

3 files changed

+310
-0
lines changed

3 files changed

+310
-0
lines changed

keras_nlp/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from keras_nlp.layers.fnet_encoder import FNetEncoder
1516
from keras_nlp.layers.transformer_decoder import TransformerDecoder
1617
from keras_nlp.layers.transformer_encoder import TransformerEncoder

keras_nlp/layers/fnet_encoder.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""FNet encoder block implementation based on `keras.layers.Layer`."""
16+
17+
import tensorflow as tf
18+
from tensorflow import keras
19+
20+
21+
class FNetEncoder(keras.layers.Layer):
22+
"""FNet encoder.
23+
24+
This class follows the architecture of FNet encoder layer in paper
25+
"FNet: Mixing Tokens with Fourier Transforms"
26+
(https://arxiv.org/abs/2105.03824). Users can instantiate multiple instances
27+
of this class to stack up the encoder.
28+
29+
Note on padding: In the official FNet code, padding tokens are added to the
30+
the input. However, the padding masks are deleted, i.e., mixing of
31+
all tokens is done. This is because certain frequencies will be zeroed
32+
out if we apply padding masks in every encoder layer. Hence, we don't
33+
take padding mask as input in the call() function.
34+
35+
Args:
36+
intermediate_dim: int. The hidden size of feedforward network.
37+
dropout: float, defaults to 0. The dropout value, applied in the
38+
feedforward network.
39+
activation: string or `tf.keras.activations`, defaults to "relu". The
40+
activation function of feedforward network.
41+
layer_norm_epsilon: float, defaults to 1e-5. The epsilon value in layer
42+
normalization components.
43+
kernel_initializer: "string" or `tf.keras.initializers` initializer,
44+
defaults to "glorot_uniform". The kernel initializer for the dense
45+
layers.
46+
bias_initializer: "string" or `tf.keras.initializers` initializer,
47+
defaults to "zeros". The bias initializer for the dense layers.
48+
name: string, defaults to None. The name of the layer.
49+
**kwargs: other keyword arguments.
50+
51+
Examples:
52+
53+
```python
54+
# Create a single FNet encoder layer.
55+
encoder = keras_nlp.layers.FNetEncoder(
56+
intermediate_dim=64)
57+
58+
# Create a simple model containing the encoder.
59+
input = tf.keras.Input(shape=[4, 6])
60+
output = encoder(input)
61+
model = tf.keras.Model(inputs=input, outputs=output)
62+
63+
# Call encoder on the inputs.
64+
input_data = tf.random.uniform(shape=[1, 10, 64])
65+
output = model(input_data)
66+
```
67+
68+
References:
69+
[Lee-Thorp et al., 2021](https://arxiv.org/abs/2105.03824)
70+
"""
71+
72+
def __init__(
73+
self,
74+
intermediate_dim,
75+
dropout=0,
76+
activation="relu",
77+
layer_norm_epsilon=1e-5,
78+
kernel_initializer="glorot_uniform",
79+
bias_initializer="zeros",
80+
name=None,
81+
**kwargs
82+
):
83+
super().__init__(name=name, **kwargs)
84+
self.intermediate_dim = intermediate_dim
85+
self.dropout = dropout
86+
self.activation = keras.activations.get(activation)
87+
self.layer_norm_epsilon = layer_norm_epsilon
88+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
89+
self.bias_initializer = keras.initializers.get(bias_initializer)
90+
91+
def build(self, input_shape):
92+
# Create layers based on input shape.
93+
feature_size = input_shape[-1]
94+
95+
# Layer Norm layers.
96+
self._mixing_layer_norm = keras.layers.LayerNormalization(
97+
epsilon=self.layer_norm_epsilon
98+
)
99+
self._output_layer_norm = keras.layers.LayerNormalization(
100+
epsilon=self.layer_norm_epsilon
101+
)
102+
103+
# Feedforward layers.
104+
self._intermediate_dense = keras.layers.Dense(
105+
self.intermediate_dim,
106+
activation=self.activation,
107+
kernel_initializer=self.kernel_initializer,
108+
bias_initializer=self.bias_initializer,
109+
)
110+
self._output_dense = keras.layers.Dense(
111+
feature_size,
112+
kernel_initializer=self.kernel_initializer,
113+
bias_initializer=self.bias_initializer,
114+
)
115+
self._output_dropout = keras.layers.Dropout(rate=self.dropout)
116+
117+
def call(self, inputs):
118+
"""Forward pass of the FNetEncoder.
119+
120+
Args:
121+
inputs: a Tensor. The input data to TransformerEncoder, should be
122+
of shape [batch_size, sequence_length, feature_dim].
123+
124+
Returns:
125+
A Tensor of the same shape as the `inputs`.
126+
"""
127+
128+
def fourier_transform(input):
129+
# Apply FFT on the input and take the real part.
130+
# Before we apply fourier transform, let's convert the dtype of the
131+
# input tensor to complex64.
132+
input = tf.cast(input, tf.complex64)
133+
mixing_output = tf.math.real(tf.signal.fft2d(input))
134+
return mixing_output
135+
136+
def add_and_norm(input1, input2, norm_layer):
137+
return norm_layer(input1 + input2)
138+
139+
def feed_forward(input):
140+
x = self._intermediate_dense(input)
141+
x = self._output_dense(x)
142+
return self._output_dropout(x)
143+
144+
mixing_output = fourier_transform(inputs)
145+
146+
mixing_output = add_and_norm(
147+
inputs, mixing_output, self._mixing_layer_norm
148+
)
149+
150+
feed_forward_output = feed_forward(mixing_output)
151+
152+
x = add_and_norm(
153+
mixing_output, feed_forward_output, self._output_layer_norm
154+
)
155+
return x
156+
157+
def get_config(self):
158+
config = super().get_config()
159+
config.update(
160+
{
161+
"intermediate_dim": self.intermediate_dim,
162+
"dropout": self.dropout,
163+
"activation": keras.activations.serialize(self.activation),
164+
"layer_norm_epsilon": self.layer_norm_epsilon,
165+
"kernel_initializer": keras.initializers.serialize(
166+
self.kernel_initializer
167+
),
168+
"bias_initializer": keras.initializers.serialize(
169+
self.bias_initializer
170+
),
171+
}
172+
)
173+
return config
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for FNet Encoder."""
15+
16+
import os
17+
18+
import tensorflow as tf
19+
from tensorflow import keras
20+
21+
from keras_nlp.layers import fnet_encoder
22+
23+
24+
class FNetEncoderTest(tf.test.TestCase):
25+
def test_valid_call(self):
26+
encoder = fnet_encoder.FNetEncoder(intermediate_dim=4)
27+
model = keras.Sequential(
28+
[
29+
keras.Input(shape=(4, 6)),
30+
encoder,
31+
]
32+
)
33+
input = tf.random.uniform(shape=[2, 4, 6])
34+
model(input)
35+
36+
def test_get_config_and_from_config(self):
37+
encoder = fnet_encoder.FNetEncoder(
38+
intermediate_dim=4,
39+
kernel_initializer="HeNormal",
40+
bias_initializer="Zeros",
41+
)
42+
config = encoder.get_config()
43+
expected_config_subset = {
44+
"intermediate_dim": 4,
45+
"dropout": 0,
46+
"activation": "relu",
47+
"layer_norm_epsilon": 1e-5,
48+
"kernel_initializer": keras.initializers.serialize(
49+
keras.initializers.HeNormal()
50+
),
51+
"bias_initializer": keras.initializers.serialize(
52+
keras.initializers.Zeros()
53+
),
54+
}
55+
self.assertEqual(config, {**config, **expected_config_subset})
56+
57+
restored_encoder = fnet_encoder.FNetEncoder.from_config(
58+
config,
59+
)
60+
self.assertEqual(
61+
restored_encoder.get_config(), {**config, **expected_config_subset}
62+
)
63+
64+
def test_value_error_when_invalid_kernel_initializer(self):
65+
with self.assertRaises(ValueError):
66+
fnet_encoder.FNetEncoder(
67+
intermediate_dim=4,
68+
dropout=0.5,
69+
kernel_initializer="Invalid",
70+
)
71+
72+
def test_one_training_step_of_fnet_encoder(self):
73+
encoder = fnet_encoder.FNetEncoder(intermediate_dim=4)
74+
inputs = keras.Input(shape=(4, 6))
75+
x = encoder(inputs)
76+
x = keras.layers.Dense(1, activation="sigmoid")(x)
77+
model = keras.Model(inputs=inputs, outputs=x)
78+
79+
data = tf.random.uniform(shape=[2, 4, 6])
80+
label = tf.cast(data[:, :, 0] >= 0.5, dtype=tf.int32)
81+
82+
loss_fn = keras.losses.BinaryCrossentropy(from_logits=False)
83+
optimizer = keras.optimizers.Adam()
84+
with tf.GradientTape() as tape:
85+
pred = model(data)
86+
loss = loss_fn(label, pred)
87+
grad = tape.gradient(loss, model.trainable_variables)
88+
self.assertGreater(len(grad), 1)
89+
optimizer.apply_gradients(zip(grad, model.trainable_variables))
90+
91+
def test_checkpointing_fnet_encoder(self):
92+
encoder1 = fnet_encoder.FNetEncoder(
93+
intermediate_dim=4,
94+
)
95+
96+
encoder2 = fnet_encoder.FNetEncoder(
97+
intermediate_dim=4,
98+
)
99+
data = tf.random.uniform(shape=[2, 4, 6])
100+
encoder1(data)
101+
encoder2(data)
102+
# The weights of encoder1 and encoder2 are different.
103+
self.assertFalse(
104+
all(
105+
encoder1._output_dense.trainable_variables[0][0]
106+
== encoder2._output_dense.trainable_variables[0][0]
107+
)
108+
)
109+
checkpoint = tf.train.Checkpoint(encoder1)
110+
checkpoint2 = tf.train.Checkpoint(encoder2)
111+
temp_dir = self.get_temp_dir()
112+
save_path = checkpoint.save(temp_dir)
113+
checkpoint2.restore(save_path)
114+
115+
encoder1_output = encoder1(data)
116+
encoder2_output = encoder2(data)
117+
self.assertAllClose(encoder1_output, encoder2_output)
118+
119+
def test_save_model(self):
120+
model = keras.Sequential(
121+
[
122+
keras.Input(shape=(4, 6)),
123+
fnet_encoder.FNetEncoder(
124+
intermediate_dim=4,
125+
),
126+
]
127+
)
128+
data = tf.random.uniform(shape=[2, 4, 6])
129+
model(data)
130+
path = os.path.join(self.get_temp_dir(), "model")
131+
model.save(path)
132+
loaded_model = keras.models.load_model(path)
133+
134+
model_output = model(data)
135+
loaded_model_output = loaded_model(data)
136+
self.assertAllClose(model_output, loaded_model_output)

0 commit comments

Comments
 (0)