|
20 | 20 | import shutil
|
21 | 21 | import tempfile
|
22 | 22 | import unittest
|
| 23 | +from typing import Optional, Tuple, Type |
23 | 24 |
|
24 | 25 | import numpy as np
|
25 | 26 | import paddle
|
26 | 27 |
|
27 | 28 | from paddlenlp.transformers.configuration_utils import PretrainedConfig
|
28 | 29 | from paddlenlp.transformers.model_utils import PretrainedModel
|
29 |
| -from paddlenlp.utils.env import MODEL_HOME |
| 30 | +from paddlenlp.utils.env import CONFIG_NAME, LEGACY_CONFIG_NAME, MODEL_HOME |
30 | 31 |
|
31 | 32 | from ..testing_utils import slow
|
32 | 33 |
|
@@ -59,8 +60,8 @@ def check_two_model_parameter(first_model: PretrainedModel, second_model: Pretra
|
59 | 60 |
|
60 | 61 | class ModelTesterMixin:
|
61 | 62 | model_tester = None
|
62 |
| - base_model_class = None |
63 |
| - all_model_classes = () |
| 63 | + base_model_class: Optional[Type[PretrainedModel]] = None |
| 64 | + all_model_classes: Tuple[Type[PretrainedModel]] = () |
64 | 65 | all_generative_model_classes = ()
|
65 | 66 | test_resize_embeddings = True
|
66 | 67 | test_resize_position_embeddings = False
|
@@ -493,20 +494,50 @@ def test_model_name_list(self):
|
493 | 494 | model = self.base_model_class(**config)
|
494 | 495 | self.assertTrue(len(model.model_name_list) != 0)
|
495 | 496 |
|
| 497 | + def test_pretrained_config_save_load(self): |
| 498 | + |
| 499 | + if self.base_model_class is None or not self.base_model_class.constructed_from_pretrained_config(): |
| 500 | + return |
| 501 | + |
| 502 | + config_class = self.base_model_class.config_class |
| 503 | + with tempfile.TemporaryDirectory() as tempdir: |
| 504 | + config = config_class() |
| 505 | + |
| 506 | + config.save_pretrained(tempdir) |
| 507 | + |
| 508 | + # check the file exist |
| 509 | + self.assertFalse(os.path.exists(os.path.join(tempdir, LEGACY_CONFIG_NAME))) |
| 510 | + self.assertTrue(os.path.exists(os.path.join(tempdir, CONFIG_NAME))) |
| 511 | + |
| 512 | + # rename the CONFIG_NAME |
| 513 | + shutil.move(os.path.join(tempdir, CONFIG_NAME), os.path.join(tempdir, LEGACY_CONFIG_NAME)) |
| 514 | + |
| 515 | + loaded_config = config.__class__.from_pretrained(tempdir) |
| 516 | + self.assertEqual(config.hidden_size, loaded_config.hidden_size) |
| 517 | + |
| 518 | + def random_choice_pretrained_config_field(self) -> Optional[str]: |
| 519 | + |
| 520 | + if self.base_model_class is None or not self.base_model_class.constructed_from_pretrained_config(): |
| 521 | + return None |
| 522 | + |
| 523 | + config = self.base_model_class.config_class() |
| 524 | + fields = [key for key, value in config.to_dict() if value] |
| 525 | + return random.choice(fields) |
| 526 | + |
496 | 527 |
|
497 | 528 | class ModelTesterPretrainedMixin:
|
498 | 529 | base_model_class: PretrainedModel = None
|
499 | 530 | hf_remote_test_model_path: str = None
|
500 | 531 | paddlehub_remote_test_model_path: str = None
|
501 | 532 |
|
| 533 | + # Download from HF doesn't work in CI yet |
502 | 534 | @slow
|
503 | 535 | def test_model_from_pretrained_hf_hub(self):
|
504 | 536 | if self.hf_remote_test_model_path is None or self.base_model_class is None:
|
505 | 537 | return
|
506 | 538 | model = self.base_model_class.from_pretrained(self.hf_remote_test_model_path, from_hf_hub=True)
|
507 | 539 | self.assertIsNotNone(model)
|
508 | 540 |
|
509 |
| - @slow |
510 | 541 | def test_model_from_pretrained_paddle_hub(self):
|
511 | 542 | if self.paddlehub_remote_test_model_path is None or self.base_model_class is None:
|
512 | 543 | return
|
@@ -553,8 +584,6 @@ def test_pretrained_save_and_load(self):
|
553 | 584 | os.path.join(MODEL_HOME, model_name),
|
554 | 585 | tempdirname,
|
555 | 586 | )
|
556 |
| - files = os.listdir(tempdirname) |
557 |
| - |
558 | 587 | saved_model_state_file = os.path.join(
|
559 | 588 | tempdirname, self.base_model_class.resource_files_names["model_state"]
|
560 | 589 | )
|
|
0 commit comments