Skip to content

Commit 8b641c0

Browse files
committed
Fix ModelHubMixin: pass config when __init__ accepts **kwargs (#2058)
* Fix ModelHubMixin: pass config when __init__ accepts **kwargs * comment
1 parent fd19b32 commit 8b641c0

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/huggingface_hub/hub_mixin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,9 @@ def from_pretrained(
270270

271271
# Forward config to model initialization
272272
model_kwargs["config"] = config
273+
elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in init_parameters.values()):
274+
# If __init__ accepts **kwargs, let's forward the config as well (as a dict)
275+
model_kwargs["config"] = config
273276

274277
instance = cls._from_pretrained(
275278
model_id=str(model_id),

tests/test_hub_mixin.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,11 @@ def _from_pretrained(
3838
**kwargs,
3939
) -> "BaseModel":
4040
# Little hack but in practice NO-ONE is creating 5 inherited classes for their framework :D
41-
if inspect.signature(cls.__init__).parameters.get("config"):
41+
init_parameters = inspect.signature(cls.__init__).parameters
42+
if init_parameters.get("config"):
4243
return cls(config=kwargs.get("config"))
44+
if init_parameters.get("kwargs"):
45+
return cls(**kwargs)
4346
return cls()
4447

4548

@@ -68,6 +71,11 @@ def __init__(self, config: Optional[Dict] = None):
6871
pass
6972

7073

74+
class DummyModelWithKwargs(BaseModel, ModelHubMixin):
75+
def __init__(self, **kwargs):
76+
pass
77+
78+
7179
@pytest.mark.usefixtures("fx_cache_dir")
7280
class HubMixinTest(unittest.TestCase):
7381
cache_dir: Path
@@ -132,6 +140,34 @@ def test_save_pretrained_with_dict_config(self):
132140
model.save_pretrained(self.cache_dir, config=CONFIG_AS_DICT)
133141
self.assert_valid_config_json()
134142

143+
def test_init_accepts_kwargs_no_config(self):
144+
"""
145+
Test that if `__init__` accepts **kwargs and config file doesn't exist then no 'config' kwargs is passed.
146+
147+
Regression test. See https://github.com/huggingface/huggingface_hub/pull/2058.
148+
"""
149+
model = DummyModelWithKwargs()
150+
model.save_pretrained(self.cache_dir)
151+
with patch.object(
152+
DummyModelWithKwargs, "_from_pretrained", return_value=DummyModelWithKwargs()
153+
) as from_pretrained_mock:
154+
model = DummyModelWithKwargs.from_pretrained(self.cache_dir)
155+
assert "config" not in from_pretrained_mock.call_args_list[0].kwargs
156+
157+
def test_init_accepts_kwargs_with_config(self):
158+
"""
159+
Test that if `__init__` accepts **kwargs and config file exists then the 'config' kwargs is passed.
160+
161+
Regression test. See https://github.com/huggingface/huggingface_hub/pull/2058.
162+
"""
163+
model = DummyModelWithKwargs()
164+
model.save_pretrained(self.cache_dir, config=CONFIG_AS_DICT)
165+
with patch.object(
166+
DummyModelWithKwargs, "_from_pretrained", return_value=DummyModelWithKwargs()
167+
) as from_pretrained_mock:
168+
model = DummyModelWithKwargs.from_pretrained(self.cache_dir)
169+
assert "config" in from_pretrained_mock.call_args_list[0].kwargs
170+
135171
def test_save_pretrained_with_push_to_hub(self):
136172
repo_id = repo_name("save")
137173
save_directory = self.cache_dir / repo_id

0 commit comments

Comments
 (0)