Skip to content

Commit 972eaec

Browse files
Add a sinusoidal embedding layer (#59)
* Add a sinusoidal embedding layer * renamed min_frequency to base_frequency and updated docstring * changed base_frequency to max_wavelength and added test to check correct values * renamed files * Fixup docstring formatting Co-authored-by: Matt Watson <[email protected]>
1 parent e3cdc95 commit 972eaec

File tree

3 files changed

+217
-0
lines changed

3 files changed

+217
-0
lines changed

keras_nlp/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
# limitations under the License.
1414

1515
from keras_nlp.layers.fnet_encoder import FNetEncoder
16+
from keras_nlp.layers.sine_position_encoding import SinePositionEncoding
1617
from keras_nlp.layers.transformer_decoder import TransformerDecoder
1718
from keras_nlp.layers.transformer_encoder import TransformerEncoder
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Sinusoidal position embedding layer."""
16+
17+
import tensorflow as tf
18+
from tensorflow import keras
19+
20+
21+
class SinePositionEncoding(keras.layers.Layer):
22+
"""Sinusoidal positional encoding layer.
23+
24+
This layer calculates the position encoding as a mix of sine and cosine
25+
functions with geometrically increasing wavelengths. Defined and formulized
26+
in [Attention is All You Need](https://arxiv.org/abs/1706.03762).
27+
28+
Takes as input an embedded token tensor. The input must have shape
29+
[batch_size, sequence_length, feature_size]. This layer will return a
30+
positional encoding the same size as the embedded token tensor, which
31+
can be added directly to the embedded token tensor.
32+
33+
Args:
34+
max_wavelength: The maximum angular wavelength of the sine/cosine
35+
curves, as described in Attention is All You Need. Defaults to
36+
10000.
37+
38+
Example:
39+
```python
40+
# create a simple embedding layer with sinusoidal positional encoding
41+
seq_len = 100
42+
vocab_size = 1000
43+
embedding_dim = 32
44+
inputs = keras.Input((seq_len,), dtype=tf.float32)
45+
embedding = keras.layers.Embedding(
46+
input_dim=vocab_size, output_dim=embedding_dim
47+
)(inputs)
48+
positional_encoding = keras_nlp.layers.SinePositionEncoding()(embedding)
49+
outputs = embedding + positional_encoding
50+
```
51+
52+
References:
53+
[Attention is All You Need](https://arxiv.org/abs/1706.03762)
54+
"""
55+
56+
def __init__(
57+
self,
58+
max_wavelength=10000,
59+
**kwargs,
60+
):
61+
super().__init__(**kwargs)
62+
self.max_wavelength = max_wavelength
63+
64+
def call(self, inputs):
65+
input_shape = tf.shape(inputs)
66+
# length of sequence is the second last dimension of the inputs
67+
seq_length = input_shape[-2]
68+
hidden_size = input_shape[-1]
69+
position = tf.cast(tf.range(seq_length), self.compute_dtype)
70+
min_freq = tf.cast(1 / self.max_wavelength, dtype=self.compute_dtype)
71+
timescales = tf.pow(
72+
min_freq,
73+
tf.cast(2 * (tf.range(hidden_size) // 2), self.compute_dtype)
74+
/ tf.cast(hidden_size, self.compute_dtype),
75+
)
76+
angles = tf.expand_dims(position, 1) * tf.expand_dims(timescales, 0)
77+
# even indices are sine, odd are cosine
78+
cos_mask = tf.cast(tf.range(hidden_size) % 2, self.compute_dtype)
79+
sin_mask = 1 - cos_mask
80+
# embedding shape is [seq_length, hidden_size]
81+
positional_encodings = (
82+
tf.sin(angles) * sin_mask + tf.cos(angles) * cos_mask
83+
)
84+
85+
return tf.broadcast_to(positional_encodings, input_shape)
86+
87+
def get_config(self):
88+
config = super().get_config()
89+
config.update(
90+
{
91+
"max_wavelength": self.max_wavelength,
92+
}
93+
)
94+
return config
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for Sinusoidal Positional encoding."""
15+
16+
17+
import tensorflow as tf
18+
from tensorflow import keras
19+
20+
from keras_nlp.layers import sine_position_encoding
21+
22+
23+
class SinePositionEncodingTest(tf.test.TestCase):
24+
def test_valid_call(self):
25+
pos_encoding = sine_position_encoding.SinePositionEncoding()
26+
model = keras.Sequential(
27+
[
28+
keras.Input(shape=(4, 6)),
29+
pos_encoding,
30+
]
31+
)
32+
input = tf.random.uniform(shape=[2, 4, 6])
33+
model(input)
34+
35+
def test_static_layer_output_shape(self):
36+
pos_encoding = sine_position_encoding.SinePositionEncoding()
37+
seq_length = 100
38+
hidden_size = 32
39+
inputs = keras.Input(shape=(seq_length, hidden_size))
40+
outputs = pos_encoding(inputs)
41+
42+
# When using static positional encoding shapes, the output is expected
43+
# to be the same as the input shape in all dimensions.
44+
expected_output_shape = [None, seq_length, hidden_size]
45+
self.assertEqual(expected_output_shape, outputs.shape.as_list())
46+
47+
def test_dynamic_layer_output_shape(self):
48+
pos_encoding = sine_position_encoding.SinePositionEncoding()
49+
hidden_size = 32
50+
inputs = keras.Input(shape=(None, hidden_size))
51+
outputs = pos_encoding(inputs)
52+
53+
# When using dynamic positional encoding shapes, the output is expected
54+
# to be the same as the input shape in all dimensions but may be None.
55+
expected_output_shape = [None, None, hidden_size]
56+
self.assertEqual(expected_output_shape, outputs.shape.as_list())
57+
58+
# do multi dimension before sequence length
59+
def test_multi_dimension_layer_output_shape(self):
60+
pos_encoding = sine_position_encoding.SinePositionEncoding()
61+
seq_length = 100
62+
hidden_size = 32
63+
inputs = keras.Input(shape=(None, seq_length, hidden_size))
64+
outputs = pos_encoding(inputs)
65+
66+
# When using muliple dimensions before sequence length, the output is
67+
# expected to be the same as the input shape in all dimensions.
68+
expected_output_shape = [None, None, seq_length, hidden_size]
69+
self.assertEqual(expected_output_shape, outputs.shape.as_list())
70+
71+
def test_output_correct_values(self):
72+
pos_encoding = sine_position_encoding.SinePositionEncoding()
73+
model = keras.Sequential(
74+
[
75+
keras.Input(shape=(4, 6)),
76+
pos_encoding,
77+
]
78+
)
79+
input = tf.random.uniform(shape=[1, 4, 6])
80+
output = model(input)
81+
82+
# comapre position encoding values for position 0 and 3
83+
expected_encoding_position_0 = [0.0, 1.0, 0.0, 1.0, 0.0, 1.0]
84+
expected_encoding_position_3 = [
85+
0.14112,
86+
-0.9899925,
87+
0.1387981,
88+
0.9903207,
89+
0.00646326,
90+
0.99997914,
91+
]
92+
self.assertAllClose(output[0, 0, :], expected_encoding_position_0)
93+
self.assertAllClose(output[0, 3, :], expected_encoding_position_3)
94+
95+
def test_get_config_and_from_config(self):
96+
pos_encoding = sine_position_encoding.SinePositionEncoding(
97+
max_wavelength=1000,
98+
)
99+
config = pos_encoding.get_config()
100+
expected_config_subset = {
101+
"max_wavelength": 1000,
102+
}
103+
self.assertEqual(config, {**config, **expected_config_subset})
104+
restored_pos_encoding = (
105+
sine_position_encoding.SinePositionEncoding.from_config(config)
106+
)
107+
self.assertEqual(
108+
restored_pos_encoding.get_config(),
109+
{**config, **expected_config_subset},
110+
)
111+
112+
def test_float16_dtype(self):
113+
pos_encoding = sine_position_encoding.SinePositionEncoding(
114+
dtype="float16"
115+
)
116+
seq_length = 100
117+
hidden_size = 32
118+
inputs = keras.Input(shape=(seq_length, hidden_size))
119+
outputs = pos_encoding(inputs)
120+
121+
# output dtype for this layer should be tf.float16.
122+
self.assertEqual(outputs.dtype, tf.float16)

0 commit comments

Comments
 (0)