diff --git a/keras_hub/src/layers/modeling/transformer_layer_utils.py b/keras_hub/src/layers/modeling/transformer_layer_utils.py index ebc8ff37be..8ffb8e18a4 100644 --- a/keras_hub/src/layers/modeling/transformer_layer_utils.py +++ b/keras_hub/src/layers/modeling/transformer_layer_utils.py @@ -90,3 +90,18 @@ def merge_padding_and_attention_mask( else: return ops.minimum(mask, attention_mask) return mask + + +def compute_positions_from_mask(mask): + """Computes positions from provided padding mask. + + Args: + mask: Tensor of shape `(batch_size, sequence_length)`. Padding mask, + 1 for non-padding tokens, 0 for padding tokens. + + Returns: + positions: Tensor of the same shape as `mask`, which contains indices + corresponding to positions of tokens in the sequence. + """ + positions = ops.cumsum(mask, axis=-1) + return ops.subtract(positions, ops.greater_equal(positions, 1)) diff --git a/keras_hub/src/layers/modeling/transformer_layer_utils_test.py b/keras_hub/src/layers/modeling/transformer_layer_utils_test.py index 1c92950444..3c68370c96 100644 --- a/keras_hub/src/layers/modeling/transformer_layer_utils_test.py +++ b/keras_hub/src/layers/modeling/transformer_layer_utils_test.py @@ -41,3 +41,15 @@ def test_bad_mask_shapes(self): padding_mask, attention_mask, ) + + def test_compute_positions_from_mask(self): + mask = ops.array( + [ + [False, False, True, True, False], + [True, False, True, False, True], + ] + ) + output = utils.compute_positions_from_mask(mask) + + expected_output = ops.array([[0, 0, 0, 1, 1], [0, 0, 1, 1, 2]]) + self.assertAllEqual(output, expected_output) diff --git a/keras_hub/src/models/gemma/gemma_attention.py b/keras_hub/src/models/gemma/gemma_attention.py index f66a4506ce..11833d08f8 100644 --- a/keras_hub/src/models/gemma/gemma_attention.py +++ b/keras_hub/src/models/gemma/gemma_attention.py @@ -97,9 +97,9 @@ def build(self, inputs_shape): self.built = True - def _apply_rope(self, x, start_index): + def _apply_rope(self, x, start_index, positions=None): """Rope rotate q or k.""" - x = self.rope_layer(x, start_index=start_index) + x = self.rope_layer(x, start_index=start_index, positions=positions) # Gemma uses a different layout for positional embeddings. # The transformation below ensures the embeddings are numerically # equivalent to the original gemma implementation. @@ -230,12 +230,13 @@ def call( self, x, attention_mask=None, + positions=None, cache=None, cache_update_index=0, training=False, ): query = self.query_dense(x) - query = self._apply_rope(query, cache_update_index) + query = self._apply_rope(query, cache_update_index, positions=positions) if cache is not None: key_cache = cache[:, 0, ...] @@ -249,7 +250,7 @@ def call( cache = ops.stack((key, value), axis=1) else: key = self.key_dense(x) - key = self._apply_rope(key, cache_update_index) + key = self._apply_rope(key, cache_update_index, positions=positions) value = self.value_dense(x) attention_vec = self._compute_attention( diff --git a/keras_hub/src/models/gemma/gemma_backbone_test.py b/keras_hub/src/models/gemma/gemma_backbone_test.py index b5f8575332..cbcf6cdafa 100644 --- a/keras_hub/src/models/gemma/gemma_backbone_test.py +++ b/keras_hub/src/models/gemma/gemma_backbone_test.py @@ -31,6 +31,13 @@ def test_backbone_basics(self): expected_output_shape=(2, 5, 16), ) + def test_flexible_positions(self): + self.run_positions_test( + cls=GemmaBackbone, + init_kwargs=self.init_kwargs, + vocabulary_size=self.init_kwargs["vocabulary_size"], + ) + @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( @@ -188,6 +195,13 @@ def test_backbone_basics(self): expected_output_shape=(2, 10, 16), ) + def test_flexible_positions(self): + self.run_positions_test( + cls=GemmaBackbone, + init_kwargs=self.init_kwargs, + vocabulary_size=self.init_kwargs["vocabulary_size"], + ) + def test_sliding_window(self): # Test sliding window correctness by hand. backbone = GemmaBackbone(**self.init_kwargs) diff --git a/keras_hub/src/models/gemma/gemma_decoder_block.py b/keras_hub/src/models/gemma/gemma_decoder_block.py index b93e1cebc1..53804fca4d 100644 --- a/keras_hub/src/models/gemma/gemma_decoder_block.py +++ b/keras_hub/src/models/gemma/gemma_decoder_block.py @@ -4,6 +4,9 @@ from keras_hub.src.layers.modeling.transformer_layer_utils import ( compute_causal_mask, ) +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + compute_positions_from_mask, +) from keras_hub.src.layers.modeling.transformer_layer_utils import ( merge_padding_and_attention_mask, ) @@ -178,9 +181,14 @@ def call( cache_update_index=cache_update_index, ) else: + positions = None + if padding_mask is not None: + positions = compute_positions_from_mask(padding_mask) + attention = self.attention( normalized_x, attention_mask=attention_mask, + positions=positions, ) if self.use_post_attention_norm: diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index f70ab78840..d07ec2e119 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -720,6 +720,47 @@ def compare(actual, expected): output = ops.argmax(output, axis=-1) self.assertAllEqual(output, expected_labels) + def run_positions_test( + self, + cls, + init_kwargs, + vocabulary_size, + ): + """Tests that conventional and flexible positions give same output.""" + model = cls(**init_kwargs) + + rng = np.random.default_rng(seed=42) + x1 = { + "token_ids": rng.integers(low=1, high=vocabulary_size, size=(2, 5)), + "padding_mask": np.array( + [ + [True] * 3 + [False] * 2, + [True] * 2 + [False] * 3, + ] + ), + } + # Convert token_ids to list for easier manipulation. + token_ids_lst = x1["token_ids"].tolist() + x2 = { + "token_ids": np.array( + [ + [0] + token_ids_lst[0][:3] + [0], + [0] * 2 + token_ids_lst[1][:2] + [0], + ] + ), + "padding_mask": np.array( + [ + [False] + [True] * 3 + [False], + [False] * 2 + [True] * 2 + [False], + ] + ), + } + + output_1 = model.predict(x1) + output_2 = model.predict(x2) + self.assertAllClose(output_1[0][:3], output_2[0][1:4]) + self.assertAllClose(output_1[1][:2], output_2[1][2:4]) + def get_test_data_dir(self): return str(pathlib.Path(__file__).parent / "test_data")