Skip to content

Commit eb8fae5

Browse files
authored
Support ... | None (or Optional[...]) when analyzing type. #19 (#70)
Support `... | None` (or `Optional[...]`) when analyzing type.
1 parent 7b2b20e commit eb8fae5

File tree

1 file changed

+28
-6
lines changed

1 file changed

+28
-6
lines changed

python/cocoindex/typing.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import typing
22
import collections
33
import dataclasses
4+
import types
45
from typing import Annotated, NamedTuple, Any, TypeVar, TYPE_CHECKING
56

67
class Vector(NamedTuple):
@@ -59,16 +60,31 @@ class AnalyzedTypeInfo:
5960
elem_type: type | None
6061
struct_fields: tuple[dataclasses.Field, ...] | None
6162
attrs: dict[str, Any] | None
63+
nullable: bool = False
6264

6365
def analyze_type_info(t) -> AnalyzedTypeInfo:
6466
"""
6567
Analyze a Python type and return the analyzed info.
6668
"""
6769
annotations: tuple[Annotation, ...] = ()
68-
if typing.get_origin(t) is Annotated:
69-
annotations = t.__metadata__
70-
t = t.__origin__
71-
base_type = typing.get_origin(t)
70+
base_type = None
71+
nullable = False
72+
while True:
73+
base_type = typing.get_origin(t)
74+
if base_type is Annotated:
75+
annotations = t.__metadata__
76+
t = t.__origin__
77+
elif base_type is types.UnionType:
78+
possible_types = typing.get_args(t)
79+
non_none_types = [arg for arg in possible_types if arg not in (None, types.NoneType)]
80+
if len(non_none_types) != 1:
81+
raise ValueError(
82+
f"Expect exactly one non-None choice for Union type, but got {len(non_none_types)}: {t}")
83+
t = non_none_types[0]
84+
if len(possible_types) > 1:
85+
nullable = True
86+
else:
87+
break
7288

7389
attrs = None
7490
vector_info = None
@@ -118,7 +134,7 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
118134
raise ValueError(f"type unsupported yet: {base_type}")
119135

120136
return AnalyzedTypeInfo(kind=kind, vector_info=vector_info, elem_type=elem_type,
121-
struct_fields=struct_fields, attrs=attrs)
137+
struct_fields=struct_fields, attrs=attrs, nullable=nullable)
122138

123139
def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
124140
encoded_type: dict[str, Any] = { 'kind': type_info.kind }
@@ -150,9 +166,15 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
150166

151167
def _encode_enriched_type(t) -> dict[str, Any]:
152168
enriched_type_info = analyze_type_info(t)
153-
encoded = {'type': _encode_type(enriched_type_info)}
169+
170+
encoded: dict[str, Any] = {'type': _encode_type(enriched_type_info)}
171+
154172
if enriched_type_info.attrs is not None:
155173
encoded['attrs'] = enriched_type_info.attrs
174+
175+
if enriched_type_info.nullable:
176+
encoded['nullable'] = True
177+
156178
return encoded
157179

158180

0 commit comments

Comments
 (0)