Skip to content

Commit 3252e27

Browse files
authored
Fix ModelHubMixin when kwargs and config are both passed (#2138)
1 parent 5ad7855 commit 3252e27

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

src/huggingface_hub/hub_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def from_pretrained(
398398
# Forward config to model initialization
399399
model_kwargs["config"] = config
400400

401-
elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
401+
if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
402402
for key, value in config.items():
403403
if key not in model_kwargs:
404404
model_kwargs[key] = value

tests/test_hub_mixin_pytorch.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import struct
44
import unittest
55
from pathlib import Path
6-
from typing import Any, TypeVar
6+
from typing import Any, Dict, Optional, TypeVar
77
from unittest.mock import Mock, patch
88

99
import pytest
@@ -53,10 +53,16 @@ def __init__(
5353
self.num_classes = num_classes
5454
self.state = state
5555
self.not_jsonable = not_jsonable
56+
57+
class DummyModelWithConfigAndKwargs(nn.Module, PyTorchModelHubMixin):
58+
def __init__(self, num_classes: int = 42, state: str = "layernorm", config: Optional[Dict] = None, **kwargs):
59+
super().__init__()
60+
5661
else:
5762
DummyModel = None
5863
DummyModelWithTags = None
5964
DummyModelNoConfig = None
65+
DummyModelWithConfigAndKwargs = None
6066

6167

6268
@requires("torch")
@@ -346,3 +352,12 @@ def forward(self, x):
346352
b_bias_ptr = state_dict["b.bias"].storage().data_ptr()
347353
assert a_weight_ptr == b_weight_ptr
348354
assert a_bias_ptr == b_bias_ptr
355+
356+
def test_save_pretrained_when_config_and_kwargs_are_passed(self):
357+
# Test creating model with config and kwargs => all values are saved together in config.json
358+
model = DummyModelWithConfigAndKwargs(num_classes=50, state="layernorm", config={"a": 1}, b=2, c=3)
359+
model.save_pretrained(self.cache_dir)
360+
assert model._hub_mixin_config == {"num_classes": 50, "state": "layernorm", "a": 1, "b": 2, "c": 3}
361+
362+
reloaded = DummyModelWithConfigAndKwargs.from_pretrained(self.cache_dir)
363+
assert reloaded._hub_mixin_config == model._hub_mixin_config

0 commit comments

Comments
 (0)