Skip to content

Commit b218809

Browse files
authored
fix beam search perf regression (#4952) (#4959)
1 parent 6cf810e commit b218809

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/transformer_modules/CacheUtils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,19 +166,26 @@ def get_kv_slice_for_decoding(
166166
layer_idx: int,
167167
key: torch.Tensor,
168168
) -> Tuple[torch.Tensor, torch.Tensor]:
169-
# return the key and value cache for decoding in shape of BNFH
169+
# return the key and value cache for decoding
170+
prompt_len = (
171+
0 if len(self.key_prompt) == 0 else self.key_prompt[layer_idx].size(2)
172+
)
170173
seqlen = self.update_or_get_seq_cnt(layer_idx) + key.size(2)
171174
if self.cache_format == CacheFormat.FBNH:
172-
key = self.key_cache[layer_idx][:seqlen, :, :, :].permute(1, 2, 0, 3)
173-
value = self.value_cache[layer_idx][:seqlen, :, :, :].permute(
175+
key = self.key_cache[layer_idx][prompt_len:seqlen, :, :, :].permute(
176+
1, 2, 0, 3
177+
)
178+
value = self.value_cache[layer_idx][prompt_len:seqlen, :, :, :].permute(
174179
1, 2, 0, 3
175180
)
176181
elif self.cache_format == CacheFormat.BNFH:
177-
key = self.key_cache[layer_idx][:, :, :seqlen, :]
178-
value = self.value_cache[layer_idx][:, :, :seqlen, :]
182+
key = self.key_cache[layer_idx][:, :, prompt_len:seqlen, :]
183+
value = self.value_cache[layer_idx][:, :, prompt_len:seqlen, :]
179184
elif self.cache_format == CacheFormat.BFNH:
180-
key = self.key_cache[layer_idx][:, :seqlen, :, :].permute(0, 2, 1, 3)
181-
value = self.value_cache[layer_idx][:, :seqlen, :, :].permute(
185+
key = self.key_cache[layer_idx][:, prompt_len:seqlen, :, :].permute(
186+
0, 2, 1, 3
187+
)
188+
value = self.value_cache[layer_idx][:, prompt_len:seqlen, :, :].permute(
182189
0, 2, 1, 3
183190
)
184191
return key, value

intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/transformer_modules/XPUAttentionfp16.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,7 @@ def sdp(self, query, key, value, past_key_value, attention_mask, head_mask, alib
239239
key_prompt, value_prompt = past_key_value.get_prompt_for_beam_search(
240240
self.layer_idx
241241
)
242-
prompt_length = key_prompt.size(2)
243242
curr_len = key.size(2)
244-
# TODO: remove this after ifmha support combined kv cache with both prompt
245-
# and decode in [bs, curr_len, num_head, head_dim] layout
246-
key = key[:, :, prompt_length:, :]
247-
value = value[:, :, prompt_length:, :]
248243
# TODO: remove this after ifmha support [bs, curr_len, num_head, head_dim] layout
249244
if (
250245
isinstance(past_key_value, IPEXStaticCache)
@@ -260,7 +255,6 @@ def sdp(self, query, key, value, past_key_value, attention_mask, head_mask, alib
260255
0,
261256
)
262257
)
263-
264258
attention_output = torch.xpu.IpexSDP_Index(
265259
query,
266260
key_prompt,
@@ -401,7 +395,6 @@ def forward(
401395
value = value.view(
402396
[value.shape[0], value.shape[1], self.num_kv_heads, self.head_dim]
403397
)
404-
405398
# apply rope to qk
406399
query, key, value = self.rotary_embedding(
407400
query, key, value, past_key_value, position_ids, self.layer_idx, curr_len

0 commit comments

Comments
 (0)