Skip to content

Commit 6f91049

Browse files
goroldWauplin
andcommitted
Fix ModelHubMixin coders (#2291)
* update HubMixinTest with union and optional custom type * enable ModelHubMixin to handle union and optional custom type * add docstring for _is_optional_type helper function * Refactor helper functions and add independent unit tests. Refactor decode_arg * Restrict UnionType check to python3.10 and above. Minor style updates. * Only branch to pipe tests when version >= python3.10 * Change pipe operator tests to str + eval --------- Co-authored-by: Lucain <[email protected]>
1 parent 30e5192 commit 6f91049

File tree

5 files changed

+156
-8
lines changed

5 files changed

+156
-8
lines changed

src/huggingface_hub/hub_mixin.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,19 @@
44
import warnings
55
from dataclasses import asdict, dataclass, is_dataclass
66
from pathlib import Path
7-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, get_args
7+
from typing import (
8+
TYPE_CHECKING,
9+
Any,
10+
Callable,
11+
Dict,
12+
List,
13+
Optional,
14+
Tuple,
15+
Type,
16+
TypeVar,
17+
Union,
18+
get_args,
19+
)
820

921
from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE
1022
from .file_download import hf_hub_download
@@ -16,8 +28,10 @@
1628
SoftTemporaryDirectory,
1729
is_jsonable,
1830
is_safetensors_available,
31+
is_simple_optional_type,
1932
is_torch_available,
2033
logging,
34+
unwrap_simple_optional_type,
2135
validate_hf_hub_args,
2236
)
2337

@@ -336,14 +350,20 @@ def _encode_arg(cls, arg: Any) -> Any:
336350
"""Encode an argument into a JSON serializable format."""
337351
for type_, (encoder, _) in cls._hub_mixin_coders.items():
338352
if isinstance(arg, type_):
353+
if arg is None:
354+
return None
339355
return encoder(arg)
340356
return arg
341357

342358
@classmethod
343-
def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> ARGS_T:
359+
def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T]:
344360
"""Decode a JSON serializable value into an argument."""
361+
if is_simple_optional_type(expected_type):
362+
if value is None:
363+
return None
364+
expected_type = unwrap_simple_optional_type(expected_type)
345365
for type_, (_, decoder) in cls._hub_mixin_coders.items():
346-
if issubclass(expected_type, type_):
366+
if inspect.isclass(expected_type) and issubclass(expected_type, type_):
347367
return decoder(value)
348368
return value
349369

src/huggingface_hub/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
from ._subprocess import capture_output, run_interactive_subprocess, run_subprocess
114114
from ._telemetry import send_telemetry
115115
from ._token import get_token
116-
from ._typing import is_jsonable
116+
from ._typing import is_jsonable, is_simple_optional_type, unwrap_simple_optional_type
117117
from ._validators import (
118118
smoothly_deprecate_use_auth_token,
119119
validate_hf_hub_args,

src/huggingface_hub/utils/_typing.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,15 @@
1414
# limitations under the License.
1515
"""Handle typing imports based on system compatibility."""
1616

17-
from typing import Any, Callable, Literal, TypeVar
17+
import sys
18+
from typing import Any, Callable, List, Literal, Type, TypeVar, Union, get_args, get_origin
19+
20+
21+
UNION_TYPES: List[Any] = [Union]
22+
if sys.version_info >= (3, 10):
23+
from types import UnionType
24+
25+
UNION_TYPES += [UnionType]
1826

1927

2028
HTTP_METHOD_T = Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]
@@ -48,3 +56,20 @@ def is_jsonable(obj: Any) -> bool:
4856
return False
4957
except RecursionError:
5058
return False
59+
60+
61+
def is_simple_optional_type(type_: Type) -> bool:
62+
"""Check if a type is optional, i.e. Optional[Type] or Union[Type, None] or Type | None, where Type is a non-composite type."""
63+
if get_origin(type_) in UNION_TYPES:
64+
union_args = get_args(type_)
65+
if len(union_args) == 2 and type(None) in union_args:
66+
return True
67+
return False
68+
69+
70+
def unwrap_simple_optional_type(optional_type: Type) -> Type:
71+
"""Unwraps a simple optional type, i.e. returns Type from Optional[Type]."""
72+
for arg in get_args(optional_type):
73+
if arg is not type(None):
74+
return arg
75+
raise ValueError(f"'{optional_type}' is not an optional type")

tests/test_hub_mixin.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,22 @@ class DummyModelWithCustomTypes(
144144
},
145145
):
146146
def __init__(
147-
self, foo: int, bar: str, custom: CustomType, custom_default: CustomType = CustomType("default"), **kwargs
147+
self,
148+
foo: int,
149+
bar: str,
150+
baz: Union[int, str],
151+
custom: CustomType,
152+
optional_custom_1: Optional[CustomType],
153+
optional_custom_2: Optional[CustomType],
154+
custom_default: CustomType = CustomType("default"),
155+
**kwargs,
148156
):
149157
self.foo = foo
150158
self.bar = bar
159+
self.baz = baz
151160
self.custom = custom
161+
self.optional_custom_1 = optional_custom_1
162+
self.optional_custom_2 = optional_custom_2
152163
self.custom_default = custom_default
153164

154165
@classmethod
@@ -406,21 +417,34 @@ def test_from_pretrained_when_cls_is_a_dataclass(self):
406417
assert not hasattr(model, "other")
407418

408419
def test_from_cls_with_custom_type(self):
409-
model = DummyModelWithCustomTypes(1, bar="bar", custom=CustomType("custom"))
420+
model = DummyModelWithCustomTypes(
421+
1,
422+
bar="bar",
423+
baz=1.0,
424+
custom=CustomType("custom"),
425+
optional_custom_1=CustomType("optional"),
426+
optional_custom_2=None,
427+
)
410428
model.save_pretrained(self.cache_dir)
411429

412430
config = json.loads((self.cache_dir / "config.json").read_text())
413431
assert config == {
414432
"foo": 1,
415433
"bar": "bar",
434+
"baz": 1.0,
416435
"custom": {"value": "custom"},
436+
"optional_custom_1": {"value": "optional"},
437+
"optional_custom_2": None,
417438
"custom_default": {"value": "default"},
418439
}
419440

420441
model_reloaded = DummyModelWithCustomTypes.from_pretrained(self.cache_dir)
421442
assert model_reloaded.foo == 1
422443
assert model_reloaded.bar == "bar"
444+
assert model_reloaded.baz == 1.0
423445
assert model_reloaded.custom.value == "custom"
446+
assert model_reloaded.optional_custom_1 is not None and model_reloaded.optional_custom_1.value == "optional"
447+
assert model_reloaded.optional_custom_2 is None
424448
assert model_reloaded.custom_default.value == "default"
425449

426450
def test_inherited_class(self):

tests/test_utils_typing.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
import json
2+
import sys
3+
from typing import Optional, Type, Union
24

35
import pytest
46

5-
from huggingface_hub.utils._typing import is_jsonable
7+
from huggingface_hub.utils._typing import is_jsonable, is_simple_optional_type, unwrap_simple_optional_type
68

79

810
class NotSerializableClass:
911
pass
1012

1113

14+
class CustomType:
15+
pass
16+
17+
1218
OBJ_WITH_CIRCULAR_REF = {"hello": "world"}
1319
OBJ_WITH_CIRCULAR_REF["recursive"] = OBJ_WITH_CIRCULAR_REF
1420

@@ -47,3 +53,76 @@ def test_is_jsonable_failure(data):
4753
assert not is_jsonable(data)
4854
with pytest.raises((TypeError, ValueError)):
4955
json.dumps(data)
56+
57+
58+
@pytest.mark.parametrize(
59+
"type_, is_optional",
60+
[
61+
(Optional[int], True),
62+
(Union[None, int], True),
63+
(Union[int, None], True),
64+
(Optional[CustomType], True),
65+
(Union[None, CustomType], True),
66+
(Union[CustomType, None], True),
67+
(int, False),
68+
(None, False),
69+
(Union[int, float, None], False),
70+
(Union[Union[int, float], None], False),
71+
(Optional[Union[int, float]], False),
72+
],
73+
)
74+
def test_is_simple_optional_type(type_: Type, is_optional: bool):
75+
assert is_simple_optional_type(type_) is is_optional
76+
77+
78+
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
79+
@pytest.mark.parametrize(
80+
"type_, is_optional",
81+
[
82+
("int | None", True),
83+
("None | int", True),
84+
("CustomType | None", True),
85+
("None | CustomType", True),
86+
("int | float", False),
87+
("int | float | None", False),
88+
("(int | float) | None", False),
89+
("Union[int, float] | None", False),
90+
],
91+
)
92+
def test_is_simple_optional_type_pipe(type_: str, is_optional: bool):
93+
assert is_simple_optional_type(eval(type_)) is is_optional
94+
95+
96+
@pytest.mark.parametrize(
97+
"optional_type, inner_type",
98+
[
99+
(Optional[int], int),
100+
(Union[int, None], int),
101+
(Union[None, int], int),
102+
(Optional[CustomType], CustomType),
103+
(Union[CustomType, None], CustomType),
104+
(Union[None, CustomType], CustomType),
105+
],
106+
)
107+
def test_unwrap_simple_optional_type(optional_type: Type, inner_type: Type):
108+
assert unwrap_simple_optional_type(optional_type) is inner_type
109+
110+
111+
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
112+
@pytest.mark.parametrize(
113+
"optional_type, inner_type",
114+
[
115+
("None | int", int),
116+
("int | None", int),
117+
("None | CustomType", CustomType),
118+
("CustomType | None", CustomType),
119+
],
120+
)
121+
def test_unwrap_simple_optional_type_pipe(optional_type: str, inner_type: Type):
122+
assert unwrap_simple_optional_type(eval(optional_type)) is inner_type
123+
124+
125+
@pytest.mark.parametrize("non_optional_type", [int, None, CustomType])
126+
def test_unwrap_simple_optional_type_fail(non_optional_type: Type):
127+
with pytest.raises(ValueError):
128+
unwrap_simple_optional_type(non_optional_type)

0 commit comments

Comments
 (0)