|
1 | 1 | import typing |
2 | 2 | import collections |
3 | 3 | import dataclasses |
| 4 | +import types |
4 | 5 | from typing import Annotated, NamedTuple, Any, TypeVar, TYPE_CHECKING |
5 | 6 |
|
6 | 7 | class Vector(NamedTuple): |
@@ -59,16 +60,31 @@ class AnalyzedTypeInfo: |
59 | 60 | elem_type: type | None |
60 | 61 | struct_fields: tuple[dataclasses.Field, ...] | None |
61 | 62 | attrs: dict[str, Any] | None |
| 63 | + nullable: bool = False |
62 | 64 |
|
63 | 65 | def analyze_type_info(t) -> AnalyzedTypeInfo: |
64 | 66 | """ |
65 | 67 | Analyze a Python type and return the analyzed info. |
66 | 68 | """ |
67 | 69 | annotations: tuple[Annotation, ...] = () |
68 | | - if typing.get_origin(t) is Annotated: |
69 | | - annotations = t.__metadata__ |
70 | | - t = t.__origin__ |
71 | | - base_type = typing.get_origin(t) |
| 70 | + base_type = None |
| 71 | + nullable = False |
| 72 | + while True: |
| 73 | + base_type = typing.get_origin(t) |
| 74 | + if base_type is Annotated: |
| 75 | + annotations = t.__metadata__ |
| 76 | + t = t.__origin__ |
| 77 | + elif base_type is types.UnionType: |
| 78 | + possible_types = typing.get_args(t) |
| 79 | + non_none_types = [arg for arg in possible_types if arg not in (None, types.NoneType)] |
| 80 | + if len(non_none_types) != 1: |
| 81 | + raise ValueError( |
| 82 | + f"Expect exactly one non-None choice for Union type, but got {len(non_none_types)}: {t}") |
| 83 | + t = non_none_types[0] |
| 84 | + if len(possible_types) > 1: |
| 85 | + nullable = True |
| 86 | + else: |
| 87 | + break |
72 | 88 |
|
73 | 89 | attrs = None |
74 | 90 | vector_info = None |
@@ -118,7 +134,7 @@ def analyze_type_info(t) -> AnalyzedTypeInfo: |
118 | 134 | raise ValueError(f"type unsupported yet: {base_type}") |
119 | 135 |
|
120 | 136 | return AnalyzedTypeInfo(kind=kind, vector_info=vector_info, elem_type=elem_type, |
121 | | - struct_fields=struct_fields, attrs=attrs) |
| 137 | + struct_fields=struct_fields, attrs=attrs, nullable=nullable) |
122 | 138 |
|
123 | 139 | def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: |
124 | 140 | encoded_type: dict[str, Any] = { 'kind': type_info.kind } |
@@ -150,9 +166,15 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: |
150 | 166 |
|
151 | 167 | def _encode_enriched_type(t) -> dict[str, Any]: |
152 | 168 | enriched_type_info = analyze_type_info(t) |
153 | | - encoded = {'type': _encode_type(enriched_type_info)} |
| 169 | + |
| 170 | + encoded: dict[str, Any] = {'type': _encode_type(enriched_type_info)} |
| 171 | + |
154 | 172 | if enriched_type_info.attrs is not None: |
155 | 173 | encoded['attrs'] = enriched_type_info.attrs |
| 174 | + |
| 175 | + if enriched_type_info.nullable: |
| 176 | + encoded['nullable'] = True |
| 177 | + |
156 | 178 | return encoded |
157 | 179 |
|
158 | 180 |
|
|
0 commit comments