Skip to content

Commit 2c69ed6

Browse files
authored
fix(convert): stop treating None annotation as Any (#827)
1 parent 3139e18 commit 2c69ed6

File tree

4 files changed

+36
-27
lines changed

4 files changed

+36
-27
lines changed

python/cocoindex/convert.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,9 @@ def decode(value: Any) -> Any | None:
241241
vec_elem_decoder = make_engine_value_decoder(
242242
field_path + ["[*]"],
243243
src_type["element_type"],
244-
analyze_type_info(dst_type_variant and dst_type_variant.elem_type),
244+
analyze_type_info(
245+
dst_type_variant.elem_type if dst_type_variant else Any
246+
),
245247
)
246248

247249
def decode_vector(value: Any) -> Any | None:

python/cocoindex/op.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def create_export_context(
505505
self._mutatation_type.value_type,
506506
)
507507
if self._mutatation_type is not None
508-
else (None, None)
508+
else (Any, Any)
509509
)
510510

511511
key_type_info = analyze_type_info(key_annotation)
@@ -519,10 +519,11 @@ def create_export_context(
519519
["(key)"],
520520
key_fields_schema[0]["type"],
521521
key_type_info,
522+
for_key=True,
522523
)
523524
else:
524525
key_decoder = make_engine_struct_decoder(
525-
["(key)"], key_fields_schema, key_type_info
526+
["(key)"], key_fields_schema, key_type_info, for_key=True
526527
)
527528

528529
value_decoder = make_engine_struct_decoder(

python/cocoindex/tests/test_convert.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import inspect
23
import uuid
34
from dataclasses import dataclass, make_dataclass, field
45
from typing import Annotated, Any, Callable, Literal, NamedTuple
@@ -236,19 +237,24 @@ def test_encode_engine_value_none() -> None:
236237

237238

238239
def test_roundtrip_basic_types() -> None:
239-
validate_full_roundtrip(b"hello world", bytes, (b"hello world", None))
240+
validate_full_roundtrip(
241+
b"hello world",
242+
bytes,
243+
(b"hello world", inspect.Parameter.empty),
244+
(b"hello world", Any),
245+
)
240246
validate_full_roundtrip(b"\x00\x01\x02\xff\xfe", bytes)
241-
validate_full_roundtrip("hello", str, ("hello", None))
242-
validate_full_roundtrip(True, bool, (True, None))
243-
validate_full_roundtrip(False, bool, (False, None))
247+
validate_full_roundtrip("hello", str, ("hello", Any))
248+
validate_full_roundtrip(True, bool, (True, Any))
249+
validate_full_roundtrip(False, bool, (False, Any))
244250
validate_full_roundtrip(
245-
42, cocoindex.Int64, (42, int), (np.int64(42), np.int64), (42, None)
251+
42, cocoindex.Int64, (42, int), (np.int64(42), np.int64), (42, Any)
246252
)
247253
validate_full_roundtrip(42, int, (42, cocoindex.Int64))
248254
validate_full_roundtrip(np.int64(42), np.int64, (42, cocoindex.Int64))
249255

250256
validate_full_roundtrip(
251-
3.25, Float64, (3.25, float), (np.float64(3.25), np.float64), (3.25, None)
257+
3.25, Float64, (3.25, float), (np.float64(3.25), np.float64), (3.25, Any)
252258
)
253259
validate_full_roundtrip(3.25, float, (3.25, Float64))
254260
validate_full_roundtrip(np.float64(3.25), np.float64, (3.25, Float64))
@@ -260,35 +266,35 @@ def test_roundtrip_basic_types() -> None:
260266
(np.float32(3.25), np.float32),
261267
(np.float64(3.25), np.float64),
262268
(3.25, Float64),
263-
(3.25, None),
269+
(3.25, Any),
264270
)
265271
validate_full_roundtrip(np.float32(3.25), np.float32, (3.25, Float32))
266272

267273

268274
def test_roundtrip_uuid() -> None:
269275
uuid_value = uuid.uuid4()
270-
validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value, None))
276+
validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value, Any))
271277

272278

273279
def test_roundtrip_range() -> None:
274280
r1 = (0, 100)
275-
validate_full_roundtrip(r1, cocoindex.Range, (r1, None))
281+
validate_full_roundtrip(r1, cocoindex.Range, (r1, Any))
276282
r2 = (50, 50)
277-
validate_full_roundtrip(r2, cocoindex.Range, (r2, None))
283+
validate_full_roundtrip(r2, cocoindex.Range, (r2, Any))
278284
r3 = (0, 1_000_000_000)
279-
validate_full_roundtrip(r3, cocoindex.Range, (r3, None))
285+
validate_full_roundtrip(r3, cocoindex.Range, (r3, Any))
280286

281287

282288
def test_roundtrip_time() -> None:
283289
t1 = datetime.time(10, 30, 50, 123456)
284-
validate_full_roundtrip(t1, datetime.time, (t1, None))
290+
validate_full_roundtrip(t1, datetime.time, (t1, Any))
285291
t2 = datetime.time(23, 59, 59)
286-
validate_full_roundtrip(t2, datetime.time, (t2, None))
292+
validate_full_roundtrip(t2, datetime.time, (t2, Any))
287293
t3 = datetime.time(0, 0, 0)
288-
validate_full_roundtrip(t3, datetime.time, (t3, None))
294+
validate_full_roundtrip(t3, datetime.time, (t3, Any))
289295

290296
validate_full_roundtrip(
291-
datetime.date(2025, 1, 1), datetime.date, (datetime.date(2025, 1, 1), None)
297+
datetime.date(2025, 1, 1), datetime.date, (datetime.date(2025, 1, 1), Any)
292298
)
293299

294300
validate_full_roundtrip(
@@ -333,11 +339,11 @@ def test_roundtrip_timedelta() -> None:
333339
td1 = datetime.timedelta(
334340
days=5, seconds=10, microseconds=123, milliseconds=456, minutes=30, hours=2
335341
)
336-
validate_full_roundtrip(td1, datetime.timedelta, (td1, None))
342+
validate_full_roundtrip(td1, datetime.timedelta, (td1, Any))
337343
td2 = datetime.timedelta(days=-5, hours=-2)
338-
validate_full_roundtrip(td2, datetime.timedelta, (td2, None))
344+
validate_full_roundtrip(td2, datetime.timedelta, (td2, Any))
339345
td3 = datetime.timedelta(0)
340-
validate_full_roundtrip(td3, datetime.timedelta, (td3, None))
346+
validate_full_roundtrip(td3, datetime.timedelta, (td3, Any))
341347

342348

343349
def test_roundtrip_json() -> None:
@@ -1251,8 +1257,8 @@ class Config:
12511257
instance = Config("localhost", 8080, True)
12521258
expected_dict = {"host": "localhost", "port": 8080, "debug": True}
12531259

1254-
# Test None annotation (should be treated as Any)
1255-
validate_full_roundtrip(instance, Config, (expected_dict, None))
1260+
# Test empty annotation (should be treated as Any)
1261+
validate_full_roundtrip(instance, Config, (expected_dict, inspect.Parameter.empty))
12561262

12571263

12581264
def test_roundtrip_struct_to_dict_nested() -> None:

python/cocoindex/typing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,23 +262,23 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
262262

263263
if kind is not None:
264264
variant = AnalyzedBasicType(kind=kind)
265-
elif base_type is None or base_type is Any or base_type is inspect.Parameter.empty:
265+
elif base_type is Any or base_type is inspect.Parameter.empty:
266266
variant = AnalyzedAnyType()
267267
elif is_struct_type(base_type):
268268
variant = AnalyzedStructType(struct_type=t)
269269
elif is_numpy_number_type(t):
270270
kind = DtypeRegistry.validate_dtype_and_get_kind(t)
271271
variant = AnalyzedBasicType(kind=kind)
272272
elif base_type is collections.abc.Sequence or base_type is list:
273-
elem_type = type_args[0] if len(type_args) > 0 else None
273+
elem_type = type_args[0] if len(type_args) > 0 else Any
274274
variant = AnalyzedListType(elem_type=elem_type, vector_info=vector_info)
275275
elif base_type is np.ndarray:
276276
np_number_type = t
277277
elem_type = extract_ndarray_elem_dtype(np_number_type)
278278
variant = AnalyzedListType(elem_type=elem_type, vector_info=vector_info)
279279
elif base_type is collections.abc.Mapping or base_type is dict or t is dict:
280-
key_type = type_args[0] if len(type_args) > 0 else None
281-
elem_type = type_args[1] if len(type_args) > 1 else None
280+
key_type = type_args[0] if len(type_args) > 0 else Any
281+
elem_type = type_args[1] if len(type_args) > 1 else Any
282282
variant = AnalyzedDictType(key_type=key_type, value_type=elem_type)
283283
elif base_type in (types.UnionType, typing.Union):
284284
non_none_types = [arg for arg in type_args if arg not in (None, types.NoneType)]

0 commit comments

Comments
 (0)