Skip to content

Commit d6147a0

Browse files
guangy10Guang Yang
authored andcommitted
MobileBERT is ExecuTorch compatible (huggingface#34473)
Co-authored-by: Guang Yang <[email protected]>
1 parent 1bbe27a commit d6147a0

File tree

1 file changed

+42
-1
lines changed

1 file changed

+42
-1
lines changed

tests/models/mobilebert/test_modeling_mobilebert.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
import unittest
1818

19-
from transformers import MobileBertConfig, is_torch_available
19+
from packaging import version
20+
21+
from transformers import AutoTokenizer, MobileBertConfig, MobileBertForMaskedLM, is_torch_available
2022
from transformers.models.auto import get_values
2123
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
2224

@@ -384,3 +386,42 @@ def test_inference_no_head(self):
384386
upper_bound = torch.all((expected_slice / output[..., :3, :3]) <= 1 + TOLERANCE)
385387

386388
self.assertTrue(lower_bound and upper_bound)
389+
390+
@slow
391+
def test_export(self):
392+
if version.parse(torch.__version__) < version.parse("2.4.0"):
393+
self.skipTest(reason="This test requires torch >= 2.4 to run.")
394+
395+
mobilebert_model = "google/mobilebert-uncased"
396+
device = "cpu"
397+
attn_implementation = "eager"
398+
max_length = 512
399+
400+
tokenizer = AutoTokenizer.from_pretrained(mobilebert_model)
401+
inputs = tokenizer(
402+
f"the man worked as a {tokenizer.mask_token}.",
403+
return_tensors="pt",
404+
padding="max_length",
405+
max_length=max_length,
406+
)
407+
408+
model = MobileBertForMaskedLM.from_pretrained(
409+
mobilebert_model,
410+
device_map=device,
411+
attn_implementation=attn_implementation,
412+
)
413+
414+
logits = model(**inputs).logits
415+
eg_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices)
416+
self.assertEqual(eg_predicted_mask.split(), ["carpenter", "waiter", "mechanic", "teacher", "clerk"])
417+
418+
exported_program = torch.export.export(
419+
model,
420+
args=(inputs["input_ids"],),
421+
kwargs={"attention_mask": inputs["attention_mask"]},
422+
strict=True,
423+
)
424+
425+
result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"])
426+
ep_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices)
427+
self.assertEqual(eg_predicted_mask, ep_predicted_mask)

0 commit comments

Comments
 (0)