Skip to content

Commit beb3f6a

Browse files
authored
Fix masking for TokenAndPositionEmbedding (#140)
Also fix up a nearby test.
1 parent 401201a commit beb3f6a

File tree

2 files changed

+21
-41
lines changed

2 files changed

+21
-41
lines changed

keras_nlp/layers/token_and_position_embedding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __init__(
9191
max_length=max_length,
9292
initializer=embeddings_initializer,
9393
)
94+
self.supports_masking = self.token_embedding.supports_masking
9495

9596
def get_config(self):
9697
config = super().get_config()
@@ -112,3 +113,6 @@ def call(self, inputs):
112113
embedded_positions = self.position_embedding(embedded_tokens)
113114
outputs = embedded_tokens + embedded_positions
114115
return outputs
116+
117+
def compute_mask(self, inputs, mask=None):
118+
return self.token_embedding.compute_mask(inputs, mask=mask)

keras_nlp/layers/token_and_position_embedding_test.py

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -112,51 +112,27 @@ def test_dense_tensor(self):
112112
)
113113
# Create a 2-dimensional input
114114
# (the first dimension is implicit).
115-
input_tensor = tf.keras.Input(
116-
shape=(sequence_length,), dtype=tf.float32, ragged=True
117-
)
118-
output_tensor = test_layer(input_tensor)
119-
model = tf.keras.Model(input_tensor, output_tensor)
115+
inputs = tf.keras.Input(shape=(sequence_length,), dtype="int32")
116+
outputs = test_layer(inputs)
117+
model = tf.keras.Model(inputs, outputs)
120118

121-
input_data = tf.constant(
122-
[
123-
[1.0, 1.0, 1.0, 1.0],
124-
[1.0, 1.0, 1.0, 1.0],
125-
[1.0, 1.0, 1.0, 1.0],
126-
[1.0, 1.0, 1.0, 1.0],
127-
],
128-
)
129-
expected_output_data = tf.constant(
130-
[
131-
[
132-
[2.0, 2.0, 2.0],
133-
[2.0, 2.0, 2.0],
134-
[2.0, 2.0, 2.0],
135-
[2.0, 2.0, 2.0],
136-
],
137-
[
138-
[2.0, 2.0, 2.0],
139-
[2.0, 2.0, 2.0],
140-
[2.0, 2.0, 2.0],
141-
[2.0, 2.0, 2.0],
142-
],
143-
[
144-
[2.0, 2.0, 2.0],
145-
[2.0, 2.0, 2.0],
146-
[2.0, 2.0, 2.0],
147-
[2.0, 2.0, 2.0],
148-
],
149-
[
150-
[2.0, 2.0, 2.0],
151-
[2.0, 2.0, 2.0],
152-
[2.0, 2.0, 2.0],
153-
[2.0, 2.0, 2.0],
154-
],
155-
],
156-
)
119+
input_data = tf.ones((2, sequence_length), dtype="int32")
120+
expected_output_data = tf.ones((2, sequence_length, embedding_dim)) * 2
157121
output_data = model.predict(input_data)
158122
self.assertAllClose(output_data, expected_output_data)
159123

124+
def test_mask_propagation(self):
125+
test_layer = TokenAndPositionEmbedding(
126+
vocabulary_size=5,
127+
max_length=4,
128+
embedding_dim=3,
129+
mask_zero=True,
130+
)
131+
input_data = tf.constant([[1, 0], [1, 0]])
132+
mask = input_data != 0
133+
outputs = test_layer(input_data)
134+
self.assertAllEqual(outputs._keras_mask, mask)
135+
160136
def test_save_model(self):
161137
vocabulary_size = 5
162138
sequence_length = 4

0 commit comments

Comments
 (0)