Skip to content

Commit 9662bdc

Browse files
authored
[HubMixin] handle dataclasses in all args, not only 'config' (#2928)
1 parent 4028062 commit 9662bdc

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

src/huggingface_hub/hub_mixin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,13 +343,17 @@ def __new__(cls: Type[T], *args, **kwargs) -> T:
343343
@classmethod
344344
def _is_jsonable(cls, value: Any) -> bool:
345345
"""Check if a value is JSON serializable."""
346+
if is_dataclass(value):
347+
return True
346348
if isinstance(value, cls._hub_mixin_jsonable_custom_types):
347349
return True
348350
return is_jsonable(value)
349351

350352
@classmethod
351353
def _encode_arg(cls, arg: Any) -> Any:
352354
"""Encode an argument into a JSON serializable format."""
355+
if is_dataclass(arg):
356+
return asdict(arg)
353357
for type_, (encoder, _) in cls._hub_mixin_coders.items():
354358
if isinstance(arg, type_):
355359
if arg is None:

tests/test_hub_mixin.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,26 @@ def _save_pretrained(cls, save_directory: Path):
177177
return
178178

179179

180+
@dataclass
181+
class DummyDataclass:
182+
foo: int
183+
bar: str
184+
185+
186+
class DummyWithDataclassInputs(ModelHubMixin):
187+
def __init__(self, arg1: DummyDataclass, arg2: DummyDataclass):
188+
self.arg1 = arg1
189+
self.arg2 = arg2
190+
191+
@classmethod
192+
def _from_pretrained(cls, **kwargs):
193+
return cls(arg1=kwargs["arg1"], arg2=kwargs["arg2"])
194+
195+
@classmethod
196+
def _save_pretrained(cls, save_directory: Path):
197+
return
198+
199+
180200
@pytest.mark.usefixtures("fx_cache_dir")
181201
class HubMixinTest(unittest.TestCase):
182202
cache_dir: Path
@@ -501,3 +521,22 @@ def method_with_hints(self, x: int) -> str:
501521
# Test method type hints on instance
502522
model = ModelWithHints()
503523
assert get_type_hints(model.method_with_hints) == {"x": int, "return": str}
524+
525+
def test_with_dataclass_inputs(self):
526+
model = DummyWithDataclassInputs(
527+
arg1=DummyDataclass(foo=1, bar="1"),
528+
arg2=DummyDataclass(foo=2, bar="2"),
529+
)
530+
model.save_pretrained(self.cache_dir)
531+
532+
config = json.loads((self.cache_dir / "config.json").read_text())
533+
assert config == {
534+
"arg1": {"foo": 1, "bar": "1"},
535+
"arg2": {"foo": 2, "bar": "2"},
536+
}
537+
538+
model_reloaded = DummyWithDataclassInputs.from_pretrained(self.cache_dir)
539+
assert model_reloaded.arg1.foo == 1
540+
assert model_reloaded.arg1.bar == "1"
541+
assert model_reloaded.arg2.foo == 2
542+
assert model_reloaded.arg2.bar == "2"

0 commit comments

Comments
 (0)