@@ -38,8 +38,11 @@ def _from_pretrained(
3838 ** kwargs ,
3939 ) -> "BaseModel" :
4040 # Little hack but in practice NO-ONE is creating 5 inherited classes for their framework :D
41- if inspect .signature (cls .__init__ ).parameters .get ("config" ):
41+ init_parameters = inspect .signature (cls .__init__ ).parameters
42+ if init_parameters .get ("config" ):
4243 return cls (config = kwargs .get ("config" ))
44+ if init_parameters .get ("kwargs" ):
45+ return cls (** kwargs )
4346 return cls ()
4447
4548
@@ -68,6 +71,11 @@ def __init__(self, config: Optional[Dict] = None):
6871 pass
6972
7073
74+ class DummyModelWithKwargs (BaseModel , ModelHubMixin ):
75+ def __init__ (self , ** kwargs ):
76+ pass
77+
78+
7179@pytest .mark .usefixtures ("fx_cache_dir" )
7280class HubMixinTest (unittest .TestCase ):
7381 cache_dir : Path
@@ -132,6 +140,34 @@ def test_save_pretrained_with_dict_config(self):
132140 model .save_pretrained (self .cache_dir , config = CONFIG_AS_DICT )
133141 self .assert_valid_config_json ()
134142
143+ def test_init_accepts_kwargs_no_config (self ):
144+ """
145+ Test that if `__init__` accepts **kwargs and config file doesn't exist then no 'config' kwargs is passed.
146+
147+ Regression test. See https://github.com/huggingface/huggingface_hub/pull/2058.
148+ """
149+ model = DummyModelWithKwargs ()
150+ model .save_pretrained (self .cache_dir )
151+ with patch .object (
152+ DummyModelWithKwargs , "_from_pretrained" , return_value = DummyModelWithKwargs ()
153+ ) as from_pretrained_mock :
154+ model = DummyModelWithKwargs .from_pretrained (self .cache_dir )
155+ assert "config" not in from_pretrained_mock .call_args_list [0 ].kwargs
156+
157+ def test_init_accepts_kwargs_with_config (self ):
158+ """
159+ Test that if `__init__` accepts **kwargs and config file exists then the 'config' kwargs is passed.
160+
161+ Regression test. See https://github.com/huggingface/huggingface_hub/pull/2058.
162+ """
163+ model = DummyModelWithKwargs ()
164+ model .save_pretrained (self .cache_dir , config = CONFIG_AS_DICT )
165+ with patch .object (
166+ DummyModelWithKwargs , "_from_pretrained" , return_value = DummyModelWithKwargs ()
167+ ) as from_pretrained_mock :
168+ model = DummyModelWithKwargs .from_pretrained (self .cache_dir )
169+ assert "config" in from_pretrained_mock .call_args_list [0 ].kwargs
170+
135171 def test_save_pretrained_with_push_to_hub (self ):
136172 repo_id = repo_name ("save" )
137173 save_directory = self .cache_dir / repo_id
0 commit comments