diff --git a/tests/models/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py index 5c87fbea8ee7..666df0848819 100644 --- a/tests/models/bert/test_modeling_bert.py +++ b/tests/models/bert/test_modeling_bert.py @@ -16,6 +16,8 @@ import tempfile import unittest +from packaging import version + from transformers import AutoTokenizer, BertConfig, is_torch_available from transformers.models.auto import get_values from transformers.testing_utils import ( @@ -823,3 +825,43 @@ def test_sdpa_ignored_mask(self): self.assertTrue( torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4) ) + + @slow + def test_export(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + bert_model = "google-bert/bert-base-uncased" + device = "cpu" + attn_implementation = "sdpa" + max_length = 512 + + tokenizer = AutoTokenizer.from_pretrained(bert_model) + inputs = tokenizer( + "the man worked as a [MASK].", + return_tensors="pt", + padding="max_length", + max_length=max_length, + ) + + model = BertForMaskedLM.from_pretrained( + bert_model, + device_map=device, + attn_implementation=attn_implementation, + use_cache=True, + ) + + logits = model(**inputs).logits + eg_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices) + self.assertEqual(eg_predicted_mask.split(), ["carpenter", "waiter", "barber", "mechanic", "salesman"]) + + exported_program = torch.export.export( + model, + args=(inputs["input_ids"],), + kwargs={"attention_mask": inputs["attention_mask"]}, + strict=True, + ) + + result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"]) + ep_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices) + self.assertEqual(eg_predicted_mask, ep_predicted_mask)