Skip to content

Commit e6a1bca

Browse files
authored
[Bug fixes] fix config attributes backward compatibility (#4237)
* fix config comp * update comp * complete fix config comp * add deprecated warning * add common testing * add turn-on keys * update modeling * fix typo
1 parent 1c7ddf0 commit e6a1bca

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

paddlenlp/transformers/model_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,27 @@ def _post_init(self, original_init, *args, **kwargs):
279279
init_dict = fn_args_to_dict(original_init, *((self,) + args), **kwargs)
280280
self.config = init_dict
281281

282+
def __getattr__(self, name):
283+
"""
284+
called when the attribute name is missed in the model
285+
286+
Args:
287+
name: the name of attribute
288+
289+
Returns: the value of attribute
290+
291+
"""
292+
try:
293+
return super(PretrainedModel, self).__getattr__(name)
294+
except AttributeError:
295+
result = getattr(self.config, name)
296+
297+
logger.warning(
298+
f"Do not access config from `model.{name}` which will be deprecated after v2.6.0, "
299+
f"Instead, do `model.config.{name}`"
300+
)
301+
return result
302+
282303
@property
283304
def base_model(self):
284305
"""

tests/transformers/test_modeling_common.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class ModelTesterMixin:
6767
test_resize_position_embeddings = False
6868
test_mismatched_shapes = True
6969
test_missing_keys = True
70+
test_model_compatibility_keys = False
7071
use_test_inputs_embeds = False
7172
use_test_model_name_list = True
7273
is_encoder_decoder = False
@@ -542,6 +543,31 @@ def random_choice_pretrained_config_field(self) -> Optional[str]:
542543
fields = [key for key, value in config.to_dict() if value]
543544
return random.choice(fields)
544545

546+
def test_for_missed_attribute(self):
547+
if not self.test_model_compatibility_keys:
548+
self.skipTest(f"Do not test model_compatibility_keys on {self.base_model_class}")
549+
return
550+
551+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
552+
for model_class in self.all_model_classes:
553+
if not model_class.constructed_from_pretrained_config():
554+
continue
555+
556+
model = self._make_model_instance(config, model_class)
557+
558+
all_maps: dict = copy.deepcopy(model_class.config_class.attribute_map)
559+
all_maps.update(model_class.config_class.standard_config_map)
560+
561+
for old_attribute, new_attribute in all_maps.items():
562+
old_value = getattr(model, old_attribute)
563+
new_value = getattr(model, new_attribute)
564+
565+
# eg: dropout can be an instance of nn.Dropout, so we should check it attribute
566+
if type(new_value) != type(old_value):
567+
continue
568+
569+
self.assertEqual(old_value, new_value)
570+
545571

546572
class ModelTesterPretrainedMixin:
547573
base_model_class: PretrainedModel = None

0 commit comments

Comments
 (0)