|
1 | 1 | import json |
| 2 | +import sys |
| 3 | +from typing import Optional, Type, Union |
2 | 4 |
|
3 | 5 | import pytest |
4 | 6 |
|
5 | | -from huggingface_hub.utils._typing import is_jsonable |
| 7 | +from huggingface_hub.utils._typing import is_jsonable, is_simple_optional_type, unwrap_simple_optional_type |
6 | 8 |
|
7 | 9 |
|
8 | 10 | class NotSerializableClass: |
9 | 11 | pass |
10 | 12 |
|
11 | 13 |
|
| 14 | +class CustomType: |
| 15 | + pass |
| 16 | + |
| 17 | + |
12 | 18 | OBJ_WITH_CIRCULAR_REF = {"hello": "world"} |
13 | 19 | OBJ_WITH_CIRCULAR_REF["recursive"] = OBJ_WITH_CIRCULAR_REF |
14 | 20 |
|
@@ -47,3 +53,76 @@ def test_is_jsonable_failure(data): |
47 | 53 | assert not is_jsonable(data) |
48 | 54 | with pytest.raises((TypeError, ValueError)): |
49 | 55 | json.dumps(data) |
| 56 | + |
| 57 | + |
| 58 | +@pytest.mark.parametrize( |
| 59 | + "type_, is_optional", |
| 60 | + [ |
| 61 | + (Optional[int], True), |
| 62 | + (Union[None, int], True), |
| 63 | + (Union[int, None], True), |
| 64 | + (Optional[CustomType], True), |
| 65 | + (Union[None, CustomType], True), |
| 66 | + (Union[CustomType, None], True), |
| 67 | + (int, False), |
| 68 | + (None, False), |
| 69 | + (Union[int, float, None], False), |
| 70 | + (Union[Union[int, float], None], False), |
| 71 | + (Optional[Union[int, float]], False), |
| 72 | + ], |
| 73 | +) |
| 74 | +def test_is_simple_optional_type(type_: Type, is_optional: bool): |
| 75 | + assert is_simple_optional_type(type_) is is_optional |
| 76 | + |
| 77 | + |
| 78 | +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") |
| 79 | +@pytest.mark.parametrize( |
| 80 | + "type_, is_optional", |
| 81 | + [ |
| 82 | + ("int | None", True), |
| 83 | + ("None | int", True), |
| 84 | + ("CustomType | None", True), |
| 85 | + ("None | CustomType", True), |
| 86 | + ("int | float", False), |
| 87 | + ("int | float | None", False), |
| 88 | + ("(int | float) | None", False), |
| 89 | + ("Union[int, float] | None", False), |
| 90 | + ], |
| 91 | +) |
| 92 | +def test_is_simple_optional_type_pipe(type_: str, is_optional: bool): |
| 93 | + assert is_simple_optional_type(eval(type_)) is is_optional |
| 94 | + |
| 95 | + |
| 96 | +@pytest.mark.parametrize( |
| 97 | + "optional_type, inner_type", |
| 98 | + [ |
| 99 | + (Optional[int], int), |
| 100 | + (Union[int, None], int), |
| 101 | + (Union[None, int], int), |
| 102 | + (Optional[CustomType], CustomType), |
| 103 | + (Union[CustomType, None], CustomType), |
| 104 | + (Union[None, CustomType], CustomType), |
| 105 | + ], |
| 106 | +) |
| 107 | +def test_unwrap_simple_optional_type(optional_type: Type, inner_type: Type): |
| 108 | + assert unwrap_simple_optional_type(optional_type) is inner_type |
| 109 | + |
| 110 | + |
| 111 | +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") |
| 112 | +@pytest.mark.parametrize( |
| 113 | + "optional_type, inner_type", |
| 114 | + [ |
| 115 | + ("None | int", int), |
| 116 | + ("int | None", int), |
| 117 | + ("None | CustomType", CustomType), |
| 118 | + ("CustomType | None", CustomType), |
| 119 | + ], |
| 120 | +) |
| 121 | +def test_unwrap_simple_optional_type_pipe(optional_type: str, inner_type: Type): |
| 122 | + assert unwrap_simple_optional_type(eval(optional_type)) is inner_type |
| 123 | + |
| 124 | + |
| 125 | +@pytest.mark.parametrize("non_optional_type", [int, None, CustomType]) |
| 126 | +def test_unwrap_simple_optional_type_fail(non_optional_type: Type): |
| 127 | + with pytest.raises(ValueError): |
| 128 | + unwrap_simple_optional_type(non_optional_type) |
0 commit comments