Skip to content

Commit 48a0331

Browse files
authored
test(value-convert): add validate_full_roundtrip testing (#594)
* test(value-convert): add `validate_full_roundtrip` testing * add more comments
1 parent dbe7be4 commit 48a0331

File tree

3 files changed

+60
-47
lines changed

3 files changed

+60
-47
lines changed

python/cocoindex/tests/test_convert.py

Lines changed: 34 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import uuid
22
import datetime
33
from dataclasses import dataclass, make_dataclass
4-
from typing import NamedTuple, Literal
4+
from typing import NamedTuple, Literal, Any, Callable
55
import pytest
66
import cocoindex
77
from cocoindex.typing import encode_enriched_type
@@ -53,7 +53,9 @@ class CustomerNamedTuple(NamedTuple):
5353
tags: list[Tag] | None = None
5454

5555

56-
def build_engine_value_decoder(engine_type_in_py, python_type=None):
56+
def build_engine_value_decoder(
57+
engine_type_in_py: Any, python_type: Any | None = None
58+
) -> Callable[[Any], Any]:
5759
"""
5860
Helper to build a converter for the given engine-side type (as represented in Python).
5961
If python_type is not specified, uses engine_type_in_py as the target.
@@ -62,6 +64,27 @@ def build_engine_value_decoder(engine_type_in_py, python_type=None):
6264
return make_engine_value_decoder([], engine_type, python_type or engine_type_in_py)
6365

6466

67+
def validate_full_roundtrip(
68+
value: Any, output_type: Any, input_type: Any | None = None
69+
) -> None:
70+
"""
71+
Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
72+
73+
If `input_type` is not specified, uses `output_type` as the target.
74+
"""
75+
from cocoindex import _engine
76+
77+
encoded_value = encode_engine_value(value)
78+
encoded_output_type = encode_enriched_type(output_type)["type"]
79+
value_from_engine = _engine.testutil.seder_roundtrip(
80+
encoded_value, encoded_output_type
81+
)
82+
decoded_value = build_engine_value_decoder(input_type or output_type, output_type)(
83+
value_from_engine
84+
)
85+
assert decoded_value == value
86+
87+
6588
def test_encode_engine_value_basic_types():
6689
assert encode_engine_value(123) == 123
6790
assert encode_engine_value(3.14) == 3.14
@@ -434,57 +457,33 @@ def test_field_position_cases(
434457
assert decoder(engine_val) == PythonOrder(**expected_dict)
435458

436459

437-
def test_roundtrip_ltable():
460+
def test_roundtrip_ltable() -> None:
438461
t = list[Order]
439462
value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
440-
encoded = encode_engine_value(value)
441-
assert encoded == [
442-
["O1", "item1", 10.0, "default_extra"],
443-
["O2", "item2", 20.0, "default_extra"],
444-
]
445-
decoded = build_engine_value_decoder(t)(encoded)
446-
assert decoded == value
463+
validate_full_roundtrip(value, t)
447464

448465
t_nt = list[OrderNamedTuple]
449466
value_nt = [
450467
OrderNamedTuple("O1", "item1", 10.0),
451468
OrderNamedTuple("O2", "item2", 20.0),
452469
]
453-
encoded = encode_engine_value(value_nt)
454-
assert encoded == [
455-
["O1", "item1", 10.0, "default_extra"],
456-
["O2", "item2", 20.0, "default_extra"],
457-
]
458-
decoded = build_engine_value_decoder(t_nt)(encoded)
459-
assert decoded == value_nt
470+
validate_full_roundtrip(value_nt, t_nt)
460471

461472

462-
def test_roundtrip_ktable_str_key():
473+
def test_roundtrip_ktable_str_key() -> None:
463474
t = dict[str, Order]
464475
value = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)}
465-
encoded = encode_engine_value(value)
466-
assert encoded == [
467-
["K1", "O1", "item1", 10.0, "default_extra"],
468-
["K2", "O2", "item2", 20.0, "default_extra"],
469-
]
470-
decoded = build_engine_value_decoder(t)(encoded)
471-
assert decoded == value
476+
validate_full_roundtrip(value, t)
472477

473478
t_nt = dict[str, OrderNamedTuple]
474479
value_nt = {
475480
"K1": OrderNamedTuple("O1", "item1", 10.0),
476481
"K2": OrderNamedTuple("O2", "item2", 20.0),
477482
}
478-
encoded = encode_engine_value(value_nt)
479-
assert encoded == [
480-
["K1", "O1", "item1", 10.0, "default_extra"],
481-
["K2", "O2", "item2", 20.0, "default_extra"],
482-
]
483-
decoded = build_engine_value_decoder(t_nt)(encoded)
484-
assert decoded == value_nt
483+
validate_full_roundtrip(value_nt, t_nt)
485484

486485

487-
def test_roundtrip_ktable_struct_key():
486+
def test_roundtrip_ktable_struct_key() -> None:
488487
@dataclass(frozen=True)
489488
class OrderKey:
490489
shop_id: str
@@ -495,26 +494,14 @@ class OrderKey:
495494
OrderKey("A", 3): Order("O1", "item1", 10.0),
496495
OrderKey("B", 4): Order("O2", "item2", 20.0),
497496
}
498-
encoded = encode_engine_value(value)
499-
assert encoded == [
500-
[["A", 3], "O1", "item1", 10.0, "default_extra"],
501-
[["B", 4], "O2", "item2", 20.0, "default_extra"],
502-
]
503-
decoded = build_engine_value_decoder(t)(encoded)
504-
assert decoded == value
497+
validate_full_roundtrip(value, t)
505498

506499
t_nt = dict[OrderKey, OrderNamedTuple]
507500
value_nt = {
508501
OrderKey("A", 3): OrderNamedTuple("O1", "item1", 10.0),
509502
OrderKey("B", 4): OrderNamedTuple("O2", "item2", 20.0),
510503
}
511-
encoded = encode_engine_value(value_nt)
512-
assert encoded == [
513-
[["A", 3], "O1", "item1", 10.0, "default_extra"],
514-
[["B", 4], "O2", "item2", 20.0, "default_extra"],
515-
]
516-
decoded = build_engine_value_decoder(t_nt)(encoded)
517-
assert decoded == value_nt
504+
validate_full_roundtrip(value_nt, t_nt)
518505

519506

520507
IntVectorType = cocoindex.Vector[int, Literal[5]]

src/base/value.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,3 +1142,13 @@ impl<'a, I: Iterator<Item = &'a Value> + Clone> Serialize for TypedFieldsValue<'
11421142
map.end()
11431143
}
11441144
}
1145+
1146+
pub mod test_util {
1147+
use super::*;
1148+
1149+
pub fn seder_roundtrip(value: &Value, typ: &ValueType) -> Result<Value> {
1150+
let json_value = serde_json::to_value(value)?;
1151+
let roundtrip_value = Value::from_json(json_value, typ)?;
1152+
Ok(roundtrip_value)
1153+
}
1154+
}

src/py/mod.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,18 @@ fn add_auth_entry(key: String, value: Pythonized<serde_json::Value>) -> PyResult
530530
Ok(())
531531
}
532532

533+
#[pyfunction]
534+
fn seder_roundtrip<'py>(
535+
py: Python<'py>,
536+
value: Bound<'py, PyAny>,
537+
typ: Pythonized<ValueType>,
538+
) -> PyResult<Bound<'py, PyAny>> {
539+
let typ = typ.into_inner();
540+
let value = value_from_py_object(&typ, &value)?;
541+
let value = value::test_util::seder_roundtrip(&value, &typ).into_py_result()?;
542+
Ok(value_to_py_object(py, &value)?)
543+
}
544+
533545
/// A Python module implemented in Rust.
534546
#[pymodule]
535547
#[pyo3(name = "_engine")]
@@ -558,5 +570,9 @@ fn cocoindex_engine(m: &Bound<'_, PyModule>) -> PyResult<()> {
558570
m.add_class::<RenderedSpec>()?;
559571
m.add_class::<RenderedSpecLine>()?;
560572

573+
let testutil_module = PyModule::new(m.py(), "testutil")?;
574+
testutil_module.add_function(wrap_pyfunction!(seder_roundtrip, &testutil_module)?)?;
575+
m.add_submodule(&testutil_module)?;
576+
561577
Ok(())
562578
}

0 commit comments

Comments
 (0)