|
| 1 | +# Copyright 2022 The KerasNLP Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""FNet encoder block implementation based on `keras.layers.Layer`.""" |
| 16 | + |
| 17 | +import tensorflow as tf |
| 18 | +from tensorflow import keras |
| 19 | + |
| 20 | + |
| 21 | +class FNetEncoder(keras.layers.Layer): |
| 22 | + """FNet encoder. |
| 23 | +
|
| 24 | + This class follows the architecture of FNet encoder layer in paper |
| 25 | + "FNet: Mixing Tokens with Fourier Transforms" |
| 26 | + (https://arxiv.org/abs/2105.03824). Users can instantiate multiple instances |
| 27 | + of this class to stack up the encoder. |
| 28 | +
|
| 29 | + Note on padding: In the official FNet code, padding tokens are added to the |
| 30 | + the input. However, the padding masks are deleted, i.e., mixing of |
| 31 | + all tokens is done. This is because certain frequencies will be zeroed |
| 32 | + out if we apply padding masks in every encoder layer. Hence, we don't |
| 33 | + take padding mask as input in the call() function. |
| 34 | +
|
| 35 | + Args: |
| 36 | + intermediate_dim: int. The hidden size of feedforward network. |
| 37 | + dropout: float, defaults to 0. The dropout value, applied in the |
| 38 | + feedforward network. |
| 39 | + activation: string or `tf.keras.activations`, defaults to "relu". The |
| 40 | + activation function of feedforward network. |
| 41 | + layer_norm_epsilon: float, defaults to 1e-5. The epsilon value in layer |
| 42 | + normalization components. |
| 43 | + kernel_initializer: "string" or `tf.keras.initializers` initializer, |
| 44 | + defaults to "glorot_uniform". The kernel initializer for the dense |
| 45 | + layers. |
| 46 | + bias_initializer: "string" or `tf.keras.initializers` initializer, |
| 47 | + defaults to "zeros". The bias initializer for the dense layers. |
| 48 | + name: string, defaults to None. The name of the layer. |
| 49 | + **kwargs: other keyword arguments. |
| 50 | +
|
| 51 | + Examples: |
| 52 | +
|
| 53 | + ```python |
| 54 | + # Create a single FNet encoder layer. |
| 55 | + encoder = keras_nlp.layers.FNetEncoder( |
| 56 | + intermediate_dim=64) |
| 57 | +
|
| 58 | + # Create a simple model containing the encoder. |
| 59 | + input = tf.keras.Input(shape=[4, 6]) |
| 60 | + output = encoder(input) |
| 61 | + model = tf.keras.Model(inputs=input, outputs=output) |
| 62 | +
|
| 63 | + # Call encoder on the inputs. |
| 64 | + input_data = tf.random.uniform(shape=[1, 10, 64]) |
| 65 | + output = model(input_data) |
| 66 | + ``` |
| 67 | +
|
| 68 | + References: |
| 69 | + [Lee-Thorp et al., 2021](https://arxiv.org/abs/2105.03824) |
| 70 | + """ |
| 71 | + |
| 72 | + def __init__( |
| 73 | + self, |
| 74 | + intermediate_dim, |
| 75 | + dropout=0, |
| 76 | + activation="relu", |
| 77 | + layer_norm_epsilon=1e-5, |
| 78 | + kernel_initializer="glorot_uniform", |
| 79 | + bias_initializer="zeros", |
| 80 | + name=None, |
| 81 | + **kwargs |
| 82 | + ): |
| 83 | + super().__init__(name=name, **kwargs) |
| 84 | + self.intermediate_dim = intermediate_dim |
| 85 | + self.dropout = dropout |
| 86 | + self.activation = keras.activations.get(activation) |
| 87 | + self.layer_norm_epsilon = layer_norm_epsilon |
| 88 | + self.kernel_initializer = keras.initializers.get(kernel_initializer) |
| 89 | + self.bias_initializer = keras.initializers.get(bias_initializer) |
| 90 | + |
| 91 | + def build(self, input_shape): |
| 92 | + # Create layers based on input shape. |
| 93 | + feature_size = input_shape[-1] |
| 94 | + |
| 95 | + # Layer Norm layers. |
| 96 | + self._mixing_layer_norm = keras.layers.LayerNormalization( |
| 97 | + epsilon=self.layer_norm_epsilon |
| 98 | + ) |
| 99 | + self._output_layer_norm = keras.layers.LayerNormalization( |
| 100 | + epsilon=self.layer_norm_epsilon |
| 101 | + ) |
| 102 | + |
| 103 | + # Feedforward layers. |
| 104 | + self._intermediate_dense = keras.layers.Dense( |
| 105 | + self.intermediate_dim, |
| 106 | + activation=self.activation, |
| 107 | + kernel_initializer=self.kernel_initializer, |
| 108 | + bias_initializer=self.bias_initializer, |
| 109 | + ) |
| 110 | + self._output_dense = keras.layers.Dense( |
| 111 | + feature_size, |
| 112 | + kernel_initializer=self.kernel_initializer, |
| 113 | + bias_initializer=self.bias_initializer, |
| 114 | + ) |
| 115 | + self._output_dropout = keras.layers.Dropout(rate=self.dropout) |
| 116 | + |
| 117 | + def call(self, inputs): |
| 118 | + """Forward pass of the FNetEncoder. |
| 119 | +
|
| 120 | + Args: |
| 121 | + inputs: a Tensor. The input data to TransformerEncoder, should be |
| 122 | + of shape [batch_size, sequence_length, feature_dim]. |
| 123 | +
|
| 124 | + Returns: |
| 125 | + A Tensor of the same shape as the `inputs`. |
| 126 | + """ |
| 127 | + |
| 128 | + def fourier_transform(input): |
| 129 | + # Apply FFT on the input and take the real part. |
| 130 | + # Before we apply fourier transform, let's convert the dtype of the |
| 131 | + # input tensor to complex64. |
| 132 | + input = tf.cast(input, tf.complex64) |
| 133 | + mixing_output = tf.math.real(tf.signal.fft2d(input)) |
| 134 | + return mixing_output |
| 135 | + |
| 136 | + def add_and_norm(input1, input2, norm_layer): |
| 137 | + return norm_layer(input1 + input2) |
| 138 | + |
| 139 | + def feed_forward(input): |
| 140 | + x = self._intermediate_dense(input) |
| 141 | + x = self._output_dense(x) |
| 142 | + return self._output_dropout(x) |
| 143 | + |
| 144 | + mixing_output = fourier_transform(inputs) |
| 145 | + |
| 146 | + mixing_output = add_and_norm( |
| 147 | + inputs, mixing_output, self._mixing_layer_norm |
| 148 | + ) |
| 149 | + |
| 150 | + feed_forward_output = feed_forward(mixing_output) |
| 151 | + |
| 152 | + x = add_and_norm( |
| 153 | + mixing_output, feed_forward_output, self._output_layer_norm |
| 154 | + ) |
| 155 | + return x |
| 156 | + |
| 157 | + def get_config(self): |
| 158 | + config = super().get_config() |
| 159 | + config.update( |
| 160 | + { |
| 161 | + "intermediate_dim": self.intermediate_dim, |
| 162 | + "dropout": self.dropout, |
| 163 | + "activation": keras.activations.serialize(self.activation), |
| 164 | + "layer_norm_epsilon": self.layer_norm_epsilon, |
| 165 | + "kernel_initializer": keras.initializers.serialize( |
| 166 | + self.kernel_initializer |
| 167 | + ), |
| 168 | + "bias_initializer": keras.initializers.serialize( |
| 169 | + self.bias_initializer |
| 170 | + ), |
| 171 | + } |
| 172 | + ) |
| 173 | + return config |
0 commit comments