2121
2222
2323class PositionEmbedding (keras .layers .Layer ):
24- """Creates a layer which learns a position embedding for inputs sequences.
24+ """A layer which learns a position embedding for inputs sequences.
2525
2626 This class assumes that in the input tensor, the last dimension corresponds
2727 to the features, and the dimension before the last corresponds to the
2828 sequence.
2929
30- This class accepts `RaggedTensor`s as inputs to process batches of sequences
31- of different lengths. The one ragged dimension must be the dimension that
32- corresponds to the sequence, that is, the penultimate dimension.
30+ This layer optionally accepts `tf.RaggedTensor`s as inputs to process
31+ batches of sequences of different lengths. The one ragged dimension must be
32+ the dimension that corresponds to the sequence, that is, the penultimate
33+ dimension.
34+
35+ This layer does not supporting masking, but can be combined with a
36+ `keras.layers.Embedding` for padding mask support.
3337
3438 Args:
3539 sequence_length: The maximum length of the dynamic sequence.
@@ -38,17 +42,25 @@ class PositionEmbedding(keras.layers.Layer):
3842 seq_axis: The axis of the input tensor where we add the embeddings.
3943
4044 Examples:
45+
46+ Called directly on input.
47+ >>> layer = keras_nlp.layers.PositionEmbedding(sequence_length=10)
48+ >>> layer(tf.zeros((8, 10, 16))).shape
49+ TensorShape([8, 10, 16])
50+
51+ Combine with a token embedding.
4152 ```python
42- token_embeddings = layers.Embedding(
53+ seq_length = 50
54+ vocab_size = 5000
55+ embed_dim = 128
56+ inputs = keras.Input(shape=(seq_length,))
57+ token_embeddings = keras.layers.Embedding(
4358 input_dim=vocab_size, output_dim=embed_dim
44- )
59+ )(inputs)
4560 position_embeddings = keras_nlp.layers.PositionEmbedding(
46- sequence_length=sequence_length
47- )
48-
49- embedded_tokens = token_embeddings(inputs)
50- embedded_positions = position_embeddings(embedded_tokens)
51- outputs = embedded_tokens + embedded_positions
61+ sequence_length=seq_length
62+ )(token_embeddings)
63+ outputs = token_embeddings + position_embeddings
5264 ```
5365
5466 Reference:
@@ -107,8 +119,9 @@ def call(self, inputs):
107119 )
108120
109121 def _trim_and_broadcast_position_embeddings (self , shape ):
110- sequence_length = shape [SEQUENCE_AXIS ]
111- # trim to match the length of the sequence
112- position_embeddings = self .position_embeddings [:sequence_length , :]
122+ input_length = shape [SEQUENCE_AXIS ]
123+ # trim to match the length of the input sequence, which might be less
124+ # than the sequence_length of the layer.
125+ position_embeddings = self .position_embeddings [:input_length , :]
113126 # then broadcast to add the missing dimensions to match "shape"
114127 return tf .broadcast_to (position_embeddings , shape )
0 commit comments