Skip to content

Commit 8367907

Browse files
guangy10Guang Yang
authored andcommitted
DistilBERT is ExecuTorch compatible (huggingface#34475)
* DistillBERT is ExecuTorch compatible * [run_slow] distilbert * [run_slow] distilbert --------- Co-authored-by: Guang Yang <[email protected]>
1 parent 3269346 commit 8367907

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

tests/models/distilbert/test_modeling_distilbert.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import torch
3131

3232
from transformers import (
33+
AutoTokenizer,
3334
DistilBertForMaskedLM,
3435
DistilBertForMultipleChoice,
3536
DistilBertForQuestionAnswering,
@@ -38,6 +39,7 @@
3839
DistilBertModel,
3940
)
4041
from transformers.models.distilbert.modeling_distilbert import _create_sinusoidal_embeddings
42+
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
4143

4244

4345
class DistilBertModelTester:
@@ -420,3 +422,45 @@ def test_inference_no_head_absolute_embedding(self):
420422
)
421423

422424
self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
425+
426+
@slow
427+
def test_export(self):
428+
if not is_torch_greater_or_equal_than_2_4:
429+
self.skipTest(reason="This test requires torch >= 2.4 to run.")
430+
431+
distilbert_model = "distilbert-base-uncased"
432+
device = "cpu"
433+
attn_implementation = "sdpa"
434+
max_length = 64
435+
436+
tokenizer = AutoTokenizer.from_pretrained(distilbert_model)
437+
inputs = tokenizer(
438+
f"Paris is the {tokenizer.mask_token} of France.",
439+
return_tensors="pt",
440+
padding="max_length",
441+
max_length=max_length,
442+
)
443+
444+
model = DistilBertForMaskedLM.from_pretrained(
445+
distilbert_model,
446+
device_map=device,
447+
attn_implementation=attn_implementation,
448+
)
449+
450+
logits = model(**inputs).logits
451+
eager_predicted_mask = tokenizer.decode(logits[0, 4].topk(5).indices)
452+
self.assertEqual(
453+
eager_predicted_mask.split(),
454+
["capital", "birthplace", "northernmost", "centre", "southernmost"],
455+
)
456+
457+
exported_program = torch.export.export(
458+
model,
459+
args=(inputs["input_ids"],),
460+
kwargs={"attention_mask": inputs["attention_mask"]},
461+
strict=True,
462+
)
463+
464+
result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"])
465+
exported_predicted_mask = tokenizer.decode(result.logits[0, 4].topk(5).indices)
466+
self.assertEqual(eager_predicted_mask, exported_predicted_mask)

0 commit comments

Comments
 (0)