Skip to content

Commit b3ec2e8

Browse files
authored
[Arc][LLM] fix qwen greedy search on Arc (#4166)
* fix qwen on Arc
1 parent 4647b75 commit b3ec2e8

File tree

1 file changed

+4
-1
lines changed
  • intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/transformer_modules

1 file changed

+4
-1
lines changed

intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/transformer_modules/model_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,10 @@ def qwen_sdp(self, query, key, value, attention_mask, head_mask, alibi):
133133
attention_mask.logical_not(), torch.finfo(query.dtype).min
134134
)
135135
if not ipex._C._has_2d_block_array(0):
136-
return self.naive_sdp(query, key, value, attention_mask, head_mask, alibi)
136+
attn_output, attn_weight = self.naive_sdp(query, key, value, attention_mask, head_mask, alibi)
137+
if not self.is_beam_search():
138+
attn_output = attn_output.permute(1, 0, 2)
139+
return attn_output, attn_weight
137140
key, value, key_prompt, value_prompt = self.sdp_kv_preprocess(key, value)
138141
(
139142
dropout,

0 commit comments

Comments
 (0)