Skip to content

Commit f4e5fa3

Browse files
committed
replacet call to _prepare_4d_causal_attention_mask
1 parent d59c15c commit f4e5fa3

File tree

2 files changed

+4
-9
lines changed

2 files changed

+4
-9
lines changed

src/petals/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,7 @@
1717
from petals.utils import *
1818
from petals.utils.logging import initialize_logs as _initialize_logs
1919

20-
__version__ = "2.3.0.dev1"
21-
22-
23-
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
24-
assert (
25-
version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0")
26-
), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0"
20+
__version__ = "2.3.0.dev2"
2721

2822

2923
def _override_bfloat16_mode_default():

src/petals/models/llama/block.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import torch.nn as nn
1111
import torch.nn.functional as F
12+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
1213
from transformers.models.llama.modeling_llama import (
1314
LlamaAttention,
1415
LlamaConfig,
@@ -244,8 +245,8 @@ def forward(
244245
attention_mask = torch.ones(
245246
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
246247
)
247-
attention_mask = LlamaModel._prepare_decoder_attention_mask(
248-
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
248+
attention_mask = _prepare_4d_causal_attention_mask(
249+
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
249250
)
250251

251252
outputs = super().forward(

0 commit comments

Comments
 (0)