|
16 | 16 | import tempfile |
17 | 17 | import unittest |
18 | 18 |
|
| 19 | +from packaging import version |
| 20 | + |
19 | 21 | from transformers import AutoTokenizer, BertConfig, is_torch_available |
20 | 22 | from transformers.models.auto import get_values |
21 | 23 | from transformers.testing_utils import ( |
@@ -749,3 +751,43 @@ def test_sdpa_ignored_mask(self): |
749 | 751 | self.assertTrue( |
750 | 752 | torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4) |
751 | 753 | ) |
| 754 | + |
| 755 | + @slow |
| 756 | + def test_export(self): |
| 757 | + if version.parse(torch.__version__) < version.parse("2.4.0"): |
| 758 | + self.skipTest(reason="This test requires torch >= 2.4 to run.") |
| 759 | + |
| 760 | + bert_model = "google-bert/bert-base-uncased" |
| 761 | + device = "cpu" |
| 762 | + attn_implementation = "sdpa" |
| 763 | + max_length = 512 |
| 764 | + |
| 765 | + tokenizer = AutoTokenizer.from_pretrained(bert_model) |
| 766 | + inputs = tokenizer( |
| 767 | + "the man worked as a [MASK].", |
| 768 | + return_tensors="pt", |
| 769 | + padding="max_length", |
| 770 | + max_length=max_length, |
| 771 | + ) |
| 772 | + |
| 773 | + model = BertForMaskedLM.from_pretrained( |
| 774 | + bert_model, |
| 775 | + device_map=device, |
| 776 | + attn_implementation=attn_implementation, |
| 777 | + use_cache=True, |
| 778 | + ) |
| 779 | + |
| 780 | + logits = model(**inputs).logits |
| 781 | + eg_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices) |
| 782 | + self.assertEqual(eg_predicted_mask.split(), ["carpenter", "waiter", "barber", "mechanic", "salesman"]) |
| 783 | + |
| 784 | + exported_program = torch.export.export( |
| 785 | + model, |
| 786 | + args=(inputs["input_ids"],), |
| 787 | + kwargs={"attention_mask": inputs["attention_mask"]}, |
| 788 | + strict=True, |
| 789 | + ) |
| 790 | + |
| 791 | + result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"]) |
| 792 | + ep_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices) |
| 793 | + self.assertEqual(eg_predicted_mask, ep_predicted_mask) |
0 commit comments