|
7 | 7 | from cocoindex.typing import ( |
8 | 8 | encode_enriched_type, |
9 | 9 | Vector, |
| 10 | + Float32, |
| 11 | + Float64, |
10 | 12 | ) |
11 | 13 | from cocoindex.convert import ( |
12 | 14 | encode_engine_value, |
@@ -74,25 +76,36 @@ def build_engine_value_decoder( |
74 | 76 |
|
75 | 77 |
|
76 | 78 | def validate_full_roundtrip( |
77 | | - value: Any, output_type: Any, input_type: Any | None = None |
| 79 | + value: Any, |
| 80 | + value_type: Any = None, |
| 81 | + *other_decoded_values: tuple[Any, Any], |
78 | 82 | ) -> None: |
79 | 83 | """ |
80 | 84 | Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type). |
81 | 85 |
|
82 | | - If `input_type` is not specified, uses `output_type` as the target. |
| 86 | + `other_decoded_values` is a tuple of (value, type) pairs. |
| 87 | + If provided, also validate the value can be decoded to the other types. |
83 | 88 | """ |
84 | 89 | from cocoindex import _engine |
85 | 90 |
|
86 | 91 | encoded_value = encode_engine_value(value) |
87 | | - encoded_output_type = encode_enriched_type(output_type)["type"] |
| 92 | + value_type = value_type or type(value) |
| 93 | + encoded_output_type = encode_enriched_type(value_type)["type"] |
88 | 94 | value_from_engine = _engine.testutil.seder_roundtrip( |
89 | 95 | encoded_value, encoded_output_type |
90 | 96 | ) |
91 | | - decoded_value = build_engine_value_decoder(input_type or output_type, output_type)( |
| 97 | + decoded_value = build_engine_value_decoder(value_type, value_type)( |
92 | 98 | value_from_engine |
93 | 99 | ) |
94 | 100 | np.testing.assert_array_equal(decoded_value, value) |
95 | 101 |
|
| 102 | + if other_decoded_values is not None: |
| 103 | + for other_value, other_type in other_decoded_values: |
| 104 | + other_decoded_value = build_engine_value_decoder(other_type, other_type)( |
| 105 | + value_from_engine |
| 106 | + ) |
| 107 | + np.testing.assert_array_equal(other_decoded_value, other_value) |
| 108 | + |
96 | 109 |
|
97 | 110 | def test_encode_engine_value_basic_types(): |
98 | 111 | assert encode_engine_value(123) == 123 |
@@ -185,16 +198,14 @@ def test_encode_engine_value_none(): |
185 | 198 | assert encode_engine_value(None) is None |
186 | 199 |
|
187 | 200 |
|
188 | | -def test_make_engine_value_decoder_basic_types(): |
189 | | - for engine_type_in_py, value in [ |
190 | | - (int, 42), |
191 | | - (float, 3.14), |
192 | | - (str, "hello"), |
193 | | - (bool, True), |
194 | | - # (type(None), None), # Removed unsupported NoneType |
195 | | - ]: |
196 | | - decoder = build_engine_value_decoder(engine_type_in_py) |
197 | | - assert decoder(value) == value |
| 201 | +def test_make_engine_value_decoder_basic_types() -> None: |
| 202 | + validate_full_roundtrip(42, int) |
| 203 | + validate_full_roundtrip(3.25, float, (3.25, Float64)) |
| 204 | + validate_full_roundtrip(3.25, Float64, (3.25, float)) |
| 205 | + validate_full_roundtrip(3.25, Float32) |
| 206 | + validate_full_roundtrip("hello", str) |
| 207 | + validate_full_roundtrip(True, bool) |
| 208 | + validate_full_roundtrip(False, bool) |
198 | 209 |
|
199 | 210 |
|
200 | 211 | @pytest.mark.parametrize( |
|
0 commit comments