Skip to content

Commit 99922f8

Browse files
authored
Minor improvements to the position embedding docs (#180)
I'm using this as an example in a code style guide, so adding a few minor fixes I found while doing so.
1 parent 1844b46 commit 99922f8

File tree

2 files changed

+29
-16
lines changed

2 files changed

+29
-16
lines changed

keras_nlp/layers/position_embedding.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,19 @@
2121

2222

2323
class PositionEmbedding(keras.layers.Layer):
24-
"""Creates a layer which learns a position embedding for inputs sequences.
24+
"""A layer which learns a position embedding for inputs sequences.
2525
2626
This class assumes that in the input tensor, the last dimension corresponds
2727
to the features, and the dimension before the last corresponds to the
2828
sequence.
2929
30-
This class accepts `RaggedTensor`s as inputs to process batches of sequences
31-
of different lengths. The one ragged dimension must be the dimension that
32-
corresponds to the sequence, that is, the penultimate dimension.
30+
This layer optionally accepts `tf.RaggedTensor`s as inputs to process
31+
batches of sequences of different lengths. The one ragged dimension must be
32+
the dimension that corresponds to the sequence, that is, the penultimate
33+
dimension.
34+
35+
This layer does not supporting masking, but can be combined with a
36+
`keras.layers.Embedding` for padding mask support.
3337
3438
Args:
3539
sequence_length: The maximum length of the dynamic sequence.
@@ -38,17 +42,25 @@ class PositionEmbedding(keras.layers.Layer):
3842
seq_axis: The axis of the input tensor where we add the embeddings.
3943
4044
Examples:
45+
46+
Called directly on input.
47+
>>> layer = keras_nlp.layers.PositionEmbedding(sequence_length=10)
48+
>>> layer(tf.zeros((8, 10, 16))).shape
49+
TensorShape([8, 10, 16])
50+
51+
Combine with a token embedding.
4152
```python
42-
token_embeddings = layers.Embedding(
53+
seq_length = 50
54+
vocab_size = 5000
55+
embed_dim = 128
56+
inputs = keras.Input(shape=(seq_length,))
57+
token_embeddings = keras.layers.Embedding(
4358
input_dim=vocab_size, output_dim=embed_dim
44-
)
59+
)(inputs)
4560
position_embeddings = keras_nlp.layers.PositionEmbedding(
46-
sequence_length=sequence_length
47-
)
48-
49-
embedded_tokens = token_embeddings(inputs)
50-
embedded_positions = position_embeddings(embedded_tokens)
51-
outputs = embedded_tokens + embedded_positions
61+
sequence_length=seq_length
62+
)(token_embeddings)
63+
outputs = token_embeddings + position_embeddings
5264
```
5365
5466
Reference:
@@ -107,8 +119,9 @@ def call(self, inputs):
107119
)
108120

109121
def _trim_and_broadcast_position_embeddings(self, shape):
110-
sequence_length = shape[SEQUENCE_AXIS]
111-
# trim to match the length of the sequence
112-
position_embeddings = self.position_embeddings[:sequence_length, :]
122+
input_length = shape[SEQUENCE_AXIS]
123+
# trim to match the length of the input sequence, which might be less
124+
# than the sequence_length of the layer.
125+
position_embeddings = self.position_embeddings[:input_length, :]
113126
# then broadcast to add the missing dimensions to match "shape"
114127
return tf.broadcast_to(position_embeddings, shape)

keras_nlp/layers/token_and_position_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class TokenAndPositionEmbedding(keras.layers.Layer):
4646
seq_length = 50
4747
vocab_size = 5000
4848
embed_dim = 128
49-
inputs = tf.keras.Input(shape=(seq_length,))
49+
inputs = keras.Input(shape=(seq_length,))
5050
embedding_layer = keras_nlp.layers.TokenAndPositionEmbedding(
5151
vocabulary_size=vocab_size,
5252
sequence_length=seq_length,

0 commit comments

Comments
 (0)