Skip to content

Commit 4b02477

Browse files
authored
[XPU] set appropriate mask value for xpu (#9495)
1 parent 0b4b810 commit 4b02477

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

paddlenlp/transformers/llama/modeling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1546,8 +1546,9 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
15461546
expanded_attn_mask = expanded_attn_mask.astype("float32")
15471547
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
15481548
elif get_env_device() in ["xpu", "gcu"]:
1549+
min_val = paddle.finfo(dtype).min if get_env_device() == "gcu" else -1e37 # mask value for xpu
15491550
x = paddle.to_tensor(0.0, dtype=dtype)
1550-
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype)
1551+
y = paddle.to_tensor(min_val, dtype=dtype)
15511552
expanded_attn_mask = expanded_attn_mask.astype(dtype)
15521553
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
15531554
else:

0 commit comments

Comments
 (0)