Skip to content

Commit 0dd38e9

Browse files
guangy10Guang Yang
authored andcommitted
Bert is ExecuTorch compatible (huggingface#34424)
Co-authored-by: Guang Yang <[email protected]>
1 parent ce021f1 commit 0dd38e9

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

tests/models/bert/test_modeling_bert.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import tempfile
1717
import unittest
1818

19+
from packaging import version
20+
1921
from transformers import AutoTokenizer, BertConfig, is_torch_available
2022
from transformers.models.auto import get_values
2123
from transformers.testing_utils import (
@@ -749,3 +751,43 @@ def test_sdpa_ignored_mask(self):
749751
self.assertTrue(
750752
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
751753
)
754+
755+
@slow
756+
def test_export(self):
757+
if version.parse(torch.__version__) < version.parse("2.4.0"):
758+
self.skipTest(reason="This test requires torch >= 2.4 to run.")
759+
760+
bert_model = "google-bert/bert-base-uncased"
761+
device = "cpu"
762+
attn_implementation = "sdpa"
763+
max_length = 512
764+
765+
tokenizer = AutoTokenizer.from_pretrained(bert_model)
766+
inputs = tokenizer(
767+
"the man worked as a [MASK].",
768+
return_tensors="pt",
769+
padding="max_length",
770+
max_length=max_length,
771+
)
772+
773+
model = BertForMaskedLM.from_pretrained(
774+
bert_model,
775+
device_map=device,
776+
attn_implementation=attn_implementation,
777+
use_cache=True,
778+
)
779+
780+
logits = model(**inputs).logits
781+
eg_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices)
782+
self.assertEqual(eg_predicted_mask.split(), ["carpenter", "waiter", "barber", "mechanic", "salesman"])
783+
784+
exported_program = torch.export.export(
785+
model,
786+
args=(inputs["input_ids"],),
787+
kwargs={"attention_mask": inputs["attention_mask"]},
788+
strict=True,
789+
)
790+
791+
result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"])
792+
ep_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices)
793+
self.assertEqual(eg_predicted_mask, ep_predicted_mask)

0 commit comments

Comments
 (0)