Skip to content

Commit af7009e

Browse files
authored
[BAICHUAN][INF] fix attention_mask shape alignment issue (#4659) (#4673)
* fix attention_mask shape alignment issue * fix flake8
1 parent c6a6e89 commit af7009e

File tree

1 file changed

+13
-0
lines changed
  • intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules

1 file changed

+13
-0
lines changed

intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/baichuan.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,19 @@ def forward(
168168
print("Unsupported input shape")
169169
return
170170

171+
# broadcast attention mask if needed
172+
if attention_mask.dim() < 4:
173+
attention_mask = (
174+
attention_mask.unsqueeze(0)
175+
.expand(
176+
bs * beam,
177+
attention_mask.shape[0],
178+
attention_mask.shape[1],
179+
attention_mask.shape[2],
180+
)
181+
.contiguous()
182+
)
183+
171184
IPEXTransformerAttn.beam_size = beam
172185
first_token = True if past_key_value is None else False
173186

0 commit comments

Comments
 (0)