diff --git a/tests/models/distilbert/test_modeling_distilbert.py b/tests/models/distilbert/test_modeling_distilbert.py index 3a74a1557cf9..d4c51cea1257 100644 --- a/tests/models/distilbert/test_modeling_distilbert.py +++ b/tests/models/distilbert/test_modeling_distilbert.py @@ -30,6 +30,7 @@ import torch from transformers import ( + AutoTokenizer, DistilBertForMaskedLM, DistilBertForMultipleChoice, DistilBertForQuestionAnswering, @@ -38,6 +39,7 @@ DistilBertModel, ) from transformers.models.distilbert.modeling_distilbert import _create_sinusoidal_embeddings + from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 class DistilBertModelTester: @@ -420,3 +422,45 @@ def test_inference_no_head_absolute_embedding(self): ) self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4)) + + @slow + def test_export(self): + if not is_torch_greater_or_equal_than_2_4: + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + distilbert_model = "distilbert-base-uncased" + device = "cpu" + attn_implementation = "sdpa" + max_length = 64 + + tokenizer = AutoTokenizer.from_pretrained(distilbert_model) + inputs = tokenizer( + f"Paris is the {tokenizer.mask_token} of France.", + return_tensors="pt", + padding="max_length", + max_length=max_length, + ) + + model = DistilBertForMaskedLM.from_pretrained( + distilbert_model, + device_map=device, + attn_implementation=attn_implementation, + ) + + logits = model(**inputs).logits + eager_predicted_mask = tokenizer.decode(logits[0, 4].topk(5).indices) + self.assertEqual( + eager_predicted_mask.split(), + ["capital", "birthplace", "northernmost", "centre", "southernmost"], + ) + + exported_program = torch.export.export( + model, + args=(inputs["input_ids"],), + kwargs={"attention_mask": inputs["attention_mask"]}, + strict=True, + ) + + result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"]) + exported_predicted_mask = tokenizer.decode(result.logits[0, 4].topk(5).indices) + self.assertEqual(eager_predicted_mask, exported_predicted_mask)