Skip to content

Commit 4f7b5e2

Browse files
guangy10Guang Yang
authored andcommitted
Albert is ExecuTorch compatible (huggingface#34476)
Co-authored-by: Guang Yang <[email protected]>
1 parent d6147a0 commit 4f7b5e2

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

tests/models/albert/test_modeling_albert.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
import unittest
1818

19-
from transformers import AlbertConfig, is_torch_available
19+
from packaging import version
20+
21+
from transformers import AlbertConfig, AutoTokenizer, is_torch_available
2022
from transformers.models.auto import get_values
2123
from transformers.testing_utils import require_torch, slow, torch_device
2224

@@ -342,3 +344,45 @@ def test_inference_no_head_absolute_embedding(self):
342344
)
343345

344346
self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
347+
348+
@slow
349+
def test_export(self):
350+
if version.parse(torch.__version__) < version.parse("2.4.0"):
351+
self.skipTest(reason="This test requires torch >= 2.4 to run.")
352+
353+
distilbert_model = "albert/albert-base-v2"
354+
device = "cpu"
355+
attn_implementation = "sdpa"
356+
max_length = 64
357+
358+
tokenizer = AutoTokenizer.from_pretrained(distilbert_model)
359+
inputs = tokenizer(
360+
f"Paris is the {tokenizer.mask_token} of France.",
361+
return_tensors="pt",
362+
padding="max_length",
363+
max_length=max_length,
364+
)
365+
366+
model = AlbertForMaskedLM.from_pretrained(
367+
distilbert_model,
368+
device_map=device,
369+
attn_implementation=attn_implementation,
370+
)
371+
372+
logits = model(**inputs).logits
373+
eg_predicted_mask = tokenizer.decode(logits[0, 4].topk(5).indices)
374+
self.assertEqual(
375+
eg_predicted_mask.split(),
376+
["capital", "capitol", "comune", "arrondissement", "bastille"],
377+
)
378+
379+
exported_program = torch.export.export(
380+
model,
381+
args=(inputs["input_ids"],),
382+
kwargs={"attention_mask": inputs["attention_mask"]},
383+
strict=True,
384+
)
385+
386+
result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"])
387+
ep_predicted_mask = tokenizer.decode(result.logits[0, 4].topk(5).indices)
388+
self.assertEqual(eg_predicted_mask, ep_predicted_mask)

0 commit comments

Comments
 (0)