Skip to content

Commit 401201a

Browse files
Fix mask propagation of transformer layers (#139)
1 parent ad47068 commit 401201a

File tree

4 files changed

+25
-0
lines changed

4 files changed

+25
-0
lines changed

keras_nlp/layers/transformer_decoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(
9494
self.kernel_initializer = keras.initializers.get(kernel_initializer)
9595
self.bias_initializer = keras.initializers.get(bias_initializer)
9696
self._built = False
97+
self.supports_masking = True
9798

9899
def _build(self, input_shape):
99100
# Create layers based on input shape.

keras_nlp/layers/transformer_decoder_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,18 @@ def test_checkpointing_transformer_decoder(self):
162162
)
163163
self.assertAllClose(decoder1_output, decoder2_output)
164164

165+
def test_mask_propagation(self):
166+
decoder = transformer_decoder.TransformerDecoder(
167+
intermediate_dim=4,
168+
num_heads=2,
169+
)
170+
decoder_sequence = tf.random.uniform(shape=[1, 4, 6])
171+
encoder_sequence = tf.random.uniform(shape=[1, 4, 6])
172+
mask = tf.constant([[True, True, False, False]])
173+
decoder_sequence._keras_mask = mask
174+
outputs = decoder(decoder_sequence, encoder_sequence)
175+
self.assertAllEqual(outputs._keras_mask, mask)
176+
165177
def test_save_model(self):
166178
encoder_input = keras.Input(shape=[4, 6])
167179
decoder_input = keras.Input(shape=[4, 6])

keras_nlp/layers/transformer_encoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
self.kernel_initializer = keras.initializers.get(kernel_initializer)
9090
self.bias_initializer = keras.initializers.get(bias_initializer)
9191
self._built = False
92+
self.supports_masking = True
9293

9394
def _build(self, input_shape):
9495
# Create layers based on input shape.

keras_nlp/layers/transformer_encoder_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,17 @@ def test_one_training_step_of_transformer_encoder(self):
111111
self.assertGreater(len(grad), 1)
112112
optimizer.apply_gradients(zip(grad, model.trainable_variables))
113113

114+
def test_mask_propagation(self):
115+
encoder = transformer_encoder.TransformerEncoder(
116+
intermediate_dim=4,
117+
num_heads=2,
118+
)
119+
inputs = tf.random.uniform(shape=[1, 4, 6])
120+
mask = tf.constant([[True, True, False, False]])
121+
inputs._keras_mask = mask
122+
outputs = encoder(inputs)
123+
self.assertAllEqual(outputs._keras_mask, mask)
124+
114125
def test_checkpointing_transformer_encoder(self):
115126
encoder1 = transformer_encoder.TransformerEncoder(
116127
intermediate_dim=4,

0 commit comments

Comments
 (0)