Skip to content

Commit 0fd7a89

Browse files
authored
Add a learned positional embedding layer (#47)
1 parent 972eaec commit 0fd7a89

File tree

4 files changed

+445
-67
lines changed

4 files changed

+445
-67
lines changed

examples/bert/bert_model.py

Lines changed: 3 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import tensorflow as tf
2020
from tensorflow import keras
2121

22+
import keras_nlp
23+
2224

2325
def make_attention_mask(inputs, mask):
2426
"""Make a 3D attention mask from a 2D input mask.
@@ -331,72 +333,6 @@ def call(self, query, key_value=None, attention_mask=None):
331333
return self._output_layer_norm(layer_output + attention_output)
332334

333335

334-
class PositionEmbedding(keras.layers.Layer):
335-
"""Creates a positional embedding.
336-
337-
Example:
338-
```python
339-
position_embedding = PositionEmbedding(max_length=100)
340-
inputs = keras.Input((100, 32), dtype=tf.float32)
341-
outputs = position_embedding(inputs)
342-
```
343-
344-
345-
Args:
346-
max_length: The maximum size of the dynamic sequence.
347-
initializer: The initializer to use for the embedding weights. Defaults
348-
to "glorot_uniform".
349-
seq_axis: The axis of the input tensor where we add the embeddings.
350-
351-
Reference: This layer creates a positional embedding as described in
352-
[BERT: Pre-training of Deep Bidirectional Transformers for Language
353-
Understanding](https://arxiv.org/abs/1810.04805).
354-
"""
355-
356-
def __init__(
357-
self, max_length, initializer="glorot_uniform", seq_axis=1, **kwargs
358-
):
359-
360-
super().__init__(**kwargs)
361-
if max_length is None:
362-
raise ValueError("`max_length` must be an Integer, not `None`.")
363-
self._max_length = max_length
364-
self._initializer = keras.initializers.get(initializer)
365-
self._seq_axis = seq_axis
366-
367-
def get_config(self):
368-
config = {
369-
"max_length": self._max_length,
370-
"initializer": keras.initializers.serialize(self._initializer),
371-
"seq_axis": self._seq_axis,
372-
}
373-
base_config = super().get_config()
374-
return dict(list(base_config.items()) + list(config.items()))
375-
376-
def build(self, input_shape):
377-
dimension_list = input_shape.as_list()
378-
width = dimension_list[-1]
379-
weight_sequence_length = self._max_length
380-
381-
self._position_embeddings = self.add_weight(
382-
"embeddings",
383-
shape=[weight_sequence_length, width],
384-
initializer=self._initializer,
385-
)
386-
387-
super().build(input_shape)
388-
389-
def call(self, inputs):
390-
input_shape = tf.shape(inputs)
391-
actual_seq_len = input_shape[self._seq_axis]
392-
position_embeddings = self._position_embeddings[:actual_seq_len, :]
393-
new_shape = [1 for _ in inputs.get_shape().as_list()]
394-
new_shape[self._seq_axis] = actual_seq_len
395-
new_shape[-1] = position_embeddings.get_shape().as_list()[-1]
396-
position_embeddings = tf.reshape(position_embeddings, new_shape)
397-
return tf.broadcast_to(position_embeddings, input_shape)
398-
399-
400336
# TODO(mattdangerw): This class is needed for TPU friendly embeddings, we should
401337
# remove it entirely and fix tf.keras.layers.Embedding as needed.
402338
class OnDeviceEmbedding(keras.layers.Layer):
@@ -546,7 +482,7 @@ def __init__(
546482
name="word_embeddings",
547483
)
548484

549-
self._position_embedding_layer = PositionEmbedding(
485+
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
550486
initializer=initializer,
551487
max_length=max_sequence_length,
552488
name="position_embedding",

keras_nlp/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from keras_nlp.layers.fnet_encoder import FNetEncoder
16+
from keras_nlp.layers.position_embedding import PositionEmbedding
1617
from keras_nlp.layers.sine_position_encoding import SinePositionEncoding
1718
from keras_nlp.layers.transformer_decoder import TransformerDecoder
1819
from keras_nlp.layers.transformer_encoder import TransformerEncoder
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Position embedding implementation based on `keras.layers.Layer`."""
16+
17+
import tensorflow as tf
18+
from tensorflow import keras
19+
20+
SEQUENCE_AXIS = -2
21+
22+
23+
class PositionEmbedding(keras.layers.Layer):
24+
"""Creates a layer which learns a position embedding for inputs sequences.
25+
26+
This class assumes that in the input tensor, the last dimension corresponds
27+
to the features, and the dimension before the last corresponds to the
28+
sequence.
29+
30+
This class accepts `RaggedTensor`s as inputs to process batches of sequences
31+
of different lengths. The one ragged dimension must be the dimension that
32+
corresponds to the sequence, that is, the penultimate dimension.
33+
34+
Args:
35+
max_length: The maximum length of the dynamic sequence.
36+
initializer: The initializer to use for the embedding weights. Defaults
37+
to "glorot_uniform".
38+
seq_axis: The axis of the input tensor where we add the embeddings.
39+
40+
Example:
41+
```python
42+
token_embeddings = layers.Embedding(
43+
input_dim=vocab_size, output_dim=embed_dim
44+
)
45+
position_embeddings = keras_nlp.layers.PositionEmbedding(
46+
max_length=max_length
47+
)
48+
49+
embedded_tokens = self.token_embeddings(inputs)
50+
embedded_positions = self.position_embeddings(embedded_tokens)
51+
outputs = embedded_tokens + embedded_positions
52+
```
53+
54+
Reference:
55+
[BERT: Pre-training of Deep Bidirectional Transformers for Language
56+
Understanding](https://arxiv.org/abs/1810.04805).
57+
"""
58+
59+
def __init__(
60+
self,
61+
max_length,
62+
initializer="glorot_uniform",
63+
**kwargs,
64+
):
65+
super().__init__(**kwargs)
66+
if max_length is None:
67+
raise ValueError("`max_length` must be an Integer, not `None`.")
68+
self.max_length = int(max_length)
69+
self.initializer = keras.initializers.get(initializer)
70+
71+
def get_config(self):
72+
config = super().get_config()
73+
config.update(
74+
{
75+
"max_length": self.max_length,
76+
"initializer": keras.initializers.serialize(self.initializer),
77+
}
78+
)
79+
return config
80+
81+
def build(self, input_shape):
82+
feature_size = input_shape[-1]
83+
self.position_embeddings = self.add_weight(
84+
"embeddings",
85+
shape=[self.max_length, feature_size],
86+
initializer=self.initializer,
87+
trainable=True,
88+
)
89+
90+
super().build(input_shape)
91+
92+
def call(self, inputs):
93+
if isinstance(inputs, tf.RaggedTensor):
94+
bounding_shape = inputs.bounding_shape()
95+
position_embeddings = self._trim_and_broadcast_position_embeddings(
96+
bounding_shape,
97+
)
98+
# then apply row lengths to recreate the same ragged shape as inputs
99+
return tf.RaggedTensor.from_tensor(
100+
position_embeddings,
101+
inputs.nested_row_lengths(),
102+
)
103+
else:
104+
return self._trim_and_broadcast_position_embeddings(
105+
tf.shape(inputs),
106+
)
107+
108+
def _trim_and_broadcast_position_embeddings(self, shape):
109+
sequence_length = shape[SEQUENCE_AXIS]
110+
# trim to match the length of the sequence
111+
position_embeddings = self.position_embeddings[:sequence_length, :]
112+
# then broadcast to add the missing dimensions to match "shape"
113+
return tf.broadcast_to(position_embeddings, shape)

0 commit comments

Comments
 (0)