Skip to content

Commit 0d91bbd

Browse files
authored
Bump transformers and accelerate versions (#554)
Bump versions for transformers and accelerate, remove falcon-rw-1b CI tests
1 parent d59c15c commit 0d91bbd

File tree

11 files changed

+127
-29
lines changed

11 files changed

+127
-29
lines changed

.github/workflows/run-tests.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ jobs:
1414
- { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
1515
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
1616
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
17-
- { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.8' }
18-
- { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.11' }
1917
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
2018
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
2119
fail-fast: false

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ python_requires = >=3.8
3434
install_requires =
3535
torch>=1.12
3636
bitsandbytes==0.41.1
37-
accelerate>=0.22.0
37+
accelerate>=0.27.2
3838
huggingface-hub>=0.11.1,<1.0.0
3939
tokenizers>=0.13.3
40-
transformers>=4.32.0,<4.35.0 # if you change this, please also change version assert in petals/__init__.py
40+
transformers==4.37.1 # if you change this, please also change version assert in petals/__init__.py
4141
speedtest-cli==2.1.3
4242
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
4343
hivemind==1.1.10.post2

src/petals/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
from petals.utils import *
1818
from petals.utils.logging import initialize_logs as _initialize_logs
1919

20-
__version__ = "2.3.0.dev1"
20+
__version__ = "2.3.0.dev2"
2121

2222

2323
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
2424
assert (
25-
version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0")
26-
), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0"
25+
version.parse("4.37.1") <= version.parse(transformers.__version__) < version.parse("4.38.0")
26+
), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.38.0"
2727

2828

2929
def _override_bfloat16_mode_default():

src/petals/client/inference_session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
211211
self._position = 0
212212
self._max_length = max_length
213213
self.output_ids = None
214+
self.past_key_values = None
214215

215216
@property
216217
def num_blocks(self) -> int:

src/petals/client/remote_generation.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import contextlib
22
import dataclasses
33
from contextvars import ContextVar
4-
from typing import ContextManager, List, Optional
4+
from typing import Any, ContextManager, Dict, List, Optional, Tuple
55

66
import torch
77
import transformers
88
from hivemind.utils.logging import get_logger
9+
from torch import Tensor
10+
from transformers.cache_utils import Cache, DynamicCache
911
from transformers.generation.utils import ModelOutput
1012

1113
from petals.client.inference_session import InferenceSession
@@ -15,15 +17,29 @@
1517
logger = get_logger(__name__)
1618

1719

18-
@dataclasses.dataclass(frozen=True)
19-
class RemotePastKeyValues:
20-
"""A mock class representing the fact that `past_key_values` do exist but are stored on remote servers."""
20+
class RemotePastKeyValues(Cache):
21+
"""only keeps the number of seen tokens. pretends to be a legit cache"""
2122

22-
hypo_ids: Optional[torch.LongTensor] = None
23+
def __init__(self) -> None:
24+
super().__init__()
25+
self.seen_tokens = 0
26+
self.hypo_ids: Optional[torch.LongTensor] = None
2327

2428
def __getitem__(self, _index: int) -> List[torch.Tensor]:
2529
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
2630

31+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
32+
return self.seen_tokens
33+
34+
def get_max_length(self) -> Optional[int]:
35+
return None
36+
37+
def update_seen(self, new_seen: int) -> None:
38+
self.seen_tokens += new_seen
39+
40+
def reorder_cache(self, beam_idx):
41+
pass
42+
2743

2844
_skipped_tokens = ContextVar("skipped_tokens", default=0)
2945

@@ -113,6 +129,11 @@ def generate(
113129
# but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)
114130
_skipped_tokens.set(max(0, n_prev_tokens - 1))
115131

132+
if self._supports_cache_class and "past_key_values" not in kwargs:
133+
past_key_values = RemotePastKeyValues()
134+
past_key_values.update_seen(session.position)
135+
kwargs["past_key_values"] = past_key_values
136+
116137
result = super().generate(inputs, *args, **kwargs)
117138

118139
sequences = result.sequences if isinstance(result, ModelOutput) else result

src/petals/models/bloom/block.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Optional, Tuple
77

88
import torch
9+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
910
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
1011

1112

@@ -26,7 +27,13 @@ def forward(
2627
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
2728
if alibi is None:
2829
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
29-
attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
30+
attention_mask = _prepare_4d_causal_attention_mask(
31+
attention_mask=attention_mask,
32+
input_shape=(batch_size, seq_length),
33+
inputs_embeds=hidden_states,
34+
past_key_values_length=past_length,
35+
)
36+
attention_mask = attention_mask.bool()
3037
return super().forward(
3138
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
3239
)

src/petals/models/bloom/model.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import torch.nn as nn
66
from hivemind.utils.logging import get_logger
7+
from transformers.cache_utils import Cache
78
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
89
from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
910

@@ -92,12 +93,16 @@ def forward(
9293
if use_prompts:
9394
hidden_states = hidden_states[:, self.pre_seq_len :]
9495

96+
if past_key_values is None:
97+
past_key_values = RemotePastKeyValues()
98+
past_key_values.update_seen(hidden_states.size(1))
99+
95100
# Add last hidden state
96101
hidden_states = self.ln_f(hidden_states)
97102
hidden_states = hidden_states.view(output_shape)
98103
return BaseModelOutputWithPastAndCrossAttentions(
99104
last_hidden_state=hidden_states,
100-
past_key_values=RemotePastKeyValues(),
105+
past_key_values=past_key_values,
101106
hidden_states=None,
102107
attentions=None,
103108
)
@@ -107,6 +112,7 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
107112
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
108113
_keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
109114
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
115+
_supports_cache_class = True
110116

111117
config_class = DistributedBloomConfig
112118

@@ -118,6 +124,58 @@ def __init__(self, config: DistributedBloomConfig):
118124
# Initialize weights and apply final processing
119125
self.post_init()
120126

127+
def prepare_inputs_for_generation(
128+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
129+
) -> dict:
130+
# Omit tokens covered by past_key_values
131+
if past_key_values is not None:
132+
if isinstance(past_key_values, Cache):
133+
cache_length = past_key_values.get_seq_length()
134+
past_length = past_key_values.seen_tokens
135+
max_cache_length = past_key_values.get_max_length()
136+
else:
137+
cache_length = past_length = past_key_values[0][0].shape[2]
138+
max_cache_length = None
139+
140+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
141+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
142+
elif past_length < input_ids.shape[1]:
143+
input_ids = input_ids[:, past_length:]
144+
145+
if (
146+
max_cache_length is not None
147+
and attention_mask is not None
148+
and cache_length + input_ids.shape[1] > max_cache_length
149+
):
150+
attention_mask = attention_mask[:, -max_cache_length:]
151+
152+
position_ids = kwargs.get("position_ids", None)
153+
if attention_mask is not None and position_ids is None:
154+
# create position_ids on the fly for batch generation
155+
position_ids = attention_mask.long().cumsum(-1) - 1
156+
position_ids.masked_fill_(attention_mask == 0, 1)
157+
if past_key_values:
158+
position_ids = position_ids[:, -input_ids.shape[1] :]
159+
160+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
161+
if inputs_embeds is not None and past_key_values is None:
162+
model_inputs = {"inputs_embeds": inputs_embeds}
163+
else:
164+
model_inputs = {"input_ids": input_ids}
165+
166+
model_inputs.update(
167+
{
168+
"position_ids": position_ids,
169+
"past_key_values": past_key_values,
170+
"use_cache": kwargs.get("use_cache"),
171+
"attention_mask": attention_mask,
172+
}
173+
)
174+
return model_inputs
175+
176+
def _temporary_reorder_cache(self, past_key_values, beam_idx):
177+
return past_key_values
178+
121179
def get_output_embeddings(self):
122180
return self.lm_head
123181

src/petals/models/llama/block.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import torch.nn as nn
1111
import torch.nn.functional as F
12+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
1213
from transformers.models.llama.modeling_llama import (
1314
LlamaAttention,
1415
LlamaConfig,
@@ -84,8 +85,8 @@ def forward(
8485
if past_key_value is not None:
8586
kv_seq_len += past_key_value[0].shape[-2]
8687
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
87-
cos = cos[:, :, kv_seq_len - q_len :]
88-
sin = sin[:, :, kv_seq_len - q_len :]
88+
cos = cos[kv_seq_len - q_len :]
89+
sin = sin[kv_seq_len - q_len :]
8990

9091
if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
9192
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
@@ -244,8 +245,11 @@ def forward(
244245
attention_mask = torch.ones(
245246
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
246247
)
247-
attention_mask = LlamaModel._prepare_decoder_attention_mask(
248-
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
248+
attention_mask = _prepare_4d_causal_attention_mask(
249+
attention_mask=attention_mask,
250+
input_shape=(batch_size, seq_length),
251+
inputs_embeds=hidden_states,
252+
past_key_values_length=past_key_values_length,
249253
)
250254

251255
outputs = super().forward(

src/petals/models/llama/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,21 @@ def forward(
9090
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
9191
)
9292

93+
if past_key_values is None:
94+
past_key_values = RemotePastKeyValues()
95+
past_key_values.update_seen(hidden_states.size(1))
96+
9397
# Remove prefix
9498
if use_prompts:
9599
hidden_states = hidden_states[:, self.pre_seq_len :]
96100

97101
# Add last hidden state
98102
hidden_states = self.norm(hidden_states)
99103
hidden_states = hidden_states.view(output_shape)
104+
100105
return BaseModelOutputWithPast(
101106
last_hidden_state=hidden_states,
102-
past_key_values=RemotePastKeyValues(),
107+
past_key_values=past_key_values,
103108
hidden_states=None,
104109
attentions=None,
105110
)

src/petals/utils/peft.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626

2727

2828
def check_peft_repository(repo_id: str) -> bool:
29-
fs = HfFileSystem()
30-
list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
31-
return len(list_of_files) > 0
29+
return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}")
3230

3331

3432
def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):

0 commit comments

Comments
 (0)