|
30 | 30 | import torch |
31 | 31 |
|
32 | 32 | from transformers import ( |
| 33 | + AutoTokenizer, |
33 | 34 | DistilBertForMaskedLM, |
34 | 35 | DistilBertForMultipleChoice, |
35 | 36 | DistilBertForQuestionAnswering, |
|
38 | 39 | DistilBertModel, |
39 | 40 | ) |
40 | 41 | from transformers.models.distilbert.modeling_distilbert import _create_sinusoidal_embeddings |
| 42 | + from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 |
41 | 43 |
|
42 | 44 |
|
43 | 45 | class DistilBertModelTester: |
@@ -420,3 +422,45 @@ def test_inference_no_head_absolute_embedding(self): |
420 | 422 | ) |
421 | 423 |
|
422 | 424 | self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4)) |
| 425 | + |
| 426 | + @slow |
| 427 | + def test_export(self): |
| 428 | + if not is_torch_greater_or_equal_than_2_4: |
| 429 | + self.skipTest(reason="This test requires torch >= 2.4 to run.") |
| 430 | + |
| 431 | + distilbert_model = "distilbert-base-uncased" |
| 432 | + device = "cpu" |
| 433 | + attn_implementation = "sdpa" |
| 434 | + max_length = 64 |
| 435 | + |
| 436 | + tokenizer = AutoTokenizer.from_pretrained(distilbert_model) |
| 437 | + inputs = tokenizer( |
| 438 | + f"Paris is the {tokenizer.mask_token} of France.", |
| 439 | + return_tensors="pt", |
| 440 | + padding="max_length", |
| 441 | + max_length=max_length, |
| 442 | + ) |
| 443 | + |
| 444 | + model = DistilBertForMaskedLM.from_pretrained( |
| 445 | + distilbert_model, |
| 446 | + device_map=device, |
| 447 | + attn_implementation=attn_implementation, |
| 448 | + ) |
| 449 | + |
| 450 | + logits = model(**inputs).logits |
| 451 | + eager_predicted_mask = tokenizer.decode(logits[0, 4].topk(5).indices) |
| 452 | + self.assertEqual( |
| 453 | + eager_predicted_mask.split(), |
| 454 | + ["capital", "birthplace", "northernmost", "centre", "southernmost"], |
| 455 | + ) |
| 456 | + |
| 457 | + exported_program = torch.export.export( |
| 458 | + model, |
| 459 | + args=(inputs["input_ids"],), |
| 460 | + kwargs={"attention_mask": inputs["attention_mask"]}, |
| 461 | + strict=True, |
| 462 | + ) |
| 463 | + |
| 464 | + result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"]) |
| 465 | + exported_predicted_mask = tokenizer.decode(result.logits[0, 4].topk(5).indices) |
| 466 | + self.assertEqual(eager_predicted_mask, exported_predicted_mask) |
0 commit comments