Skip to content

Commit 433b4ce

Browse files
author
lingzhi98
authored
add kv cache related keras_nlp patch (#414)
1 parent 0708d7a commit 433b4ce

File tree

2 files changed

+86
-1
lines changed

2 files changed

+86
-1
lines changed

example/gemma/README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@ export KAGGLE_KEY=xxxxxxxx
1616
Mark `intel-extension-for-openxla` folder as \<WORKSPACE\>, then
1717
```bash
1818
cd <WORKSPACE>/example/gemma/
19-
pip install keras-nlp==0.10.0 keras==3.3.2
19+
pip install keras==3.3.2
20+
git clone https://github.com/keras-team/keras-nlp.git
21+
cd keras-nlp
22+
git checkout v0.10.0
23+
git apply ../keras_nlp.patch
24+
python setup.py install
25+
cd ..
2026
pip install -r ../../test/requirements.txt
2127
```
2228

example/gemma/keras_nlp.patch

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
diff --git a/keras_nlp/models/gemma/gemma_attention.py b/keras_nlp/models/gemma/gemma_attention.py
2+
index 4b39126..c180752 100644
3+
--- a/keras_nlp/models/gemma/gemma_attention.py
4+
+++ b/keras_nlp/models/gemma/gemma_attention.py
5+
@@ -155,15 +155,15 @@ class CachedGemmaAttention(keras.layers.Layer):
6+
query = self._apply_rope(query, cache_update_index)
7+
8+
if cache is not None:
9+
- key_cache = cache[:, 0, ...]
10+
- value_cache = cache[:, 1, ...]
11+
+ key_cache = cache[0]
12+
+ value_cache = cache[1]
13+
key_update = self.key_dense(x)
14+
key_update = self._apply_rope(key_update, cache_update_index)
15+
value_update = self.value_dense(x)
16+
start = [0, cache_update_index, 0, 0]
17+
key = ops.slice_update(key_cache, start, key_update)
18+
value = ops.slice_update(value_cache, start, value_update)
19+
- cache = ops.stack((key, value), axis=1)
20+
+ cache = [key, value]
21+
else:
22+
key = self.key_dense(x)
23+
key = self._apply_rope(key, cache_update_index)
24+
diff --git a/keras_nlp/models/gemma/gemma_causal_lm.py b/keras_nlp/models/gemma/gemma_causal_lm.py
25+
index 26e9aad..d29238c 100644
26+
--- a/keras_nlp/models/gemma/gemma_causal_lm.py
27+
+++ b/keras_nlp/models/gemma/gemma_causal_lm.py
28+
@@ -215,17 +215,17 @@ class GemmaCausalLM(CausalLM):
29+
# Each decoder layer has a cache; we update them separately.
30+
caches = []
31+
for i, transformer_layer in enumerate(self.backbone.transformer_layers):
32+
- current_cache = cache[:, i, ...]
33+
+ current_cache = cache[i]
34+
x, next_cache = transformer_layer(
35+
x,
36+
cache=current_cache,
37+
cache_update_index=cache_update_index,
38+
)
39+
caches.append(next_cache)
40+
- cache = ops.stack(caches, axis=1)
41+
+
42+
hidden_states = x = self.backbone.layer_norm(x)
43+
logits = self.backbone.token_embedding(x, reverse=True)
44+
- return logits, hidden_states, cache
45+
+ return logits, hidden_states, caches
46+
47+
def _build_cache(self, token_ids):
48+
"""Build an empty cache for use with `call_with_cache()`."""
49+
@@ -234,11 +234,13 @@ class GemmaCausalLM(CausalLM):
50+
num_layers = self.backbone.num_layers
51+
num_heads = self.backbone.num_key_value_heads
52+
head_dim = self.backbone.head_dim
53+
- shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
54+
- cache = ops.zeros(shape, dtype=self.compute_dtype)
55+
+ shape = [batch_size, max_length, num_heads, head_dim]
56+
+ cache_list = []
57+
+ for _ in range(0, num_layers):
58+
+ cache_list.append([ops.zeros(shape, dtype=self.compute_dtype), ops.zeros(shape, dtype=self.compute_dtype)])
59+
# Seed the cache.
60+
- _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
61+
- return hidden_states, cache
62+
+ _, hidden_states, cache_list = self.call_with_cache(token_ids, cache_list, 0)
63+
+ return hidden_states, cache_list
64+
65+
def generate_step(
66+
self,
67+
diff --git a/keras_nlp/models/gemma/gemma_decoder_block.py b/keras_nlp/models/gemma/gemma_decoder_block.py
68+
index 0a91655..3ae7f8a 100644
69+
--- a/keras_nlp/models/gemma/gemma_decoder_block.py
70+
+++ b/keras_nlp/models/gemma/gemma_decoder_block.py
71+
@@ -117,7 +117,7 @@ class GemmaDecoderBlock(keras.layers.Layer):
72+
batch_size = ops.shape(x)[0]
73+
input_length = output_length = ops.shape(x)[1]
74+
if cache is not None:
75+
- input_length = ops.shape(cache)[2]
76+
+ input_length = ops.shape(cache[0])[1]
77+
78+
causal_mask = compute_causal_mask(
79+
batch_size=batch_size,

0 commit comments

Comments
 (0)