Skip to content

Commit c91a51f

Browse files
committed
fix(convert): make type hints required in engine value encoding
1 parent c25dbbe commit c91a51f

File tree

2 files changed

+50
-42
lines changed

2 files changed

+50
-42
lines changed

python/cocoindex/convert.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def _encode_engine_value_core(
177177
return value
178178

179179

180-
def encode_engine_value(value: Any, type_hint: Type[Any] | str | None = None) -> Any:
180+
def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any:
181181
"""
182182
Encode a Python value to an engine value.
183183
@@ -188,9 +188,6 @@ def encode_engine_value(value: Any, type_hint: Type[Any] | str | None = None) ->
188188
Returns:
189189
The encoded engine value
190190
"""
191-
if type_hint is None:
192-
return _encode_engine_value_core(value)
193-
194191
# Analyze type once and reuse the result
195192
type_info = _get_cached_type_info(type_hint, {})
196193
if isinstance(type_info.variant, AnalyzedUnknownType):

python/cocoindex/tests/test_convert.py

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -133,24 +133,24 @@ def validate_full_roundtrip(
133133

134134

135135
def test_encode_engine_value_basic_types() -> None:
136-
assert encode_engine_value(123) == 123
137-
assert encode_engine_value(3.14) == 3.14
138-
assert encode_engine_value("hello") == "hello"
139-
assert encode_engine_value(True) is True
136+
assert encode_engine_value(123, int) == 123
137+
assert encode_engine_value(3.14, float) == 3.14
138+
assert encode_engine_value("hello", str) == "hello"
139+
assert encode_engine_value(True, bool) is True
140140

141141

142142
def test_encode_engine_value_uuid() -> None:
143143
u = uuid.uuid4()
144-
assert encode_engine_value(u) == u
144+
assert encode_engine_value(u, uuid.UUID) == u
145145

146146

147147
def test_encode_engine_value_date_time_types() -> None:
148148
d = datetime.date(2024, 1, 1)
149-
assert encode_engine_value(d) == d
149+
assert encode_engine_value(d, datetime.date) == d
150150
t = datetime.time(12, 30)
151-
assert encode_engine_value(t) == t
151+
assert encode_engine_value(t, datetime.time) == t
152152
dt = datetime.datetime(2024, 1, 1, 12, 30)
153-
assert encode_engine_value(dt) == dt
153+
assert encode_engine_value(dt, datetime.datetime) == dt
154154

155155

156156
def test_encode_scalar_numpy_values() -> None:
@@ -161,17 +161,22 @@ def test_encode_scalar_numpy_values() -> None:
161161
(np.float64(2.718), pytest.approx(2.718)),
162162
]
163163
for np_value, expected in test_cases:
164-
encoded = encode_engine_value(np_value)
164+
encoded = encode_engine_value(np_value, type(np_value))
165165
assert encoded == expected
166166
assert isinstance(encoded, (int, float))
167167

168168

169169
def test_encode_engine_value_struct() -> None:
170170
order = Order(order_id="O123", name="mixed nuts", price=25.0)
171-
assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"]
171+
assert encode_engine_value(order, Order) == [
172+
"O123",
173+
"mixed nuts",
174+
25.0,
175+
"default_extra",
176+
]
172177

173178
order_nt = OrderNamedTuple(order_id="O123", name="mixed nuts", price=25.0)
174-
assert encode_engine_value(order_nt) == [
179+
assert encode_engine_value(order_nt, OrderNamedTuple) == [
175180
"O123",
176181
"mixed nuts",
177182
25.0,
@@ -181,7 +186,7 @@ def test_encode_engine_value_struct() -> None:
181186

182187
def test_encode_engine_value_list_of_structs() -> None:
183188
orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
184-
assert encode_engine_value(orders) == [
189+
assert encode_engine_value(orders, list[Order]) == [
185190
["O1", "item1", 10.0, "default_extra"],
186191
["O2", "item2", 20.0, "default_extra"],
187192
]
@@ -190,20 +195,20 @@ def test_encode_engine_value_list_of_structs() -> None:
190195
OrderNamedTuple("O1", "item1", 10.0),
191196
OrderNamedTuple("O2", "item2", 20.0),
192197
]
193-
assert encode_engine_value(orders_nt) == [
198+
assert encode_engine_value(orders_nt, list[OrderNamedTuple]) == [
194199
["O1", "item1", 10.0, "default_extra"],
195200
["O2", "item2", 20.0, "default_extra"],
196201
]
197202

198203

199204
def test_encode_engine_value_struct_with_list() -> None:
200205
basket = Basket(items=["apple", "banana"])
201-
assert encode_engine_value(basket) == [["apple", "banana"]]
206+
assert encode_engine_value(basket, Basket) == [["apple", "banana"]]
202207

203208

204209
def test_encode_engine_value_nested_struct() -> None:
205210
customer = Customer(name="Alice", order=Order("O1", "item1", 10.0))
206-
assert encode_engine_value(customer) == [
211+
assert encode_engine_value(customer, Customer) == [
207212
"Alice",
208213
["O1", "item1", 10.0, "default_extra"],
209214
None,
@@ -212,28 +217,28 @@ def test_encode_engine_value_nested_struct() -> None:
212217
customer_nt = CustomerNamedTuple(
213218
name="Alice", order=OrderNamedTuple("O1", "item1", 10.0)
214219
)
215-
assert encode_engine_value(customer_nt) == [
220+
assert encode_engine_value(customer_nt, CustomerNamedTuple) == [
216221
"Alice",
217222
["O1", "item1", 10.0, "default_extra"],
218223
None,
219224
]
220225

221226

222227
def test_encode_engine_value_empty_list() -> None:
223-
assert encode_engine_value([]) == []
224-
assert encode_engine_value([[]]) == [[]]
228+
assert encode_engine_value([], list) == []
229+
assert encode_engine_value([[]], list[list]) == [[]]
225230

226231

227232
def test_encode_engine_value_tuple() -> None:
228-
assert encode_engine_value(()) == []
229-
assert encode_engine_value((1, 2, 3)) == [1, 2, 3]
230-
assert encode_engine_value(((1, 2), (3, 4))) == [[1, 2], [3, 4]]
231-
assert encode_engine_value(([],)) == [[]]
232-
assert encode_engine_value(((),)) == [[]]
233+
assert encode_engine_value((), Any) == []
234+
assert encode_engine_value((1, 2, 3), Any) == [1, 2, 3]
235+
assert encode_engine_value(((1, 2), (3, 4)), Any) == [[1, 2], [3, 4]]
236+
assert encode_engine_value(([],), Any) == [[]]
237+
assert encode_engine_value(((),), Any) == [[]]
233238

234239

235240
def test_encode_engine_value_none() -> None:
236-
assert encode_engine_value(None) is None
241+
assert encode_engine_value(None, int | None) is None
237242

238243

239244
def test_roundtrip_basic_types() -> None:
@@ -743,7 +748,7 @@ class OrderKey:
743748

744749
def test_vector_as_vector() -> None:
745750
value = np.array([1, 2, 3, 4, 5], dtype=np.int64)
746-
encoded = encode_engine_value(value)
751+
encoded = encode_engine_value(value, IntVectorType)
747752
assert np.array_equal(encoded, value)
748753
decoded = build_engine_value_decoder(IntVectorType)(encoded)
749754
assert np.array_equal(decoded, value)
@@ -754,7 +759,7 @@ def test_vector_as_vector() -> None:
754759

755760
def test_vector_as_list() -> None:
756761
value: ListIntType = [1, 2, 3, 4, 5]
757-
encoded = encode_engine_value(value)
762+
encoded = encode_engine_value(value, ListIntType)
758763
assert encoded == [1, 2, 3, 4, 5]
759764
decoded = build_engine_value_decoder(ListIntType)(encoded)
760765
assert np.array_equal(decoded, value)
@@ -772,13 +777,19 @@ def test_vector_as_list() -> None:
772777
def test_encode_engine_value_ndarray() -> None:
773778
"""Test encoding NDArray vectors to lists for the Rust engine."""
774779
vec_f32: Float32VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float32)
775-
assert np.array_equal(encode_engine_value(vec_f32), [1.0, 2.0, 3.0])
780+
assert np.array_equal(
781+
encode_engine_value(vec_f32, Float32VectorType), [1.0, 2.0, 3.0]
782+
)
776783
vec_f64: Float64VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float64)
777-
assert np.array_equal(encode_engine_value(vec_f64), [1.0, 2.0, 3.0])
784+
assert np.array_equal(
785+
encode_engine_value(vec_f64, Float64VectorType), [1.0, 2.0, 3.0]
786+
)
778787
vec_i64: Int64VectorType = np.array([1, 2, 3], dtype=np.int64)
779-
assert np.array_equal(encode_engine_value(vec_i64), [1, 2, 3])
788+
assert np.array_equal(encode_engine_value(vec_i64, Int64VectorType), [1, 2, 3])
780789
vec_nd_f32: NDArrayFloat32Type = np.array([1.0, 2.0, 3.0], dtype=np.float32)
781-
assert np.array_equal(encode_engine_value(vec_nd_f32), [1.0, 2.0, 3.0])
790+
assert np.array_equal(
791+
encode_engine_value(vec_nd_f32, NDArrayFloat32Type), [1.0, 2.0, 3.0]
792+
)
782793

783794

784795
def test_make_engine_value_decoder_ndarray() -> None:
@@ -808,21 +819,21 @@ def test_make_engine_value_decoder_ndarray() -> None:
808819
def test_roundtrip_ndarray_vector() -> None:
809820
"""Test roundtrip encoding and decoding of NDArray vectors."""
810821
value_f32 = np.array([1.0, 2.0, 3.0], dtype=np.float32)
811-
encoded_f32 = encode_engine_value(value_f32)
822+
encoded_f32 = encode_engine_value(value_f32, Float32VectorType)
812823
np.array_equal(encoded_f32, [1.0, 2.0, 3.0])
813824
decoded_f32 = build_engine_value_decoder(Float32VectorType)(encoded_f32)
814825
assert isinstance(decoded_f32, np.ndarray)
815826
assert decoded_f32.dtype == np.float32
816827
assert np.array_equal(decoded_f32, value_f32)
817828
value_i64 = np.array([1, 2, 3], dtype=np.int64)
818-
encoded_i64 = encode_engine_value(value_i64)
829+
encoded_i64 = encode_engine_value(value_i64, Int64VectorType)
819830
assert np.array_equal(encoded_i64, [1, 2, 3])
820831
decoded_i64 = build_engine_value_decoder(Int64VectorType)(encoded_i64)
821832
assert isinstance(decoded_i64, np.ndarray)
822833
assert decoded_i64.dtype == np.int64
823834
assert np.array_equal(decoded_i64, value_i64)
824835
value_nd_f64: NDArrayFloat64Type = np.array([1.0, 2.0, 3.0], dtype=np.float64)
825-
encoded_nd_f64 = encode_engine_value(value_nd_f64)
836+
encoded_nd_f64 = encode_engine_value(value_nd_f64, NDArrayFloat64Type)
826837
assert np.array_equal(encoded_nd_f64, [1.0, 2.0, 3.0])
827838
decoded_nd_f64 = build_engine_value_decoder(NDArrayFloat64Type)(encoded_nd_f64)
828839
assert isinstance(decoded_nd_f64, np.ndarray)
@@ -833,7 +844,7 @@ def test_roundtrip_ndarray_vector() -> None:
833844
def test_ndarray_dimension_mismatch() -> None:
834845
"""Test dimension enforcement for Vector with specified dimension."""
835846
value = np.array([1.0, 2.0], dtype=np.float32)
836-
encoded = encode_engine_value(value)
847+
encoded = encode_engine_value(value, NDArray[np.float32])
837848
assert np.array_equal(encoded, [1.0, 2.0])
838849
with pytest.raises(ValueError, match="Vector dimension mismatch"):
839850
build_engine_value_decoder(Float32VectorType)(encoded)
@@ -842,14 +853,14 @@ def test_ndarray_dimension_mismatch() -> None:
842853
def test_list_vector_backward_compatibility() -> None:
843854
"""Test that list-based vectors still work for backward compatibility."""
844855
value = [1, 2, 3, 4, 5]
845-
encoded = encode_engine_value(value)
856+
encoded = encode_engine_value(value, list[int])
846857
assert encoded == [1, 2, 3, 4, 5]
847858
decoded = build_engine_value_decoder(IntVectorType)(encoded)
848859
assert isinstance(decoded, np.ndarray)
849860
assert decoded.dtype == np.int64
850861
assert np.array_equal(decoded, np.array([1, 2, 3, 4, 5], dtype=np.int64))
851862
value_list: ListIntType = [1, 2, 3, 4, 5]
852-
encoded = encode_engine_value(value_list)
863+
encoded = encode_engine_value(value_list, ListIntType)
853864
assert np.array_equal(encoded, [1, 2, 3, 4, 5])
854865
decoded = build_engine_value_decoder(ListIntType)(encoded)
855866
assert np.array_equal(decoded, [1, 2, 3, 4, 5])
@@ -867,7 +878,7 @@ class MyStructWithNDArray:
867878
original = MyStructWithNDArray(
868879
name="test_np", data=np.array([1.0, 0.5], dtype=np.float32), value=100
869880
)
870-
encoded = encode_engine_value(original)
881+
encoded = encode_engine_value(original, MyStructWithNDArray)
871882

872883
assert encoded[0] == original.name
873884
assert np.array_equal(encoded[1], original.data)
@@ -1026,7 +1037,7 @@ def test_full_roundtrip_vector_of_vector() -> None:
10261037
),
10271038
(
10281039
value_f32,
1029-
np.typing.NDArray[np.typing.NDArray[np.float32]],
1040+
np.typing.NDArray[np.float32],
10301041
),
10311042
)
10321043

0 commit comments

Comments
 (0)