Skip to content

Commit 575b0b0

Browse files
committed
test(value-convert): add validate_full_roundtrip testing
1 parent dbe7be4 commit 575b0b0

File tree

3 files changed

+56
-47
lines changed

3 files changed

+56
-47
lines changed

python/cocoindex/tests/test_convert.py

Lines changed: 29 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import uuid
22
import datetime
33
from dataclasses import dataclass, make_dataclass
4-
from typing import NamedTuple, Literal
4+
from typing import NamedTuple, Literal, Any, Callable
55
import pytest
66
import cocoindex
77
from cocoindex.typing import encode_enriched_type
@@ -53,7 +53,9 @@ class CustomerNamedTuple(NamedTuple):
5353
tags: list[Tag] | None = None
5454

5555

56-
def build_engine_value_decoder(engine_type_in_py, python_type=None):
56+
def build_engine_value_decoder(
57+
engine_type_in_py: Any, python_type: Any | None = None
58+
) -> Callable[[Any], Any]:
5759
"""
5860
Helper to build a converter for the given engine-side type (as represented in Python).
5961
If python_type is not specified, uses engine_type_in_py as the target.
@@ -62,6 +64,22 @@ def build_engine_value_decoder(engine_type_in_py, python_type=None):
6264
return make_engine_value_decoder([], engine_type, python_type or engine_type_in_py)
6365

6466

67+
def validate_full_roundtrip(
68+
value: Any, output_type: Any, input_type: Any | None = None
69+
) -> None:
70+
from cocoindex import _engine
71+
72+
encoded_value = encode_engine_value(value)
73+
encoded_output_type = encode_enriched_type(output_type)["type"]
74+
value_from_engine = _engine.testutil.seder_roundtrip(
75+
encoded_value, encoded_output_type
76+
)
77+
decoded_value = build_engine_value_decoder(input_type or output_type, output_type)(
78+
value_from_engine
79+
)
80+
assert decoded_value == value
81+
82+
6583
def test_encode_engine_value_basic_types():
6684
assert encode_engine_value(123) == 123
6785
assert encode_engine_value(3.14) == 3.14
@@ -434,57 +452,33 @@ def test_field_position_cases(
434452
assert decoder(engine_val) == PythonOrder(**expected_dict)
435453

436454

437-
def test_roundtrip_ltable():
455+
def test_roundtrip_ltable() -> None:
438456
t = list[Order]
439457
value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
440-
encoded = encode_engine_value(value)
441-
assert encoded == [
442-
["O1", "item1", 10.0, "default_extra"],
443-
["O2", "item2", 20.0, "default_extra"],
444-
]
445-
decoded = build_engine_value_decoder(t)(encoded)
446-
assert decoded == value
458+
validate_full_roundtrip(value, t)
447459

448460
t_nt = list[OrderNamedTuple]
449461
value_nt = [
450462
OrderNamedTuple("O1", "item1", 10.0),
451463
OrderNamedTuple("O2", "item2", 20.0),
452464
]
453-
encoded = encode_engine_value(value_nt)
454-
assert encoded == [
455-
["O1", "item1", 10.0, "default_extra"],
456-
["O2", "item2", 20.0, "default_extra"],
457-
]
458-
decoded = build_engine_value_decoder(t_nt)(encoded)
459-
assert decoded == value_nt
465+
validate_full_roundtrip(value_nt, t_nt)
460466

461467

462-
def test_roundtrip_ktable_str_key():
468+
def test_roundtrip_ktable_str_key() -> None:
463469
t = dict[str, Order]
464470
value = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)}
465-
encoded = encode_engine_value(value)
466-
assert encoded == [
467-
["K1", "O1", "item1", 10.0, "default_extra"],
468-
["K2", "O2", "item2", 20.0, "default_extra"],
469-
]
470-
decoded = build_engine_value_decoder(t)(encoded)
471-
assert decoded == value
471+
validate_full_roundtrip(value, t)
472472

473473
t_nt = dict[str, OrderNamedTuple]
474474
value_nt = {
475475
"K1": OrderNamedTuple("O1", "item1", 10.0),
476476
"K2": OrderNamedTuple("O2", "item2", 20.0),
477477
}
478-
encoded = encode_engine_value(value_nt)
479-
assert encoded == [
480-
["K1", "O1", "item1", 10.0, "default_extra"],
481-
["K2", "O2", "item2", 20.0, "default_extra"],
482-
]
483-
decoded = build_engine_value_decoder(t_nt)(encoded)
484-
assert decoded == value_nt
478+
validate_full_roundtrip(value_nt, t_nt)
485479

486480

487-
def test_roundtrip_ktable_struct_key():
481+
def test_roundtrip_ktable_struct_key() -> None:
488482
@dataclass(frozen=True)
489483
class OrderKey:
490484
shop_id: str
@@ -495,26 +489,14 @@ class OrderKey:
495489
OrderKey("A", 3): Order("O1", "item1", 10.0),
496490
OrderKey("B", 4): Order("O2", "item2", 20.0),
497491
}
498-
encoded = encode_engine_value(value)
499-
assert encoded == [
500-
[["A", 3], "O1", "item1", 10.0, "default_extra"],
501-
[["B", 4], "O2", "item2", 20.0, "default_extra"],
502-
]
503-
decoded = build_engine_value_decoder(t)(encoded)
504-
assert decoded == value
492+
validate_full_roundtrip(value, t)
505493

506494
t_nt = dict[OrderKey, OrderNamedTuple]
507495
value_nt = {
508496
OrderKey("A", 3): OrderNamedTuple("O1", "item1", 10.0),
509497
OrderKey("B", 4): OrderNamedTuple("O2", "item2", 20.0),
510498
}
511-
encoded = encode_engine_value(value_nt)
512-
assert encoded == [
513-
[["A", 3], "O1", "item1", 10.0, "default_extra"],
514-
[["B", 4], "O2", "item2", 20.0, "default_extra"],
515-
]
516-
decoded = build_engine_value_decoder(t_nt)(encoded)
517-
assert decoded == value_nt
499+
validate_full_roundtrip(value_nt, t_nt)
518500

519501

520502
IntVectorType = cocoindex.Vector[int, Literal[5]]

src/base/value.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,3 +1142,13 @@ impl<'a, I: Iterator<Item = &'a Value> + Clone> Serialize for TypedFieldsValue<'
11421142
map.end()
11431143
}
11441144
}
1145+
1146+
pub mod test_util {
1147+
use super::*;
1148+
1149+
pub fn seder_roundtrip(value: &Value, typ: &ValueType) -> Result<Value> {
1150+
let json_value = serde_json::to_value(value)?;
1151+
let roundtrip_value = Value::from_json(json_value, typ)?;
1152+
Ok(roundtrip_value)
1153+
}
1154+
}

src/py/mod.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,19 @@ fn add_auth_entry(key: String, value: Pythonized<serde_json::Value>) -> PyResult
530530
Ok(())
531531
}
532532

533+
#[pyfunction]
534+
fn seder_roundtrip<'py>(
535+
py: Python<'py>,
536+
value: Bound<'py, PyAny>,
537+
typ: Pythonized<ValueType>,
538+
) -> PyResult<Bound<'py, PyAny>> {
539+
unimplemented!()
540+
let typ = typ.into_inner();
541+
let value = value_from_py_object(&typ, &value)?;
542+
let value = value::test_util::seder_roundtrip(&value, &typ).into_py_result()?;
543+
Ok(value_to_py_object(py, &value)?)
544+
}
545+
533546
/// A Python module implemented in Rust.
534547
#[pymodule]
535548
#[pyo3(name = "_engine")]
@@ -558,5 +571,9 @@ fn cocoindex_engine(m: &Bound<'_, PyModule>) -> PyResult<()> {
558571
m.add_class::<RenderedSpec>()?;
559572
m.add_class::<RenderedSpecLine>()?;
560573

574+
let testutil_module = PyModule::new(m.py(), "testutil")?;
575+
testutil_module.add_function(wrap_pyfunction!(seder_roundtrip, &testutil_module)?)?;
576+
m.add_submodule(&testutil_module)?;
577+
561578
Ok(())
562579
}

0 commit comments

Comments
 (0)