Skip to content

Commit 54b4514

Browse files
Wauplinnot-lain
andcommitted
Fix: correctly encode/decode config in ModelHubMixin if custom coders (#2337)
* Fix: correctly encode/decode config in ModelHubMixin if custom coders * make style * make quality * Update tests/test_hub_mixin_pytorch.py Co-authored-by: Hafedh <[email protected]> --------- Co-authored-by: Hafedh <[email protected]>
1 parent 9bce62f commit 54b4514

File tree

3 files changed

+45
-19
lines changed

3 files changed

+45
-19
lines changed

docs/source/ko/guides/integrations.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,9 @@ from argparse import Namespace
365365

366366
class VoiceCraft(
367367
nn.Module,
368-
PytorchModelHubMixin, # 믹스인을 상속합니다.
369-
coders: {
370-
Namespace = (
368+
PyTorchModelHubMixin, # 믹스인을 상속합니다.
369+
coders={
370+
Namespace: (
371371
lambda x: vars(x), # Encoder: `Namespace`를 유효한 JSON 형태로 변환하는 방법은 무엇인가요?
372372
lambda data: Namespace(**data), # Decoder: 딕셔너리에서 Namespace를 재구성하는 방법은 무엇인가요?
373373
)

src/huggingface_hub/hub_mixin.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
Type,
1616
TypeVar,
1717
Union,
18-
get_args,
1918
)
2019

2120
from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE
@@ -326,12 +325,11 @@ def __new__(cls, *args, **kwargs) -> "ModelHubMixin":
326325
if instance._is_jsonable(value) # Only if jsonable or we have a custom encoder
327326
},
328327
}
329-
init_config.pop("config", {})
328+
passed_config = init_config.pop("config", {})
330329

331330
# Populate `init_config` with provided config
332-
provided_config = passed_values.get("config")
333-
if isinstance(provided_config, dict):
334-
init_config.update(provided_config)
331+
if isinstance(passed_config, dict):
332+
init_config.update(passed_config)
335333

336334
# Set `config` attribute and return
337335
if init_config != {}:
@@ -362,9 +360,14 @@ def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T
362360
if value is None:
363361
return None
364362
expected_type = unwrap_simple_optional_type(expected_type)
363+
# Dataclass => handle it
364+
if is_dataclass(expected_type):
365+
return _load_dataclass(expected_type, value) # type: ignore[return-value]
366+
# Otherwise => check custom decoders
365367
for type_, (_, decoder) in cls._hub_mixin_coders.items():
366368
if inspect.isclass(expected_type) and issubclass(expected_type, type_):
367369
return decoder(value)
370+
# Otherwise => don't decode
368371
return value
369372

370373
def save_pretrained(
@@ -531,18 +534,9 @@ def from_pretrained(
531534

532535
# Check if `config` argument was passed at init
533536
if "config" in cls._hub_mixin_init_parameters and "config" not in model_kwargs:
534-
# Check if `config` argument is a dataclass
537+
# Decode `config` argument if it was passed
535538
config_annotation = cls._hub_mixin_init_parameters["config"].annotation
536-
if config_annotation is inspect.Parameter.empty:
537-
pass # no annotation
538-
elif is_dataclass(config_annotation):
539-
config = _load_dataclass(config_annotation, config)
540-
else:
541-
# if Optional/Union annotation => check if a dataclass is in the Union
542-
for _sub_annotation in get_args(config_annotation):
543-
if is_dataclass(_sub_annotation):
544-
config = _load_dataclass(_sub_annotation, config)
545-
break
539+
config = cls._decode_arg(config_annotation, config)
546540

547541
# Forward config to model initialization
548542
model_kwargs["config"] = config

tests/test_hub_mixin_pytorch.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import struct
44
import unittest
5+
from argparse import Namespace
56
from pathlib import Path
67
from typing import Any, Dict, Optional, TypeVar
78
from unittest.mock import Mock, patch
@@ -95,6 +96,21 @@ class DummyModelWithModelCardAndCustomKwargs(
9596
def __init__(self, linear_layer: int = 4):
9697
super().__init__()
9798

99+
class DummyModelWithEncodedConfig(
100+
nn.Module,
101+
PyTorchModelHubMixin,
102+
coders={
103+
Namespace: (
104+
lambda x: vars(x),
105+
lambda data: Namespace(**data),
106+
)
107+
},
108+
):
109+
# Regression test for https://github.com/huggingface/huggingface_hub/issues/2334
110+
def __init__(self, config: Namespace):
111+
super().__init__()
112+
self.config = config
113+
98114
else:
99115
DummyModel = None
100116
DummyModelWithModelCard = None
@@ -419,3 +435,19 @@ def test_model_card_with_custom_kwargs(self):
419435
model.save_pretrained(self.cache_dir, model_card_kwargs=model_card_kwargs)
420436
card_reloaded = ModelCard.load(self.cache_dir / "README.md")
421437
assert str(card) == str(card_reloaded)
438+
439+
def test_config_with_custom_coders(self):
440+
"""
441+
Regression test for #2334. When `config` is encoded with custom coders, it should be decoded correctly.
442+
443+
See https://github.com/huggingface/huggingface_hub/issues/2334.
444+
"""
445+
model = DummyModelWithEncodedConfig(Namespace(a=1, b=2))
446+
model.save_pretrained(self.cache_dir)
447+
assert model._hub_mixin_config["a"] == 1
448+
assert model._hub_mixin_config["b"] == 2
449+
450+
reloaded = DummyModelWithEncodedConfig.from_pretrained(self.cache_dir)
451+
assert isinstance(reloaded.config, Namespace)
452+
assert reloaded.config.a == 1
453+
assert reloaded.config.b == 2

0 commit comments

Comments
 (0)