|
16 | 16 |
|
17 | 17 | import unittest |
18 | 18 |
|
19 | | -from transformers import RobertaConfig, is_torch_available |
| 19 | +from transformers import AutoTokenizer, RobertaConfig, is_torch_available |
20 | 20 | from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device |
21 | 21 |
|
22 | 22 | from ...generation.test_utils import GenerationTesterMixin |
|
41 | 41 | RobertaEmbeddings, |
42 | 42 | create_position_ids_from_input_ids, |
43 | 43 | ) |
| 44 | + from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 |
44 | 45 |
|
45 | 46 | ROBERTA_TINY = "sshleifer/tiny-distilroberta-base" |
46 | 47 |
|
@@ -576,3 +577,43 @@ def test_inference_classification_head(self): |
576 | 577 | # expected_tensor = roberta.predict("mnli", input_ids, return_logits=True).detach() |
577 | 578 |
|
578 | 579 | self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4)) |
| 580 | + |
| 581 | + @slow |
| 582 | + def test_export(self): |
| 583 | + if not is_torch_greater_or_equal_than_2_4: |
| 584 | + self.skipTest(reason="This test requires torch >= 2.4 to run.") |
| 585 | + |
| 586 | + roberta_model = "FacebookAI/roberta-base" |
| 587 | + device = "cpu" |
| 588 | + attn_implementation = "sdpa" |
| 589 | + max_length = 512 |
| 590 | + |
| 591 | + tokenizer = AutoTokenizer.from_pretrained(roberta_model) |
| 592 | + inputs = tokenizer( |
| 593 | + "The goal of life is <mask>.", |
| 594 | + return_tensors="pt", |
| 595 | + padding="max_length", |
| 596 | + max_length=max_length, |
| 597 | + ) |
| 598 | + |
| 599 | + model = RobertaForMaskedLM.from_pretrained( |
| 600 | + roberta_model, |
| 601 | + device_map=device, |
| 602 | + attn_implementation=attn_implementation, |
| 603 | + use_cache=True, |
| 604 | + ) |
| 605 | + |
| 606 | + logits = model(**inputs).logits |
| 607 | + eager_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices) |
| 608 | + self.assertEqual(eager_predicted_mask.split(), ["happiness", "love", "peace", "freedom", "simplicity"]) |
| 609 | + |
| 610 | + exported_program = torch.export.export( |
| 611 | + model, |
| 612 | + args=(inputs["input_ids"],), |
| 613 | + kwargs={"attention_mask": inputs["attention_mask"]}, |
| 614 | + strict=True, |
| 615 | + ) |
| 616 | + |
| 617 | + result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"]) |
| 618 | + exported_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices) |
| 619 | + self.assertEqual(eager_predicted_mask, exported_predicted_mask) |
0 commit comments