Skip to content

Commit 52f5f84

Browse files
committed
refactor: eliminate type errors in pytests and enforce mypy for tests
1 parent 08f549e commit 52f5f84

File tree

6 files changed

+155
-100
lines changed

6 files changed

+155
-100
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,3 @@ dev = ["ruff"]
3535
python_version = "3.11"
3636
strict = true
3737
files = "python/cocoindex"
38-
exclude = "python/cocoindex/tests"

python/cocoindex/__init__.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,54 @@
1515
from .setting import DatabaseConnectionSpec, Settings, ServerSettings
1616
from .setting import get_app_namespace
1717
from .typing import Float32, Float64, LocalDateTime, OffsetDateTime, Range, Vector, Json
18+
19+
__all__ = [
20+
# Submodules
21+
"_engine",
22+
"functions",
23+
"sources",
24+
"storages",
25+
"cli",
26+
"utils",
27+
# Auth registry
28+
"AuthEntryReference",
29+
"add_auth_entry",
30+
"ref_auth_entry",
31+
# Flow
32+
"FlowBuilder",
33+
"DataScope",
34+
"DataSlice",
35+
"Flow",
36+
"transform_flow",
37+
"flow_def",
38+
"EvaluateAndDumpOptions",
39+
"GeneratedField",
40+
"update_all_flows_async",
41+
"FlowLiveUpdater",
42+
"FlowLiveUpdaterOptions",
43+
# Lib
44+
"init",
45+
"start_server",
46+
"stop",
47+
"main_fn",
48+
# LLM
49+
"LlmSpec",
50+
"LlmApiType",
51+
# Index
52+
"VectorSimilarityMetric",
53+
"VectorIndexDef",
54+
"IndexOptions",
55+
# Settings
56+
"DatabaseConnectionSpec",
57+
"Settings",
58+
"ServerSettings",
59+
"get_app_namespace",
60+
# Typing
61+
"Float32",
62+
"Float64",
63+
"LocalDateTime",
64+
"OffsetDateTime",
65+
"Range",
66+
"Vector",
67+
"Json",
68+
]

python/cocoindex/tests/test_convert.py

Lines changed: 52 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class Tag:
3434

3535
@dataclass
3636
class Basket:
37-
items: list
37+
items: list[str]
3838

3939

4040
@dataclass
@@ -86,7 +86,7 @@ def validate_full_roundtrip(
8686
`other_decoded_values` is a tuple of (value, type) pairs.
8787
If provided, also validate the value can be decoded to the other types.
8888
"""
89-
from cocoindex import _engine
89+
from cocoindex import _engine # type: ignore
9090

9191
encoded_value = encode_engine_value(value)
9292
value_type = value_type or type(value)
@@ -107,19 +107,19 @@ def validate_full_roundtrip(
107107
np.testing.assert_array_equal(other_decoded_value, other_value)
108108

109109

110-
def test_encode_engine_value_basic_types():
110+
def test_encode_engine_value_basic_types() -> None:
111111
assert encode_engine_value(123) == 123
112112
assert encode_engine_value(3.14) == 3.14
113113
assert encode_engine_value("hello") == "hello"
114114
assert encode_engine_value(True) is True
115115

116116

117-
def test_encode_engine_value_uuid():
117+
def test_encode_engine_value_uuid() -> None:
118118
u = uuid.uuid4()
119119
assert encode_engine_value(u) == u.bytes
120120

121121

122-
def test_encode_engine_value_date_time_types():
122+
def test_encode_engine_value_date_time_types() -> None:
123123
d = datetime.date(2024, 1, 1)
124124
assert encode_engine_value(d) == d
125125
t = datetime.time(12, 30)
@@ -128,7 +128,7 @@ def test_encode_engine_value_date_time_types():
128128
assert encode_engine_value(dt) == dt
129129

130130

131-
def test_encode_engine_value_struct():
131+
def test_encode_engine_value_struct() -> None:
132132
order = Order(order_id="O123", name="mixed nuts", price=25.0)
133133
assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"]
134134

@@ -141,7 +141,7 @@ def test_encode_engine_value_struct():
141141
]
142142

143143

144-
def test_encode_engine_value_list_of_structs():
144+
def test_encode_engine_value_list_of_structs() -> None:
145145
orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
146146
assert encode_engine_value(orders) == [
147147
["O1", "item1", 10.0, "default_extra"],
@@ -158,12 +158,12 @@ def test_encode_engine_value_list_of_structs():
158158
]
159159

160160

161-
def test_encode_engine_value_struct_with_list():
161+
def test_encode_engine_value_struct_with_list() -> None:
162162
basket = Basket(items=["apple", "banana"])
163163
assert encode_engine_value(basket) == [["apple", "banana"]]
164164

165165

166-
def test_encode_engine_value_nested_struct():
166+
def test_encode_engine_value_nested_struct() -> None:
167167
customer = Customer(name="Alice", order=Order("O1", "item1", 10.0))
168168
assert encode_engine_value(customer) == [
169169
"Alice",
@@ -181,20 +181,20 @@ def test_encode_engine_value_nested_struct():
181181
]
182182

183183

184-
def test_encode_engine_value_empty_list():
184+
def test_encode_engine_value_empty_list() -> None:
185185
assert encode_engine_value([]) == []
186186
assert encode_engine_value([[]]) == [[]]
187187

188188

189-
def test_encode_engine_value_tuple():
189+
def test_encode_engine_value_tuple() -> None:
190190
assert encode_engine_value(()) == []
191191
assert encode_engine_value((1, 2, 3)) == [1, 2, 3]
192192
assert encode_engine_value(((1, 2), (3, 4))) == [[1, 2], [3, 4]]
193193
assert encode_engine_value(([],)) == [[]]
194194
assert encode_engine_value(((),)) == [[]]
195195

196196

197-
def test_encode_engine_value_none():
197+
def test_encode_engine_value_none() -> None:
198198
assert encode_engine_value(None) is None
199199

200200

@@ -323,18 +323,18 @@ def test_make_engine_value_decoder_basic_types() -> None:
323323
),
324324
],
325325
)
326-
def test_struct_decoder_cases(data_type, engine_val, expected):
326+
def test_struct_decoder_cases(data_type: Any, engine_val: Any, expected: Any) -> None:
327327
decoder = build_engine_value_decoder(data_type)
328328
assert decoder(engine_val) == expected
329329

330330

331-
def test_make_engine_value_decoder_collections():
331+
def test_make_engine_value_decoder_list_of_struct() -> None:
332332
# List of structs (dataclass)
333-
decoder = build_engine_value_decoder(list[Order])
334333
engine_val = [
335334
["O1", "item1", 10.0, "default_extra"],
336335
["O2", "item2", 20.0, "default_extra"],
337336
]
337+
decoder = build_engine_value_decoder(list[Order])
338338
assert decoder(engine_val) == [
339339
Order("O1", "item1", 10.0, "default_extra"),
340340
Order("O2", "item2", 20.0, "default_extra"),
@@ -347,13 +347,15 @@ def test_make_engine_value_decoder_collections():
347347
OrderNamedTuple("O2", "item2", 20.0, "default_extra"),
348348
]
349349

350+
351+
def test_make_engine_value_decoder_struct_of_list() -> None:
350352
# Struct with list field
351-
decoder = build_engine_value_decoder(Customer)
352353
engine_val = [
353354
"Alice",
354355
["O1", "item1", 10.0, "default_extra"],
355356
[["vip"], ["premium"]],
356357
]
358+
decoder = build_engine_value_decoder(Customer)
357359
assert decoder(engine_val) == Customer(
358360
"Alice",
359361
Order("O1", "item1", 10.0, "default_extra"),
@@ -368,8 +370,9 @@ def test_make_engine_value_decoder_collections():
368370
[Tag("vip"), Tag("premium")],
369371
)
370372

373+
374+
def test_make_engine_value_decoder_struct_of_struct() -> None:
371375
# Struct with struct field
372-
decoder = build_engine_value_decoder(NestedStruct)
373376
engine_val = [
374377
["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]]],
375378
[
@@ -378,6 +381,7 @@ def test_make_engine_value_decoder_collections():
378381
],
379382
2,
380383
]
384+
decoder = build_engine_value_decoder(NestedStruct)
381385
assert decoder(engine_val) == NestedStruct(
382386
Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")]),
383387
[
@@ -388,11 +392,13 @@ def test_make_engine_value_decoder_collections():
388392
)
389393

390394

391-
def make_engine_order(fields):
395+
def make_engine_order(fields: list[tuple[str, type]]) -> type:
392396
return make_dataclass("EngineOrder", fields)
393397

394398

395-
def make_python_order(fields, defaults=None):
399+
def make_python_order(
400+
fields: list[tuple[str, type]], defaults: dict[str, Any] | None = None
401+
) -> type:
396402
if defaults is None:
397403
defaults = {}
398404
# Move all fields with defaults to the end (Python dataclass requirement)
@@ -466,8 +472,12 @@ def make_python_order(fields, defaults=None):
466472
],
467473
)
468474
def test_field_position_cases(
469-
engine_fields, python_fields, python_defaults, engine_val, expected_python_val
470-
):
475+
engine_fields: list[tuple[str, type]],
476+
python_fields: list[tuple[str, type]],
477+
python_defaults: dict[str, Any],
478+
engine_val: list[Any],
479+
expected_python_val: tuple[Any, ...],
480+
) -> None:
471481
EngineOrder = make_engine_order(engine_fields)
472482
PythonOrder = make_python_order(python_fields, python_defaults)
473483
decoder = build_engine_value_decoder(EngineOrder, PythonOrder)
@@ -528,9 +538,9 @@ class OrderKey:
528538

529539

530540
def test_vector_as_vector() -> None:
531-
value: IntVectorType = [1, 2, 3, 4, 5]
541+
value = np.array([1, 2, 3, 4, 5], dtype=np.int64)
532542
encoded = encode_engine_value(value)
533-
assert encoded == [1, 2, 3, 4, 5]
543+
assert np.array_equal(encoded, value)
534544
decoded = build_engine_value_decoder(IntVectorType)(encoded)
535545
assert np.array_equal(decoded, value)
536546

@@ -561,7 +571,7 @@ def test_vector_as_list() -> None:
561571
NDArrayInt64Type = NDArray[np.int64]
562572

563573

564-
def test_encode_engine_value_ndarray():
574+
def test_encode_engine_value_ndarray() -> None:
565575
"""Test encoding NDArray vectors to lists for the Rust engine."""
566576
vec_f32: Float32VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float32)
567577
assert np.array_equal(encode_engine_value(vec_f32), [1.0, 2.0, 3.0])
@@ -573,7 +583,7 @@ def test_encode_engine_value_ndarray():
573583
assert np.array_equal(encode_engine_value(vec_nd_f32), [1.0, 2.0, 3.0])
574584

575585

576-
def test_make_engine_value_decoder_ndarray():
586+
def test_make_engine_value_decoder_ndarray() -> None:
577587
"""Test decoding engine lists to NDArray vectors."""
578588
decoder_f32 = build_engine_value_decoder(Float32VectorType)
579589
result_f32 = decoder_f32([1.0, 2.0, 3.0])
@@ -597,16 +607,16 @@ def test_make_engine_value_decoder_ndarray():
597607
assert np.array_equal(result_nd_f32, np.array([1.0, 2.0, 3.0], dtype=np.float32))
598608

599609

600-
def test_roundtrip_ndarray_vector():
610+
def test_roundtrip_ndarray_vector() -> None:
601611
"""Test roundtrip encoding and decoding of NDArray vectors."""
602-
value_f32: Float32VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float32)
612+
value_f32 = np.array([1.0, 2.0, 3.0], dtype=np.float32)
603613
encoded_f32 = encode_engine_value(value_f32)
604614
np.array_equal(encoded_f32, [1.0, 2.0, 3.0])
605615
decoded_f32 = build_engine_value_decoder(Float32VectorType)(encoded_f32)
606616
assert isinstance(decoded_f32, np.ndarray)
607617
assert decoded_f32.dtype == np.float32
608618
assert np.array_equal(decoded_f32, value_f32)
609-
value_i64: Int64VectorType = np.array([1, 2, 3], dtype=np.int64)
619+
value_i64 = np.array([1, 2, 3], dtype=np.int64)
610620
encoded_i64 = encode_engine_value(value_i64)
611621
assert np.array_equal(encoded_i64, [1, 2, 3])
612622
decoded_i64 = build_engine_value_decoder(Int64VectorType)(encoded_i64)
@@ -622,18 +632,18 @@ def test_roundtrip_ndarray_vector():
622632
assert np.array_equal(decoded_nd_f64, value_nd_f64)
623633

624634

625-
def test_ndarray_dimension_mismatch():
635+
def test_ndarray_dimension_mismatch() -> None:
626636
"""Test dimension enforcement for Vector with specified dimension."""
627-
value: Float32VectorType = np.array([1.0, 2.0], dtype=np.float32)
637+
value = np.array([1.0, 2.0], dtype=np.float32)
628638
encoded = encode_engine_value(value)
629639
assert np.array_equal(encoded, [1.0, 2.0])
630640
with pytest.raises(ValueError, match="Vector dimension mismatch"):
631641
build_engine_value_decoder(Float32VectorType)(encoded)
632642

633643

634-
def test_list_vector_backward_compatibility():
644+
def test_list_vector_backward_compatibility() -> None:
635645
"""Test that list-based vectors still work for backward compatibility."""
636-
value: IntVectorType = [1, 2, 3, 4, 5]
646+
value = [1, 2, 3, 4, 5]
637647
encoded = encode_engine_value(value)
638648
assert encoded == [1, 2, 3, 4, 5]
639649
decoded = build_engine_value_decoder(IntVectorType)(encoded)
@@ -647,7 +657,7 @@ def test_list_vector_backward_compatibility():
647657
assert np.array_equal(decoded, [1, 2, 3, 4, 5])
648658

649659

650-
def test_encode_complex_structure_with_ndarray():
660+
def test_encode_complex_structure_with_ndarray() -> None:
651661
"""Test encoding a complex structure that includes an NDArray."""
652662

653663
@dataclass
@@ -660,17 +670,13 @@ class MyStructWithNDArray:
660670
name="test_np", data=np.array([1.0, 0.5], dtype=np.float32), value=100
661671
)
662672
encoded = encode_engine_value(original)
663-
expected = [
664-
"test_np",
665-
[1.0, 0.5],
666-
100,
667-
]
668-
assert encoded[0] == expected[0]
669-
assert np.array_equal(encoded[1], expected[1])
670-
assert encoded[2] == expected[2]
673+
674+
assert encoded[0] == original.name
675+
assert np.array_equal(encoded[1], original.data)
676+
assert encoded[2] == original.value
671677

672678

673-
def test_decode_nullable_ndarray_none_or_value_input():
679+
def test_decode_nullable_ndarray_none_or_value_input() -> None:
674680
"""Test decoding a nullable NDArray with None or value inputs."""
675681
src_type_dict = {
676682
"kind": "Vector",
@@ -694,7 +700,7 @@ def test_decode_nullable_ndarray_none_or_value_input():
694700
)
695701

696702

697-
def test_decode_vector_string():
703+
def test_decode_vector_string() -> None:
698704
"""Test decoding a vector of strings works for Python native list type."""
699705
src_type_dict = {
700706
"kind": "Vector",
@@ -705,7 +711,7 @@ def test_decode_vector_string():
705711
assert decoder(["hello", "world"]) == ["hello", "world"]
706712

707713

708-
def test_decode_error_non_nullable_or_non_list_vector():
714+
def test_decode_error_non_nullable_or_non_list_vector() -> None:
709715
"""Test decoding errors for non-nullable vectors or non-list inputs."""
710716
src_type_dict = {
711717
"kind": "Vector",
@@ -719,7 +725,7 @@ def test_decode_error_non_nullable_or_non_list_vector():
719725
decoder("not a list")
720726

721727

722-
def test_dump_vector_type_annotation_with_dim():
728+
def test_dump_vector_type_annotation_with_dim() -> None:
723729
"""Test dumping a vector type annotation with a specified dimension."""
724730
expected_dump = {
725731
"type": {
@@ -731,7 +737,7 @@ def test_dump_vector_type_annotation_with_dim():
731737
assert dump_engine_object(Float32VectorType) == expected_dump
732738

733739

734-
def test_dump_vector_type_annotation_no_dim():
740+
def test_dump_vector_type_annotation_no_dim() -> None:
735741
"""Test dumping a vector type annotation with no dimension."""
736742
expected_dump_no_dim = {
737743
"type": {

0 commit comments

Comments
 (0)