Skip to content

Commit ee9ef67

Browse files
committed
함수 스키마 변경: 유니온 및 리터럴 타입 처리 수정
1 parent a685e87 commit ee9ef67

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

function_schema/core.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
py_310 = packaging.version.parse("3.10")
99

1010
if current_version >= py_310:
11-
import types
1211
from types import UnionType
1312
else:
1413
UnionType = typing.Union # type: ignore
@@ -154,13 +153,17 @@ def guess_type(
154153

155154
origin = typing.get_origin(T)
156155

156+
if origin is typing.Annotated:
157+
return guess_type(typing.get_args(T)[0])
158+
157159
# hacking around typing modules, `typing.Union` and `types.UnitonType`
158-
if origin is typing.Union or origin is UnionType:
160+
if origin in [typing.Union, UnionType]:
159161
union_types = [t for t in typing.get_args(T) if t is not type(None)]
160-
_types = []
161-
for union_type in union_types:
162-
_types.append(guess_type(union_type))
163-
_types = [t for t in _types if t is not None] # exclude None
162+
_types = [
163+
guess_type(union_type)
164+
for union_type in union_types
165+
if guess_type(union_type) is not None
166+
]
164167

165168
# number contains integer in JSON schema
166169
if "number" in _types and "integer" in _types:
@@ -171,7 +174,8 @@ def guess_type(
171174
return _types
172175

173176
if origin is typing.Literal:
174-
return guess_type(typing.Union[tuple(type(arg) for arg in typing.get_args(T))])
177+
type_args = typing.Union[tuple(type(arg) for arg in typing.get_args(T))]
178+
return guess_type(type_args)
175179
elif origin is list or origin is tuple:
176180
return "array"
177181
elif origin is dict:

test/test_guess_type.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,21 @@ def test_union_type():
9898

9999
def test_literal_type():
100100
"""Test literal type"""
101-
assert guess_type(typing.Literal["a"]) == "string"
102-
assert guess_type(typing.Literal[1]) == "integer"
101+
assert guess_type(typing.Literal["a"]) == "string", "should be string"
102+
assert guess_type(typing.Literal[1]) == "integer", "should be integer"
103+
assert guess_type(typing.Literal[1.0]) == "number", "should be number"
103104

104-
assert set(guess_type(typing.Literal["a", 1, None])) == {"string", "integer"}
105+
assert set(guess_type(typing.Literal["a", 1, None])) == {
106+
"string",
107+
"integer",
108+
}, "should be string or integer, omit None"
105109

106110
assert set(guess_type(typing.Literal["a", 1])) == {"string", "integer"}
107-
assert set(guess_type(typing.Literal["a", 1.0])) == {"string", "integer"}
111+
assert set(guess_type(typing.Literal["a", 1.0])) == {"string", "number"}
108112
assert set(guess_type(typing.Literal["a", 1.1])) == {"string", "number"}
109113
assert set(guess_type(typing.Literal["a", 1, 1.0])) == {
110114
"string",
111115
"number",
112-
} # XXX should be ["string", "integer", "number"] ?
116+
}, "should omit integer if number is present"
113117

114118
assert set(guess_type(typing.Literal["a", 1, 1.0, None])) == {"string", "number"}

0 commit comments

Comments
 (0)