Skip to content

Commit 2f4ac25

Browse files
committed
fix: reorder imports and annotate return types
1 parent 525695b commit 2f4ac25

File tree

2 files changed

+97
-105
lines changed

2 files changed

+97
-105
lines changed

python/cocoindex/tests/test_convert.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
1-
import uuid
21
import datetime
2+
import uuid
33
from dataclasses import dataclass, make_dataclass
4-
from typing import NamedTuple, Literal, Any, Annotated, Callable
4+
from typing import Annotated, Any, Callable, Literal, NamedTuple
5+
6+
import numpy as np
57
import pytest
8+
from numpy.typing import NDArray
9+
610
import cocoindex
11+
from cocoindex.convert import (
12+
dump_engine_object,
13+
encode_engine_value,
14+
make_engine_value_decoder,
15+
)
716
from cocoindex.typing import (
8-
encode_enriched_type,
917
Float32,
18+
Float64,
1019
TypeKind,
1120
Vector,
12-
Float32,
13-
Float64,
14-
)
15-
from cocoindex.convert import (
16-
encode_engine_value,
17-
make_engine_value_decoder,
18-
dump_engine_object,
21+
encode_enriched_type,
1922
)
20-
import numpy as np
21-
from numpy.typing import NDArray
2223

2324

2425
@dataclass
@@ -130,7 +131,7 @@ def test_encode_engine_value_date_time_types() -> None:
130131
assert encode_engine_value(dt) == dt
131132

132133

133-
def test_encode_scalar_numpy_values():
134+
def test_encode_scalar_numpy_values() -> None:
134135
"""Test encoding scalar NumPy values to engine-compatible values."""
135136
test_cases = [
136137
(np.int64(42), 42),
@@ -228,7 +229,7 @@ def test_roundtrip_basic_types() -> None:
228229
)
229230

230231

231-
def test_decode_scalar_numpy_values():
232+
def test_decode_scalar_numpy_values() -> None:
232233
test_cases = [
233234
({"kind": "Int64"}, np.int64, 42, np.int64(42)),
234235
({"kind": "Float32"}, np.float32, 3.14, np.float32(3.14)),
@@ -241,29 +242,29 @@ def test_decode_scalar_numpy_values():
241242
assert result == expected
242243

243244

244-
def test_non_ndarray_vector_decoding():
245+
def test_non_ndarray_vector_decoding() -> None:
245246
# Test list[np.float64]
246247
src_type = {
247248
"kind": "Vector",
248249
"element_type": {"kind": "Float64"},
249250
"dimension": None,
250251
}
251-
dst_type = list[np.float64]
252-
decoder = make_engine_value_decoder(["field"], src_type, dst_type)
253-
input_value = [1.0, 2.0, 3.0]
254-
result = decoder(input_value)
252+
dst_type_float = list[np.float64]
253+
decoder = make_engine_value_decoder(["field"], src_type, dst_type_float)
254+
input_numbers = [1.0, 2.0, 3.0]
255+
result = decoder(input_numbers)
255256
assert isinstance(result, list)
256257
assert all(isinstance(x, np.float64) for x in result)
257258
assert result == [np.float64(1.0), np.float64(2.0), np.float64(3.0)]
258259

259260
# Test list[Uuid]
260261
src_type = {"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None}
261-
dst_type = list[uuid.UUID]
262-
decoder = make_engine_value_decoder(["field"], src_type, dst_type)
262+
dst_type_uuid = list[uuid.UUID]
263+
decoder = make_engine_value_decoder(["field"], src_type, dst_type_uuid)
263264
uuid1 = uuid.uuid4()
264265
uuid2 = uuid.uuid4()
265-
input_value = [uuid1.bytes, uuid2.bytes]
266-
result = decoder(input_value)
266+
input_bytes = [uuid1.bytes, uuid2.bytes]
267+
result = decoder(input_bytes)
267268
assert isinstance(result, list)
268269
assert all(isinstance(x, uuid.UUID) for x in result)
269270
assert result == [uuid1, uuid2]
@@ -946,6 +947,6 @@ class MixedStruct:
946947
python_float=3.14,
947948
string="hello, world",
948949
annotated_int=np.int64(42),
949-
annotated_float=np.float32(3.14),
950+
annotated_float=2.0,
950951
)
951952
validate_full_roundtrip(instance, MixedStruct)

python/cocoindex/tests/test_typing.py

Lines changed: 72 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,23 @@
11
import dataclasses
22
import datetime
33
import uuid
4-
from typing import (
5-
Annotated,
6-
List,
7-
Dict,
8-
Literal,
9-
Any,
10-
get_args,
11-
NamedTuple,
12-
)
13-
from collections.abc import Sequence, Mapping
14-
import pytest
4+
from collections.abc import Mapping, Sequence
5+
from typing import Annotated, Any, Dict, List, Literal, NamedTuple, get_args, get_origin
6+
157
import numpy as np
8+
import pytest
169
from numpy.typing import NDArray
1710

1811
from cocoindex.typing import (
19-
analyze_type_info,
20-
Vector,
21-
VectorInfo,
22-
TypeKind,
23-
TypeAttr,
12+
AnalyzedTypeInfo,
2413
Float32,
2514
Float64,
15+
TypeAttr,
16+
TypeKind,
17+
Vector,
18+
VectorInfo,
19+
analyze_type_info,
2620
encode_enriched_type,
27-
AnalyzedTypeInfo,
2821
)
2922

3023

@@ -42,61 +35,57 @@ class SimpleNamedTuple(NamedTuple):
4235
def test_ndarray_float32_no_dim() -> None:
4336
typ = NDArray[np.float32]
4437
result = analyze_type_info(typ)
45-
assert result == AnalyzedTypeInfo(
46-
kind="Vector",
47-
vector_info=VectorInfo(dim=None),
48-
elem_type=Float32,
49-
key_type=None,
50-
struct_type=None,
51-
np_number_type=np.ndarray[tuple[int, ...], np.dtype[np.float32]],
52-
attrs=None,
53-
nullable=False,
54-
)
38+
assert result.kind == "Vector"
39+
assert result.vector_info == VectorInfo(dim=None)
40+
assert result.elem_type == Float32
41+
assert result.key_type is None
42+
assert result.struct_type is None
43+
assert result.nullable is False
44+
assert result.np_number_type is not None
45+
assert get_origin(result.np_number_type) == np.ndarray
46+
assert get_args(result.np_number_type)[1] == np.dtype[np.float32]
5547

5648

5749
def test_vector_float32_no_dim() -> None:
5850
typ = Vector[np.float32]
5951
result = analyze_type_info(typ)
60-
assert result == AnalyzedTypeInfo(
61-
kind="Vector",
62-
vector_info=VectorInfo(dim=None),
63-
elem_type=Float32,
64-
key_type=None,
65-
struct_type=None,
66-
np_number_type=np.ndarray[tuple[int, ...], np.dtype[np.float32]],
67-
attrs=None,
68-
nullable=False,
69-
)
52+
assert result.kind == "Vector"
53+
assert result.vector_info == VectorInfo(dim=None)
54+
assert result.elem_type == Float32
55+
assert result.key_type is None
56+
assert result.struct_type is None
57+
assert result.nullable is False
58+
assert result.np_number_type is not None
59+
assert get_origin(result.np_number_type) == np.ndarray
60+
assert get_args(result.np_number_type)[1] == np.dtype[np.float32]
7061

7162

7263
def test_ndarray_float64_with_dim() -> None:
7364
typ = Annotated[NDArray[np.float64], VectorInfo(dim=128)]
7465
result = analyze_type_info(typ)
75-
assert result == AnalyzedTypeInfo(
76-
kind="Vector",
77-
vector_info=VectorInfo(dim=128),
78-
elem_type=Float64,
79-
key_type=None,
80-
struct_type=None,
81-
np_number_type=np.ndarray[tuple[int, ...], np.dtype[np.float64]],
82-
attrs=None,
83-
nullable=False,
84-
)
66+
assert result.kind == "Vector"
67+
assert result.vector_info == VectorInfo(dim=128)
68+
assert result.elem_type == Float64
69+
assert result.key_type is None
70+
assert result.struct_type is None
71+
assert result.nullable is False
72+
assert result.np_number_type is not None
73+
assert get_origin(result.np_number_type) == np.ndarray
74+
assert get_args(result.np_number_type)[1] == np.dtype[np.float64]
8575

8676

8777
def test_vector_float32_with_dim() -> None:
8878
typ = Vector[np.float32, Literal[384]]
8979
result = analyze_type_info(typ)
90-
assert result == AnalyzedTypeInfo(
91-
kind="Vector",
92-
vector_info=VectorInfo(dim=384),
93-
elem_type=Float32,
94-
key_type=None,
95-
struct_type=None,
96-
np_number_type=np.ndarray[tuple[int, ...], np.dtype[np.float32]],
97-
attrs=None,
98-
nullable=False,
99-
)
80+
assert result.kind == "Vector"
81+
assert result.vector_info == VectorInfo(dim=384)
82+
assert result.elem_type == Float32
83+
assert result.key_type is None
84+
assert result.struct_type is None
85+
assert result.nullable is False
86+
assert result.np_number_type is not None
87+
assert get_origin(result.np_number_type) == np.ndarray
88+
assert get_args(result.np_number_type)[1] == np.dtype[np.float32]
10089

10190

10291
def test_ndarray_int64_no_dim() -> None:
@@ -105,37 +94,39 @@ def test_ndarray_int64_no_dim() -> None:
10594
assert result.kind == "Vector"
10695
assert result.vector_info == VectorInfo(dim=None)
10796
assert get_args(result.elem_type) == (int, TypeKind("Int64"))
108-
assert not result.nullable
97+
assert result.nullable is False
98+
assert result.np_number_type is not None
99+
assert get_origin(result.np_number_type) == np.ndarray
100+
assert get_args(result.np_number_type)[1] == np.dtype[np.int64]
109101

110102

111103
def test_nullable_ndarray() -> None:
112104
typ = NDArray[np.float32] | None
113105
result = analyze_type_info(typ)
114-
assert result == AnalyzedTypeInfo(
115-
kind="Vector",
116-
vector_info=VectorInfo(dim=None),
117-
elem_type=Float32,
118-
key_type=None,
119-
struct_type=None,
120-
np_number_type=np.ndarray[tuple[int, ...], np.dtype[np.float32]],
121-
attrs=None,
122-
nullable=True,
123-
)
106+
assert result.kind == "Vector"
107+
assert result.vector_info == VectorInfo(dim=None)
108+
assert result.elem_type == Float32
109+
assert result.key_type is None
110+
assert result.struct_type is None
111+
assert result.nullable is True
112+
assert result.np_number_type is not None
113+
assert get_origin(result.np_number_type) == np.ndarray
114+
assert get_args(result.np_number_type)[1] == np.dtype[np.float32]
124115

125116

126-
def test_scalar_numpy_types():
117+
def test_scalar_numpy_types() -> None:
127118
for np_type, expected_kind in [
128119
(np.int64, "Int64"),
129120
(np.float32, "Float32"),
130121
(np.float64, "Float64"),
131122
]:
132123
type_info = analyze_type_info(np_type)
133-
assert (
134-
type_info.kind == expected_kind
135-
), f"Expected {expected_kind} for {np_type}, got {type_info.kind}"
136-
assert (
137-
type_info.np_number_type == np_type
138-
), f"Expected {np_type}, got {type_info.np_number_type}"
124+
assert type_info.kind == expected_kind, (
125+
f"Expected {expected_kind} for {np_type}, got {type_info.kind}"
126+
)
127+
assert type_info.np_number_type == np_type, (
128+
f"Expected {np_type}, got {type_info.np_number_type}"
129+
)
139130
assert type_info.elem_type is None
140131
assert type_info.vector_info is None
141132

@@ -144,7 +135,7 @@ def test_vector_str() -> None:
144135
typ = Vector[str]
145136
result = analyze_type_info(typ)
146137
assert result.kind == "Vector"
147-
assert result.elem_type == str
138+
assert result.elem_type is str
148139
assert result.vector_info == VectorInfo(dim=None)
149140

150141

@@ -160,7 +151,7 @@ def test_non_numpy_vector() -> None:
160151
typ = Vector[float, Literal[3]]
161152
result = analyze_type_info(typ)
162153
assert result.kind == "Vector"
163-
assert result.elem_type == float
154+
assert result.elem_type is float
164155
assert result.vector_info == VectorInfo(dim=3)
165156

166157

@@ -504,16 +495,16 @@ def test_encode_enriched_type_nullable() -> None:
504495
assert result["nullable"] is True
505496

506497

507-
def test_encode_scalar_numpy_types_schema():
498+
def test_encode_scalar_numpy_types_schema() -> None:
508499
for np_type, expected_kind in [
509500
(np.int64, "Int64"),
510501
(np.float32, "Float32"),
511502
(np.float64, "Float64"),
512503
]:
513504
schema = encode_enriched_type(np_type)
514-
assert (
515-
schema["type"]["kind"] == expected_kind
516-
), f"Expected {expected_kind} for {np_type}, got {schema['type']['kind']}"
505+
assert schema["type"]["kind"] == expected_kind, (
506+
f"Expected {expected_kind} for {np_type}, got {schema['type']['kind']}"
507+
)
517508
assert not schema.get("nullable", False)
518509

519510

0 commit comments

Comments
 (0)