Skip to content

Commit cd27761

Browse files
guangy10Guang Yang
andauthored
Roberta is ExecuTorch compatible (#34425)
* Roberta is ExecuTorch compatible * [run_slow] roberta --------- Co-authored-by: Guang Yang <[email protected]>
1 parent 9bee9ff commit cd27761

File tree

1 file changed

+42
-1
lines changed

1 file changed

+42
-1
lines changed

tests/models/roberta/test_modeling_roberta.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import unittest
1818

19-
from transformers import RobertaConfig, is_torch_available
19+
from transformers import AutoTokenizer, RobertaConfig, is_torch_available
2020
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
2121

2222
from ...generation.test_utils import GenerationTesterMixin
@@ -41,6 +41,7 @@
4141
RobertaEmbeddings,
4242
create_position_ids_from_input_ids,
4343
)
44+
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
4445

4546
ROBERTA_TINY = "sshleifer/tiny-distilroberta-base"
4647

@@ -576,3 +577,43 @@ def test_inference_classification_head(self):
576577
# expected_tensor = roberta.predict("mnli", input_ids, return_logits=True).detach()
577578

578579
self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))
580+
581+
@slow
582+
def test_export(self):
583+
if not is_torch_greater_or_equal_than_2_4:
584+
self.skipTest(reason="This test requires torch >= 2.4 to run.")
585+
586+
roberta_model = "FacebookAI/roberta-base"
587+
device = "cpu"
588+
attn_implementation = "sdpa"
589+
max_length = 512
590+
591+
tokenizer = AutoTokenizer.from_pretrained(roberta_model)
592+
inputs = tokenizer(
593+
"The goal of life is <mask>.",
594+
return_tensors="pt",
595+
padding="max_length",
596+
max_length=max_length,
597+
)
598+
599+
model = RobertaForMaskedLM.from_pretrained(
600+
roberta_model,
601+
device_map=device,
602+
attn_implementation=attn_implementation,
603+
use_cache=True,
604+
)
605+
606+
logits = model(**inputs).logits
607+
eager_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices)
608+
self.assertEqual(eager_predicted_mask.split(), ["happiness", "love", "peace", "freedom", "simplicity"])
609+
610+
exported_program = torch.export.export(
611+
model,
612+
args=(inputs["input_ids"],),
613+
kwargs={"attention_mask": inputs["attention_mask"]},
614+
strict=True,
615+
)
616+
617+
result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"])
618+
exported_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices)
619+
self.assertEqual(eager_predicted_mask, exported_predicted_mask)

0 commit comments

Comments
 (0)