Skip to content

Commit ae224a1

Browse files
authored
Adding ragged support to SinePositionEncoding (#751)
* adding support for tf.RaggedTensor * formatting the code * tests for ragged tensor
1 parent ee99548 commit ae224a1

File tree

2 files changed

+104
-6
lines changed

2 files changed

+104
-6
lines changed

keras_nlp/layers/sine_position_encoding.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ class SinePositionEncoding(keras.layers.Layer):
3131
positional encoding the same size as the embedded token tensor, which
3232
can be added directly to the embedded token tensor.
3333
34+
This layer optionally accepts `tf.RaggedTensor`s as inputs to process
35+
batches of sequences of different lengths. The one ragged dimension must be
36+
the dimension that corresponds to the sequence, that is, the penultimate
37+
dimension.
38+
3439
Args:
3540
max_wavelength: The maximum angular wavelength of the sine/cosine
3641
curves, as described in Attention is All You Need. Defaults to
@@ -65,10 +70,26 @@ def __init__(
6570
def call(self, inputs):
6671
# TODO(jbischof): replace `hidden_size` with`hidden_dim` for consistency
6772
# with other layers.
68-
input_shape = tf.shape(inputs)
69-
# length of sequence is the second last dimension of the inputs
70-
seq_length = input_shape[-2]
71-
hidden_size = input_shape[-1]
73+
if isinstance(inputs, tf.RaggedTensor):
74+
bounding_shape = inputs.bounding_shape()
75+
position_embeddings = (
76+
self._compute_trim_and_broadcast_position_embeddings(
77+
bounding_shape,
78+
)
79+
)
80+
# then apply row lengths to recreate the same ragged shape as inputs
81+
return tf.RaggedTensor.from_tensor(
82+
position_embeddings,
83+
inputs.nested_row_lengths(),
84+
)
85+
else:
86+
return self._compute_trim_and_broadcast_position_embeddings(
87+
tf.shape(inputs),
88+
)
89+
90+
def _compute_trim_and_broadcast_position_embeddings(self, shape):
91+
seq_length = shape[-2]
92+
hidden_size = shape[-1]
7293
position = tf.cast(tf.range(seq_length), self.compute_dtype)
7394
min_freq = tf.cast(1 / self.max_wavelength, dtype=self.compute_dtype)
7495
timescales = tf.pow(
@@ -85,7 +106,7 @@ def call(self, inputs):
85106
tf.sin(angles) * sin_mask + tf.cos(angles) * cos_mask
86107
)
87108

88-
return tf.broadcast_to(positional_encodings, input_shape)
109+
return tf.broadcast_to(positional_encodings, shape)
89110

90111
def get_config(self):
91112
config = super().get_config()

keras_nlp/layers/sine_position_encoding_test.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
"""Tests for Sinusoidal Positional encoding."""
1515

16-
1716
import tensorflow as tf
1817
from tensorflow import keras
1918

@@ -92,6 +91,84 @@ def test_output_correct_values(self):
9291
self.assertAllClose(output[0, 0, :], expected_encoding_position_0)
9392
self.assertAllClose(output[0, 3, :], expected_encoding_position_3)
9493

94+
def test_ragged_tensor_with_3_dimensions(self):
95+
feature_size = 2
96+
test_layer = sine_position_encoding.SinePositionEncoding()
97+
# Create a 3-dimensional ragged input (the first dimension is implicit).
98+
input_tensor = keras.Input(
99+
shape=(None, feature_size), dtype=tf.float32, ragged=True
100+
)
101+
output_tensor = test_layer(input_tensor)
102+
model = keras.Model(input_tensor, output_tensor)
103+
104+
input_data = tf.ragged.constant(
105+
[
106+
[[1.0, 1.0], [1.0, 1.0]],
107+
[],
108+
[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]],
109+
[[1.0, 1.0]],
110+
],
111+
ragged_rank=1,
112+
inner_shape=(2,),
113+
)
114+
expected_output_data = tf.ragged.constant(
115+
[
116+
[[0.0, 1.0], [0.84147096, 0.5403023]],
117+
[],
118+
[[0.0, 1.0], [0.84147096, 0.5403023], [0.9092974, -0.41614684]],
119+
[[0.0, 1.0]],
120+
],
121+
ragged_rank=1,
122+
inner_shape=(2,),
123+
)
124+
output_data = model.predict(input_data)
125+
self.assertAllClose(output_data, expected_output_data)
126+
127+
def test_ragged_tensor_with_4_dimensions(self):
128+
feature_size = 2
129+
test_layer = sine_position_encoding.SinePositionEncoding()
130+
# Create a 4-dimensional ragged input (the first dimension is implicit).
131+
input_tensor = keras.Input(
132+
shape=(None, None, feature_size), dtype=tf.float32, ragged=True
133+
)
134+
output_tensor = test_layer(input_tensor)
135+
model = keras.Model(input_tensor, output_tensor)
136+
137+
input_data = tf.ragged.constant(
138+
[
139+
[
140+
[[1.0, 1.0], [1.0, 1.0]],
141+
[],
142+
],
143+
[
144+
[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]],
145+
[[1.0, 1.0]],
146+
],
147+
],
148+
ragged_rank=2,
149+
inner_shape=(2,),
150+
)
151+
expected_output_data = tf.ragged.constant(
152+
[
153+
[
154+
[[0.0, 1.0], [0.84147096, 0.5403023]],
155+
[],
156+
],
157+
[
158+
[
159+
[0.0, 1.0],
160+
[0.84147096, 0.5403023],
161+
[0.9092974, -0.41614684],
162+
],
163+
[[0.0, 1.0]],
164+
],
165+
],
166+
ragged_rank=2,
167+
inner_shape=(2,),
168+
)
169+
output_data = model.predict(input_data)
170+
self.assertAllClose(output_data, expected_output_data)
171+
95172
def test_get_config_and_from_config(self):
96173
pos_encoding = sine_position_encoding.SinePositionEncoding(
97174
max_wavelength=1000,

0 commit comments

Comments
 (0)