Skip to content

Commit 702381e

Browse files
authored
Upgrade Transformers to v4.46.x (#757)
Changes: - re-copy changes in T5, MT5 - fix resize embeddings override - add Distilbert sdpa/ flash - add Mistral QA head conversion - add HF custom pytest markers Known issues: - Electra test failure seems to be present also in HF: https://app.circleci.com/pipelines/github/huggingface/transformers/110747/workflows/60c508ce-1261-46b2-a321-363718877ead/jobs/1473377/tests
1 parent 7550099 commit 702381e

File tree

10 files changed

+539
-268
lines changed

10 files changed

+539
-268
lines changed

conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@ def pytest_configure(config):
4646
config.addinivalue_line(
4747
"markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested"
4848
)
49+
config.addinivalue_line("markers", "is_pipeline_test: mark test to run only when pipelines are tested")
4950
config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment")
51+
config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate")
52+
config.addinivalue_line("markers", "agent_tests: mark the agent tests that are run on their specific schedule")
53+
config.addinivalue_line("markers", "not_device_test: mark the tests always running on cpu")
5054

5155

5256
def pytest_addoption(parser):

hf_transformers

Submodule hf_transformers updated 892 files

pyproject.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
11
[tool.black]
22
line-length = 119
33
target-version = ['py38', 'py39', 'py310']
4+
5+
# copied from HF for testing
6+
[tool.pytest.ini_options]
7+
markers = [
8+
"flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')",
9+
"bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
10+
"generate: marks tests that use the GenerationTesterMixin"
11+
]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
"timeout-decorator",
6161
"torch",
6262
"torchvision",
63-
"transformers~=4.45.2",
63+
"transformers~=4.46.3",
6464
]
6565

6666

src/adapters/head_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,14 @@
705705
},
706706
"layers": [None, "score"],
707707
},
708+
"MistralForQuestionAnswering": {
709+
"config": {
710+
"head_type": "question_answering",
711+
"layers": 1,
712+
"activation_function": None,
713+
},
714+
"layers": [None, "qa_outputs"],
715+
},
708716
# Electra
709717
"ElectraForTokenClassification": {
710718
"config": {

src/adapters/heads/model_mixin.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,8 @@ def tie_weights(self):
139139

140140
super().tie_weights()
141141

142-
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
143-
old_embeddings = self.get_input_embeddings()
144-
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
145-
self.set_input_embeddings(new_embeddings)
142+
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True):
143+
super()._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
146144

147145
# if word embeddings are not tied, make sure that lm head is resized as well
148146
if not self.config.tie_word_embeddings:

src/adapters/models/distilbert/modeling_distilbert.py

Lines changed: 184 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,26 @@
2525
import torch
2626
from torch import nn
2727

28-
from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention, TransformerBlock
28+
from transformers.models.distilbert.modeling_distilbert import (
29+
DistilBertFlashAttention2,
30+
DistilBertSdpaAttention,
31+
MultiHeadSelfAttention,
32+
TransformerBlock,
33+
)
34+
from transformers.utils import is_flash_attn_2_available, logging
2935

3036
from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel
3137
from ...utils import prefix_attention_mask
3238
from .mixin_distilbert import DistilBertMultiHeadSelfAttentionMixin, DistilBertTransfomerBlockAdaptersMixin
3339

3440

41+
if is_flash_attn_2_available():
42+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
43+
44+
45+
logger = logging.get_logger(__name__)
46+
47+
3548
class MultiHeadSelfAttentionWithAdapters(DistilBertMultiHeadSelfAttentionMixin, MultiHeadSelfAttention):
3649
def forward(
3750
self,
@@ -66,18 +79,20 @@ def shape(x: torch.Tensor) -> torch.Tensor:
6679

6780
def unshape(x: torch.Tensor) -> torch.Tensor:
6881
"""group heads"""
69-
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
82+
return x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.n_heads * dim_per_head)
7083

7184
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
7285
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
7386
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
7487

88+
# >>> START AH Changes <<<
7589
q, k, v = match_attn_matrices_for_parallel(q, k, v)
7690
(mask,) = adjust_tensors_for_parallel(q, mask)
7791

7892
k, v, mask = self.prefix_tuning(k, v, value, mask, invert_mask=False)
7993
bs = k.size(0) # reset for Parallel block
8094
(q,) = adjust_tensors_for_parallel(k, q)
95+
# >>> END AH Changes <<<
8196

8297
mask_reshp = (bs, 1, 1, k.size(2))
8398

@@ -105,6 +120,172 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
105120
return (context,)
106121

107122

123+
class DistilBertSdpaAttentionWithAdapters(DistilBertMultiHeadSelfAttentionMixin, DistilBertSdpaAttention):
124+
def forward(
125+
self,
126+
query: torch.Tensor,
127+
key: torch.Tensor,
128+
value: torch.Tensor,
129+
mask: torch.Tensor,
130+
head_mask: Optional[torch.Tensor] = None,
131+
output_attentions: bool = False,
132+
) -> Tuple[torch.Tensor, ...]:
133+
"""
134+
Parameters:
135+
query: torch.tensor(bs, seq_length, dim)
136+
key: torch.tensor(bs, seq_length, dim)
137+
value: torch.tensor(bs, seq_length, dim)
138+
mask: torch.tensor(bs, seq_length)
139+
140+
Returns:
141+
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
142+
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
143+
"""
144+
if output_attentions or head_mask is not None:
145+
logger.warning_once(
146+
"DistilBertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support"
147+
" `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying"
148+
" the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be"
149+
' removed using the argument `attn_implementation="eager"` when loading the model.'
150+
)
151+
return super().forward(
152+
query,
153+
key,
154+
value,
155+
mask,
156+
head_mask,
157+
output_attentions,
158+
)
159+
160+
batch_size, _, _ = query.size()
161+
dim_per_head = self.dim // self.n_heads
162+
163+
def shape(x: torch.Tensor) -> torch.Tensor:
164+
"""separate heads"""
165+
# keep first dim due to parallel composition
166+
return x.view(x.shape[0], -1, self.n_heads, dim_per_head).transpose(1, 2)
167+
168+
def unshape(x: torch.Tensor) -> torch.Tensor:
169+
"""group heads"""
170+
return x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.n_heads * dim_per_head)
171+
172+
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
173+
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
174+
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
175+
176+
# >>> START AH Changes <<<
177+
q, k, v = match_attn_matrices_for_parallel(q, k, v)
178+
(mask,) = adjust_tensors_for_parallel(q, mask)
179+
180+
k, v, mask = self.prefix_tuning(k, v, value, mask, invert_mask=False)
181+
(q,) = adjust_tensors_for_parallel(k, q)
182+
# >>> END AH Changes <<<
183+
184+
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
185+
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
186+
# Reference: https://github.com/pytorch/pytorch/issues/112577
187+
if self.require_contiguous_qkv and q.device.type == "cuda" and mask is not None:
188+
q = q.contiguous()
189+
k = k.contiguous()
190+
v = v.contiguous()
191+
192+
attn_output = torch.nn.functional.scaled_dot_product_attention(
193+
q,
194+
k,
195+
v,
196+
attn_mask=mask,
197+
dropout_p=self.dropout_prob if self.training else 0.0,
198+
is_causal=False,
199+
)
200+
201+
attn_output = unshape(attn_output)
202+
attn_output = self.out_lin(attn_output)
203+
204+
return (attn_output,)
205+
206+
207+
class DistilBertFlashAttention2WithAdapters(DistilBertMultiHeadSelfAttentionMixin, DistilBertFlashAttention2):
208+
def forward(
209+
self,
210+
query: torch.Tensor,
211+
key: torch.Tensor,
212+
value: torch.Tensor,
213+
mask: torch.Tensor,
214+
head_mask: Optional[torch.Tensor] = None,
215+
output_attentions: bool = False,
216+
) -> Tuple[torch.Tensor, ...]:
217+
"""
218+
Parameters:
219+
query: torch.tensor(bs, seq_length, dim)
220+
key: torch.tensor(bs, seq_length, dim)
221+
value: torch.tensor(bs, seq_length, dim)
222+
mask: torch.tensor(bs, seq_length)
223+
224+
Returns:
225+
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
226+
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
227+
"""
228+
batch_size, q_length, dim = query.size()
229+
230+
dim_per_head = self.dim // self.n_heads
231+
232+
def reshape(x: torch.Tensor) -> torch.Tensor:
233+
"""separate heads"""
234+
return x.view(x.shape[0], -1, self.n_heads, dim_per_head)
235+
236+
# Flash attention requires the input to have the shape
237+
# batch_size x seq_length x head_dim x hidden_dim
238+
query_states = reshape(self.q_lin(query))
239+
key_states = reshape(self.k_lin(key))
240+
value_states = reshape(self.v_lin(value))
241+
242+
attn_dropout = self.config.attention_dropout if self.training else 0.0
243+
244+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
245+
# therefore the input hidden states gets silently casted in float32. Hence, we need
246+
# cast them back in the correct dtype just to be sure everything works as expected.
247+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
248+
# in fp32. (LlamaRMSNorm handles it correctly)
249+
250+
if query_states.dtype == torch.float32:
251+
if torch.is_autocast_enabled():
252+
target_dtype = torch.get_autocast_gpu_dtype()
253+
# Handle the case where the model is quantized
254+
elif hasattr(self.config, "_pre_quantization_dtype"):
255+
target_dtype = self.config._pre_quantization_dtype
256+
else:
257+
target_dtype = self.q_lin.weight.dtype
258+
259+
logger.warning_once(
260+
f"The input hidden states seems to be silently casted in float32, this might be related to"
261+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
262+
f" {target_dtype}."
263+
)
264+
265+
query_states = query_states.to(target_dtype)
266+
key_states = key_states.to(target_dtype)
267+
value_states = value_states.to(target_dtype)
268+
269+
attn_weights = _flash_attention_forward(
270+
query_states,
271+
key_states,
272+
value_states,
273+
mask,
274+
q_length,
275+
dropout=attn_dropout,
276+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
277+
is_causal=self.is_causal,
278+
)
279+
280+
attn_weights_reshaped = attn_weights.reshape(batch_size, q_length, self.n_heads * dim_per_head)
281+
attn_output = self.out_lin(attn_weights_reshaped)
282+
283+
if output_attentions:
284+
return (attn_output, attn_weights)
285+
else:
286+
return (attn_output,)
287+
288+
108289
class TransformerBlockWithAdapters(DistilBertTransfomerBlockAdaptersMixin, TransformerBlock):
109290
def forward(
110291
self,
@@ -123,7 +304,7 @@ def forward(
123304
torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.
124305
"""
125306
adjust_tensors_for_parallel_(x, attn_mask)
126-
attn_mask = prefix_attention_mask(attn_mask, dim=1, prefix_value=1) # type: ignore
307+
attn_mask = prefix_attention_mask(attn_mask, dim=[2, 3], prefix_value=1) # type: ignore
127308

128309
# Self-Attention
129310
sa_output = self.attention(

0 commit comments

Comments
 (0)