|
| 1 | +diff --git a/keras_nlp/models/gemma/gemma_attention.py b/keras_nlp/models/gemma/gemma_attention.py |
| 2 | +index 4b39126..c180752 100644 |
| 3 | +--- a/keras_nlp/models/gemma/gemma_attention.py |
| 4 | ++++ b/keras_nlp/models/gemma/gemma_attention.py |
| 5 | +@@ -155,15 +155,15 @@ class CachedGemmaAttention(keras.layers.Layer): |
| 6 | + query = self._apply_rope(query, cache_update_index) |
| 7 | + |
| 8 | + if cache is not None: |
| 9 | +- key_cache = cache[:, 0, ...] |
| 10 | +- value_cache = cache[:, 1, ...] |
| 11 | ++ key_cache = cache[0] |
| 12 | ++ value_cache = cache[1] |
| 13 | + key_update = self.key_dense(x) |
| 14 | + key_update = self._apply_rope(key_update, cache_update_index) |
| 15 | + value_update = self.value_dense(x) |
| 16 | + start = [0, cache_update_index, 0, 0] |
| 17 | + key = ops.slice_update(key_cache, start, key_update) |
| 18 | + value = ops.slice_update(value_cache, start, value_update) |
| 19 | +- cache = ops.stack((key, value), axis=1) |
| 20 | ++ cache = [key, value] |
| 21 | + else: |
| 22 | + key = self.key_dense(x) |
| 23 | + key = self._apply_rope(key, cache_update_index) |
| 24 | +diff --git a/keras_nlp/models/gemma/gemma_causal_lm.py b/keras_nlp/models/gemma/gemma_causal_lm.py |
| 25 | +index 26e9aad..d29238c 100644 |
| 26 | +--- a/keras_nlp/models/gemma/gemma_causal_lm.py |
| 27 | ++++ b/keras_nlp/models/gemma/gemma_causal_lm.py |
| 28 | +@@ -215,17 +215,17 @@ class GemmaCausalLM(CausalLM): |
| 29 | + # Each decoder layer has a cache; we update them separately. |
| 30 | + caches = [] |
| 31 | + for i, transformer_layer in enumerate(self.backbone.transformer_layers): |
| 32 | +- current_cache = cache[:, i, ...] |
| 33 | ++ current_cache = cache[i] |
| 34 | + x, next_cache = transformer_layer( |
| 35 | + x, |
| 36 | + cache=current_cache, |
| 37 | + cache_update_index=cache_update_index, |
| 38 | + ) |
| 39 | + caches.append(next_cache) |
| 40 | +- cache = ops.stack(caches, axis=1) |
| 41 | ++ |
| 42 | + hidden_states = x = self.backbone.layer_norm(x) |
| 43 | + logits = self.backbone.token_embedding(x, reverse=True) |
| 44 | +- return logits, hidden_states, cache |
| 45 | ++ return logits, hidden_states, caches |
| 46 | + |
| 47 | + def _build_cache(self, token_ids): |
| 48 | + """Build an empty cache for use with `call_with_cache()`.""" |
| 49 | +@@ -234,11 +234,13 @@ class GemmaCausalLM(CausalLM): |
| 50 | + num_layers = self.backbone.num_layers |
| 51 | + num_heads = self.backbone.num_key_value_heads |
| 52 | + head_dim = self.backbone.head_dim |
| 53 | +- shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] |
| 54 | +- cache = ops.zeros(shape, dtype=self.compute_dtype) |
| 55 | ++ shape = [batch_size, max_length, num_heads, head_dim] |
| 56 | ++ cache_list = [] |
| 57 | ++ for _ in range(0, num_layers): |
| 58 | ++ cache_list.append([ops.zeros(shape, dtype=self.compute_dtype), ops.zeros(shape, dtype=self.compute_dtype)]) |
| 59 | + # Seed the cache. |
| 60 | +- _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) |
| 61 | +- return hidden_states, cache |
| 62 | ++ _, hidden_states, cache_list = self.call_with_cache(token_ids, cache_list, 0) |
| 63 | ++ return hidden_states, cache_list |
| 64 | + |
| 65 | + def generate_step( |
| 66 | + self, |
| 67 | +diff --git a/keras_nlp/models/gemma/gemma_decoder_block.py b/keras_nlp/models/gemma/gemma_decoder_block.py |
| 68 | +index 0a91655..3ae7f8a 100644 |
| 69 | +--- a/keras_nlp/models/gemma/gemma_decoder_block.py |
| 70 | ++++ b/keras_nlp/models/gemma/gemma_decoder_block.py |
| 71 | +@@ -117,7 +117,7 @@ class GemmaDecoderBlock(keras.layers.Layer): |
| 72 | + batch_size = ops.shape(x)[0] |
| 73 | + input_length = output_length = ops.shape(x)[1] |
| 74 | + if cache is not None: |
| 75 | +- input_length = ops.shape(cache)[2] |
| 76 | ++ input_length = ops.shape(cache[0])[1] |
| 77 | + |
| 78 | + causal_mask = compute_causal_mask( |
| 79 | + batch_size=batch_size, |
0 commit comments