Skip to content

Commit 896833c

Browse files
authored
Fix some tests (especially compile with fullgraph=True on Python<3.11) (#38319)
* fix tests * better fix for python<3.11 * fixes * style
1 parent a63bc17 commit 896833c

File tree

7 files changed

+48
-83
lines changed

7 files changed

+48
-83
lines changed

src/transformers/cache_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2232,7 +2232,7 @@ def _prefetch_next_layer(self, layer_idx: int) -> None:
22322232

22332233
def _prefetch_layer_in_context(self, layer_idx: int) -> None:
22342234
"""Performs the actual copy of the layer to device cache."""
2235-
if len(self.key_cache) >= layer_idx:
2235+
if len(self.key_cache) > layer_idx:
22362236
self.device_key_cache[self.active_device_layer].copy_(self.key_cache[layer_idx], non_blocking=True)
22372237
self.device_value_cache[self.active_device_layer].copy_(self.value_cache[layer_idx], non_blocking=True)
22382238
# The layer was not yet initialized

src/transformers/integrations/executorch.py

Lines changed: 34 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# specific language governing permissions and limitations under the License.
1212

1313
import logging
14-
from contextlib import contextmanager
1514
from typing import Callable, Optional
1615

1716
import torch
@@ -110,14 +109,13 @@ def export(
110109
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
111110
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
112111

113-
with patch_mask_interface():
114-
exported_program = torch.export.export(
115-
self.model,
116-
args=(example_input_ids, example_cache_position),
117-
kwargs={},
118-
dynamic_shapes=dynamic_shapes,
119-
strict=strict if strict is not None else True,
120-
)
112+
exported_program = torch.export.export(
113+
self.model,
114+
args=(example_input_ids, example_cache_position),
115+
kwargs={},
116+
dynamic_shapes=dynamic_shapes,
117+
strict=strict if strict is not None else True,
118+
)
121119
return exported_program
122120

123121
@staticmethod
@@ -456,24 +454,6 @@ def forward(
456454
return outputs.logits
457455

458456

459-
@contextmanager
460-
def patch_mask_interface():
461-
"""
462-
Context manager to locally use a simple dict instead of `AttentionMaskInterface`, as otherwise export will fail
463-
with `strict=True` due to dynamo skip rules, i.e. `torch._dynamo.exc.Unsupported: 'inline in skipfiles:
464-
Mapping.__contains__ | __contains__, skipped according trace_rules.lookup SKIP_DIRS'`.
465-
Note that this seem to be an issue only for python<3.11.
466-
"""
467-
import transformers
468-
469-
original = transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
470-
transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS = ALL_MASK_ATTENTION_FUNCTIONS._global_mapping
471-
try:
472-
yield
473-
finally:
474-
transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS = original
475-
476-
477457
def convert_and_export_with_cache(
478458
model: PreTrainedModel,
479459
example_input_ids: Optional[torch.Tensor] = None,
@@ -515,14 +495,13 @@ def convert_and_export_with_cache(
515495
)
516496

517497
if is_torch_greater_or_equal("2.6.0"):
518-
with patch_mask_interface():
519-
exported_program = torch.export.export(
520-
TorchExportableModuleWithStaticCache(model),
521-
args=(example_input_ids, example_cache_position),
522-
kwargs={},
523-
dynamic_shapes=dynamic_shapes,
524-
strict=strict if strict is not None else True,
525-
)
498+
exported_program = torch.export.export(
499+
TorchExportableModuleWithStaticCache(model),
500+
args=(example_input_ids, example_cache_position),
501+
kwargs={},
502+
dynamic_shapes=dynamic_shapes,
503+
strict=strict if strict is not None else True,
504+
)
526505
else:
527506
if dynamic_shapes is not None:
528507
logging.warning(
@@ -534,14 +513,13 @@ def convert_and_export_with_cache(
534513
#
535514
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
536515
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
537-
with patch_mask_interface():
538-
exported_program = torch.export._trace._export(
539-
TorchExportableModuleWithStaticCache(model),
540-
args=(example_input_ids,),
541-
kwargs={"cache_position": example_cache_position},
542-
pre_dispatch=False,
543-
strict=True,
544-
)
516+
exported_program = torch.export._trace._export(
517+
TorchExportableModuleWithStaticCache(model),
518+
args=(example_input_ids,),
519+
kwargs={"cache_position": example_cache_position},
520+
pre_dispatch=False,
521+
strict=True,
522+
)
545523
return exported_program
546524

547525

@@ -634,10 +612,9 @@ def _export_encoder(self, encoder_input_ids):
634612

635613
# Export the encoder
636614
with torch.no_grad():
637-
with patch_mask_interface():
638-
exported_encoder = torch.export.export(
639-
wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True
640-
)
615+
exported_encoder = torch.export.export(
616+
wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True
617+
)
641618

642619
return exported_encoder
643620

@@ -657,17 +634,16 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi
657634

658635
# Export the decoder
659636
with torch.no_grad():
660-
with patch_mask_interface():
661-
exported_decoder = torch.export.export(
662-
wrapped_decoder,
663-
(decoder_input_ids, encoder_hidden_states, cache_position),
664-
dynamic_shapes={
665-
"decoder_input_ids": None,
666-
"encoder_hidden_states": {1: encoder_seq_len_dim},
667-
"cache_position": None,
668-
},
669-
strict=True,
670-
)
637+
exported_decoder = torch.export.export(
638+
wrapped_decoder,
639+
(decoder_input_ids, encoder_hidden_states, cache_position),
640+
dynamic_shapes={
641+
"decoder_input_ids": None,
642+
"encoder_hidden_states": {1: encoder_seq_len_dim},
643+
"cache_position": None,
644+
},
645+
strict=True,
646+
)
671647

672648
return exported_decoder
673649

src/transformers/masking_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,11 @@ def _preprocess_mask_arguments(
623623
return True, attention_mask, None, None
624624

625625
# For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask!
626-
if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS:
626+
# Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise
627+
# full graph dynamo tracing (i.e. torch.export or compile with `fullgraph=True`) will fail on Python<3.11
628+
# with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped
629+
# according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11
630+
if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping:
627631
return True, None, None, None
628632

629633
# Move the mask to correct device, and potentially switch dtype for efficiency

tests/models/cohere/test_modeling_cohere.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ def test_batched_small_model_logits(self):
232232

233233
EXPECTED_LOGITS = torch.Tensor(
234234
[
235-
[[0.0000, 0.1866, -0.1997], [0.0000, -0.0736, 0.1785], [0.0000, -0.1965, -0.0569]],
236-
[[0.0000, -0.0302, 0.1488], [0.0000, -0.0402, 0.1351], [0.0000, -0.0341, 0.1116]],
235+
[[0.0000, 0.0285, 0.0322], [0.0000, 0.0011, 0.1105], [0.0000, -0.0018, -0.1019]],
236+
[[0.0000, 0.1080, 0.0454], [0.0000, -0.1808, -0.1553], [0.0000, 0.0452, 0.0369]],
237237
]
238238
).to(device=torch_device, dtype=torch.float16)
239239

@@ -251,4 +251,4 @@ def test_batched_small_model_logits(self):
251251
output = model(**inputs)
252252

253253
logits = output.logits
254-
torch.testing.assert_close(EXPECTED_LOGITS, logits[:, :3, :3], rtol=1e-3, atol=1e-3)
254+
torch.testing.assert_close(EXPECTED_LOGITS, logits[:, -3:, :3], rtol=1e-3, atol=1e-3)

tests/models/csm/test_modeling_csm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ class CsmForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, u
150150
test_headmasking = False
151151
test_resize_embeddings = False
152152
test_resize_embeddings_untied = False
153-
test_torch_exportable = True
154153

155154
def setUp(self):
156155
self.model_tester = CsmModelTester(self)

tests/models/mixtral/test_modeling_mixtral.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -402,24 +402,12 @@ def test_small_model_logits_batched(self):
402402
#
403403
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
404404
# considering differences in hardware processing and potential deviations in generated text.
405-
EXPECTED_LOGITS_LEFT = {
406-
7: torch.Tensor(
407-
[[0.1904, 0.0500, 0.7187], [0.1933, 0.0515, 0.7187], [0.2001, 0.0559, 0.7148]],
408-
).to(torch_device),
409-
8: torch.Tensor([[0.1914, 0.0508, 0.7188], [0.1953, 0.0510, 0.7227], [0.1973, 0.0562, 0.7148]]).to(
410-
torch_device
411-
),
412-
9: torch.Tensor([[0.1904, 0.0513, 0.7227], [0.1943, 0.0518, 0.7227], [0.1982, 0.0557, 0.7148]]).to(
413-
torch_device
414-
),
415-
}
416-
417405
EXPECTED_LOGITS_LEFT_UNPADDED = {
418406
7: torch.Tensor(
419407
[[0.2236, 0.5195, -0.3828], [0.8203, -0.2275, 0.6054], [0.2656, -0.7070, 0.2460]],
420408
).to(torch_device),
421-
8: torch.Tensor([[0.2217, 0.5195, -0.3828], [0.8203, -0.2295, 0.6055], [0.2676, -0.7109, 0.2461]]).to(
422-
torch_device
409+
8: torch.Tensor([[0.2207, 0.5234, -0.3828], [0.8203, -0.2285, 0.6055], [0.2656, -0.7109, 0.2451]]).to(
410+
torch_device,
423411
),
424412
9: torch.Tensor([[0.2236, 0.5195, -0.3828], [0.8203, -0.2285, 0.6055], [0.2637, -0.7109, 0.2451]]).to(
425413
torch_device
@@ -430,8 +418,8 @@ def test_small_model_logits_batched(self):
430418
7: torch.Tensor([[0.2167, 0.1269, -0.1640], [-0.3496, 0.2988, -1.0312], [0.0688, 0.7929, 0.8007]]).to(
431419
torch_device
432420
),
433-
8: torch.Tensor([[0.2178, 0.1260, -0.1621], [-0.3496, 0.2988, -1.0312], [0.0693, 0.7930, 0.8008]]).to(
434-
torch_device
421+
8: torch.Tensor([[0.2178, 0.1270, -0.1621], [-0.3496, 0.3008, -1.0312], [0.0693, 0.7930, 0.7969]]).to(
422+
torch_device,
435423
),
436424
9: torch.Tensor([[0.2197, 0.1250, -0.1611], [-0.3516, 0.3008, -1.0312], [0.0684, 0.7930, 0.8008]]).to(
437425
torch_device
@@ -442,9 +430,6 @@ def test_small_model_logits_batched(self):
442430
logits = model(dummy_input, attention_mask=attention_mask).logits
443431
logits = logits.float()
444432

445-
torch.testing.assert_close(
446-
logits[0, :3, :3], EXPECTED_LOGITS_LEFT[self.cuda_compute_capability_major_version], atol=1e-3, rtol=1e-3
447-
)
448433
torch.testing.assert_close(
449434
logits[0, -3:, -3:],
450435
EXPECTED_LOGITS_LEFT_UNPADDED[self.cuda_compute_capability_major_version],

tests/test_modeling_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4461,6 +4461,7 @@ def test_torch_compile_for_training(self):
44614461
del loss
44624462

44634463
model = torch.compile(model, fullgraph=True, mode="reduce-overhead")
4464+
44644465
# forward compilation
44654466
set_seed(42)
44664467
loss = model(**inputs).loss

0 commit comments

Comments
 (0)