Skip to content

Commit 2bdbf2d

Browse files
committed
mask to bool for bloom fwd
1 parent c6db638 commit 2bdbf2d

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/petals/models/bloom/block.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,17 @@ def forward(
2727
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
2828
if alibi is None:
2929
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
30-
fake_inputs_embeds = torch.tensor([42], dtype=torch.float32)
30+
31+
# _prepare_4d only needs inputs_embeds.dtype. And it is changed to bool before .forward() anyways
32+
fake_inputs_embeds = torch.tensor([42], dtype=torch.float32)
33+
3134
attention_mask = _prepare_4d_causal_attention_mask(
3235
attention_mask=attention_mask,
3336
input_shape=(batch_size, seq_length),
3437
inputs_embeds=fake_inputs_embeds,
3538
past_key_values_length=past_length,
3639
)
40+
attention_mask = attention_mask.bool() # consistent with https://github.com/huggingface/transformers/pull/27086
3741
return super().forward(
3842
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
3943
)

0 commit comments

Comments
 (0)