Skip to content

Commit cc13b76

Browse files
committed
test: enhance validate_full_roundtrip()
1 parent 91b1e84 commit cc13b76

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

python/cocoindex/tests/test_convert.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from cocoindex.typing import (
88
encode_enriched_type,
99
Vector,
10+
Float32,
11+
Float64,
1012
)
1113
from cocoindex.convert import (
1214
encode_engine_value,
@@ -74,25 +76,36 @@ def build_engine_value_decoder(
7476

7577

7678
def validate_full_roundtrip(
77-
value: Any, output_type: Any, input_type: Any | None = None
79+
value: Any,
80+
value_type: Any = None,
81+
*other_decoded_values: tuple[Any, Any],
7882
) -> None:
7983
"""
8084
Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
8185
82-
If `input_type` is not specified, uses `output_type` as the target.
86+
`other_decoded_values` is a tuple of (value, type) pairs.
87+
If provided, also validate the value can be decoded to the other types.
8388
"""
8489
from cocoindex import _engine
8590

8691
encoded_value = encode_engine_value(value)
87-
encoded_output_type = encode_enriched_type(output_type)["type"]
92+
value_type = value_type or type(value)
93+
encoded_output_type = encode_enriched_type(value_type)["type"]
8894
value_from_engine = _engine.testutil.seder_roundtrip(
8995
encoded_value, encoded_output_type
9096
)
91-
decoded_value = build_engine_value_decoder(input_type or output_type, output_type)(
97+
decoded_value = build_engine_value_decoder(value_type, value_type)(
9298
value_from_engine
9399
)
94100
np.testing.assert_array_equal(decoded_value, value)
95101

102+
if other_decoded_values is not None:
103+
for other_value, other_type in other_decoded_values:
104+
other_decoded_value = build_engine_value_decoder(other_type, other_type)(
105+
value_from_engine
106+
)
107+
np.testing.assert_array_equal(other_decoded_value, other_value)
108+
96109

97110
def test_encode_engine_value_basic_types():
98111
assert encode_engine_value(123) == 123
@@ -185,16 +198,14 @@ def test_encode_engine_value_none():
185198
assert encode_engine_value(None) is None
186199

187200

188-
def test_make_engine_value_decoder_basic_types():
189-
for engine_type_in_py, value in [
190-
(int, 42),
191-
(float, 3.14),
192-
(str, "hello"),
193-
(bool, True),
194-
# (type(None), None), # Removed unsupported NoneType
195-
]:
196-
decoder = build_engine_value_decoder(engine_type_in_py)
197-
assert decoder(value) == value
201+
def test_make_engine_value_decoder_basic_types() -> None:
202+
validate_full_roundtrip(42, int)
203+
validate_full_roundtrip(3.25, float, (3.25, Float64))
204+
validate_full_roundtrip(3.25, Float64, (3.25, float))
205+
validate_full_roundtrip(3.25, Float32)
206+
validate_full_roundtrip("hello", str)
207+
validate_full_roundtrip(True, bool)
208+
validate_full_roundtrip(False, bool)
198209

199210

200211
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)