Skip to content

Commit bcb40f9

Browse files
ganyi1996ppoWu, Hui
andauthored
enable new cache strategy for GroupAttention (#3542)
* add the cache strategy for groupattention * fix typo error * fix flake8 --------- Co-authored-by: Wu, Hui <[email protected]>
1 parent cb7be86 commit bcb40f9

File tree

2 files changed

+12
-56
lines changed

2 files changed

+12
-56
lines changed

intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/transformer_modules/Attention.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,28 +31,28 @@ def pre_qkv(self, hidden_states, key_value_states, layer_past, **kwargs):
3131

3232
def prepare_cache_for_greedy_search(self, hidden_states, layer_past):
3333
bs_beam, seq_len, _ = self.get_runtime_shape(hidden_states)
34-
self.prepare_kv_cache(hidden_states)
34+
self.prepare_kv_cache(hidden_states, self.num_attn_head)
3535

3636
self.prev_seq_len = layer_past[0].size(2) if layer_past is not None else 0
3737
self.seq_len = self.prev_seq_len + 1 if self.prev_seq_len != 0 else seq_len
3838

3939
def prepare_cache_for_beam_search(self, hidden_states, layer_past):
40-
self.prepare_kv_prompt(hidden_states)
41-
self.prepare_kv_cache(hidden_states)
40+
self.prepare_kv_prompt(hidden_states, self.num_attn_head)
41+
self.prepare_kv_cache(hidden_states, self.num_attn_head)
4242
if self.is_1st_token_beam_search():
4343
self.prev_seq_len = 0
4444
self.seq_len = 0
4545
else:
4646
self.seq_len = self.prev_seq_len + 1
4747

48-
def prepare_kv_prompt(self, hidden_states):
48+
def prepare_kv_prompt(self, hidden_states, kv_head):
4949
bs_beam, seq_len, embed_dim = self.get_runtime_shape(hidden_states)
5050
if (
5151
self.runtime_cache.key_prompt is None
5252
or self.runtime_cache.value_prompt is None
5353
or IPEXTransformerAttn.timestamp == 0
5454
):
55-
out_shape = [bs_beam, seq_len, self.head_dim * self.num_attn_head]
55+
out_shape = [bs_beam, seq_len, self.head_dim * kv_head]
5656
self.runtime_cache.key_prompt = torch.empty(
5757
out_shape, device=hidden_states.device, dtype=hidden_states.dtype
5858
)
@@ -67,7 +67,7 @@ def ready_for_runtime_cache_update(self):
6767
response_cache_len = self.runtime_cache.key_cache.size(0) - prompt_len
6868
return response_cache_len == IPEXTransformerAttn.timestamp
6969

70-
def prepare_kv_cache(self, hidden_states):
70+
def prepare_kv_cache(self, hidden_states, kv_head):
7171
bs_beam, seq_len, embed_dim = self.get_runtime_shape(hidden_states)
7272
batch_size = bs_beam // self.beam_size
7373
if (
@@ -81,7 +81,7 @@ def prepare_kv_cache(self, hidden_states):
8181
if self.is_beam_search()
8282
else self.runtime_cache_size + seq_len
8383
)
84-
cache_shape = [cache_len, bs_beam, self.num_attn_head, self.head_dim]
84+
cache_shape = [cache_len, bs_beam, kv_head, self.head_dim]
8585
self.runtime_cache.key_cache = torch.empty(
8686
cache_shape, device=hidden_states.device, dtype=hidden_states.dtype
8787
)
@@ -93,7 +93,7 @@ def prepare_kv_cache(self, hidden_states):
9393
elif self.ready_for_runtime_cache_update():
9494
old_cache_len = self.runtime_cache.key_cache.size(0)
9595
cache_len = old_cache_len + self.runtime_cache_size
96-
cache_shape = [cache_len, bs_beam, self.num_attn_head, self.head_dim]
96+
cache_shape = [cache_len, bs_beam, kv_head, self.head_dim]
9797
key_cache_tmp = torch.empty(
9898
cache_shape, device=hidden_states.device, dtype=hidden_states.dtype
9999
)

intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/transformer_modules/GroupedAttention.py

Lines changed: 4 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from .._transformer_configuration import IPEXTransformerConfig
44
from .Attention import IPEXTransformerAttnOptimizedFp16
5-
from .BaseAttention import IPEXTransformerAttn
65

76

87
class IPEXTransformerAttnOptimizedFp16Grouped(IPEXTransformerAttnOptimizedFp16):
@@ -19,54 +18,11 @@ def cat_qkv(self):
1918
else:
2019
pass
2120

22-
def prepare_kv_prompt(self, hidden_states):
23-
if self.num_kv_group <= 1:
24-
return super().prepare_kv_prompt(hidden_states)
25-
bs_beam, seq_len, embed_dim = self.get_runtime_shape(hidden_states)
26-
if (
27-
self.runtime_cache.key_prompt is None
28-
or self.runtime_cache.value_prompt is None
29-
or IPEXTransformerAttn.timestamp == 0
30-
):
31-
out_shape = [bs_beam, seq_len, self.head_dim * self.num_kv_head]
32-
self.runtime_cache.key_prompt = torch.empty(
33-
out_shape, device=hidden_states.device, dtype=hidden_states.dtype
34-
)
35-
self.runtime_cache.value_prompt = torch.empty(
36-
out_shape, device=hidden_states.device, dtype=hidden_states.dtype
37-
)
21+
def prepare_kv_prompt(self, hidden_states, kv_head):
22+
return super().prepare_kv_prompt(hidden_states, self.num_kv_head)
3823

39-
def prepare_kv_cache(self, hidden_states):
40-
if self.num_kv_group <= 1:
41-
return super().prepare_kv_cache(hidden_states)
42-
bs_beam, seq_len, embed_dim = self.get_runtime_shape(hidden_states)
43-
batch_size = bs_beam // self.beam_size
44-
if (
45-
self.runtime_cache.key_cache is None
46-
or self.runtime_cache.value_cache is None
47-
or batch_size != self.batch_size
48-
):
49-
if self.is_beam_search():
50-
cache_shape = [
51-
self.max_out_position,
52-
bs_beam,
53-
self.num_kv_head,
54-
self.head_dim,
55-
]
56-
else:
57-
cache_shape = [
58-
self.max_position,
59-
bs_beam,
60-
self.num_kv_head,
61-
self.head_dim,
62-
]
63-
self.runtime_cache.key_cache = torch.empty(
64-
cache_shape, device=hidden_states.device, dtype=hidden_states.dtype
65-
)
66-
self.runtime_cache.value_cache = torch.empty(
67-
cache_shape, device=hidden_states.device, dtype=hidden_states.dtype
68-
)
69-
self.batch_size = batch_size
24+
def prepare_kv_cache(self, hidden_states, kv_head):
25+
return super().prepare_kv_cache(hidden_states, self.num_kv_head)
7026

7127
def prepare_qkv_input_1st_token_beam_search(self, hidden_states, **kwargs):
7228
if self.num_kv_group <= 1:

0 commit comments

Comments
 (0)