diff --git a/setup.cfg b/setup.cfg index ef35f8455..7f0a77f3b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ install_requires = accelerate>=0.22.0 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 - transformers>=4.32.0,<4.35.0 # if you change this, please also change version assert in petals/__init__.py + transformers>=4.35.0 # 4.35.0 is the minimum that contains modeling_attn_mask_utils.py speedtest-cli==2.1.3 pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet hivemind==1.1.10.post2 diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 1af8bf951..bd8861c87 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -17,13 +17,7 @@ from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "2.3.0.dev1" - - -if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): - assert ( - version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0") - ), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0" +__version__ = "2.3.0.dev2" def _override_bfloat16_mode_default(): diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index 97a115ab2..5324a5183 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -34,6 +34,9 @@ class _SkipTokensMixin: def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict: input_ids = input_ids[:, _skipped_tokens.get() :] _skipped_tokens.set(0) + if "past_key_values" in kwargs: + if kwargs["past_key_values"][0][0].shape == torch.Size([0]): + kwargs["past_key_values"] = None return super().prepare_inputs_for_generation(input_ids, **kwargs) diff --git a/src/petals/models/bloom/block.py b/src/petals/models/bloom/block.py index f246bd867..e86b839c0 100644 --- a/src/petals/models/bloom/block.py +++ b/src/petals/models/bloom/block.py @@ -6,6 +6,7 @@ from typing import Optional, Tuple import torch +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor @@ -26,7 +27,14 @@ def forward( attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) if alibi is None: alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) - attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length) + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask=attention_mask, + input_shape=(batch_size, seq_length), + inputs_embeds=hidden_states, + past_key_values_length=past_length, + ) + attention_mask = attention_mask.bool() return super().forward( hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs ) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index a510abaa1..d36fcb7cc 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.falcon.modeling_falcon import ( FalconAttention, FalconConfig, @@ -418,7 +419,13 @@ def forward( attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) if alibi is None and self.config.alibi: alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) - attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask=attention_mask, + input_shape=(batch_size, seq_length), + inputs_embeds=hidden_states, + past_key_values_length=past_length, + ) outputs = super().forward( hidden_states, diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py index a8d433ded..6f539a841 100644 --- a/src/petals/models/llama/block.py +++ b/src/petals/models/llama/block.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaConfig, @@ -84,8 +85,8 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - cos = cos[:, :, kv_seq_len - q_len :] - sin = sin[:, :, kv_seq_len - q_len :] + cos = cos[kv_seq_len - q_len :] + sin = sin[kv_seq_len - q_len :] if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin) @@ -244,8 +245,11 @@ def forward( attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device ) - attention_mask = LlamaModel._prepare_decoder_attention_mask( - None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask=attention_mask, + input_shape=(batch_size, seq_length), + inputs_embeds=hidden_states, + past_key_values_length=past_key_values_length, ) outputs = super().forward( diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py index 84cbfffe2..69b77cc1e 100644 --- a/tests/test_optimized_layers.py +++ b/tests/test_optimized_layers.py @@ -2,6 +2,7 @@ import pytest import torch +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel @@ -131,8 +132,8 @@ def forward( attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device ) - attention_mask = LlamaModel._prepare_decoder_attention_mask( - None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length ) outputs = super().forward(