|
16 | 16 |
|
17 | 17 | import unittest |
18 | 18 |
|
19 | | -from transformers import AlbertConfig, is_torch_available |
| 19 | +from packaging import version |
| 20 | + |
| 21 | +from transformers import AlbertConfig, AutoTokenizer, is_torch_available |
20 | 22 | from transformers.models.auto import get_values |
21 | 23 | from transformers.testing_utils import require_torch, slow, torch_device |
22 | 24 |
|
@@ -342,3 +344,45 @@ def test_inference_no_head_absolute_embedding(self): |
342 | 344 | ) |
343 | 345 |
|
344 | 346 | self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4)) |
| 347 | + |
| 348 | + @slow |
| 349 | + def test_export(self): |
| 350 | + if version.parse(torch.__version__) < version.parse("2.4.0"): |
| 351 | + self.skipTest(reason="This test requires torch >= 2.4 to run.") |
| 352 | + |
| 353 | + distilbert_model = "albert/albert-base-v2" |
| 354 | + device = "cpu" |
| 355 | + attn_implementation = "sdpa" |
| 356 | + max_length = 64 |
| 357 | + |
| 358 | + tokenizer = AutoTokenizer.from_pretrained(distilbert_model) |
| 359 | + inputs = tokenizer( |
| 360 | + f"Paris is the {tokenizer.mask_token} of France.", |
| 361 | + return_tensors="pt", |
| 362 | + padding="max_length", |
| 363 | + max_length=max_length, |
| 364 | + ) |
| 365 | + |
| 366 | + model = AlbertForMaskedLM.from_pretrained( |
| 367 | + distilbert_model, |
| 368 | + device_map=device, |
| 369 | + attn_implementation=attn_implementation, |
| 370 | + ) |
| 371 | + |
| 372 | + logits = model(**inputs).logits |
| 373 | + eg_predicted_mask = tokenizer.decode(logits[0, 4].topk(5).indices) |
| 374 | + self.assertEqual( |
| 375 | + eg_predicted_mask.split(), |
| 376 | + ["capital", "capitol", "comune", "arrondissement", "bastille"], |
| 377 | + ) |
| 378 | + |
| 379 | + exported_program = torch.export.export( |
| 380 | + model, |
| 381 | + args=(inputs["input_ids"],), |
| 382 | + kwargs={"attention_mask": inputs["attention_mask"]}, |
| 383 | + strict=True, |
| 384 | + ) |
| 385 | + |
| 386 | + result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"]) |
| 387 | + ep_predicted_mask = tokenizer.decode(result.logits[0, 4].topk(5).indices) |
| 388 | + self.assertEqual(eg_predicted_mask, ep_predicted_mask) |
0 commit comments