Skip to content

Commit 91b1e84

Browse files
fix: unit tests for python values (#452)
* fix: unit tests for python values * fix: KTable test * chore: cleanup * chore: use PartialEq impl * fix: move back to value_to_py_object * chore: minor cleanup
1 parent 039d924 commit 91b1e84

File tree

3 files changed

+209
-6
lines changed

3 files changed

+209
-6
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ name = "cocoindex_engine"
1515
crate-type = ["cdylib"]
1616

1717
[dependencies]
18-
pyo3 = { version = "0.25.0", features = ["chrono"] }
18+
pyo3 = { version = "0.25.0", features = ["chrono", "auto-initialize"] }
1919
pythonize = "0.25.0"
2020
pyo3-async-runtimes = { version = "0.25.0", features = ["tokio-runtime"] }
2121

src/base/value.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ impl<'de> Deserialize<'de> for RangeValue {
7373
}
7474

7575
/// Value of key.
76-
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
76+
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize)]
7777
pub enum KeyValue {
7878
Bytes(Bytes),
7979
Str(Arc<str>),
@@ -362,7 +362,7 @@ impl KeyValue {
362362
}
363363
}
364364

365-
#[derive(Debug, Clone)]
365+
#[derive(Debug, Clone, PartialEq, Deserialize)]
366366
pub enum BasicValue {
367367
Bytes(Bytes),
368368
Str(Arc<str>),
@@ -543,7 +543,7 @@ impl BasicValue {
543543
}
544544
}
545545

546-
#[derive(Debug, Clone, Default)]
546+
#[derive(Debug, Clone, Default, PartialEq, Deserialize)]
547547
pub enum Value<VS = ScopeValue> {
548548
#[default]
549549
Null,
@@ -779,7 +779,7 @@ impl<VS> Value<VS> {
779779
}
780780
}
781781

782-
#[derive(Debug, Clone)]
782+
#[derive(Debug, Clone, PartialEq, Deserialize)]
783783
pub struct FieldValues<VS = ScopeValue> {
784784
pub fields: Vec<Value<VS>>,
785785
}
@@ -853,7 +853,7 @@ where
853853
}
854854
}
855855

856-
#[derive(Debug, Clone, Serialize)]
856+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
857857
pub struct ScopeValue(pub FieldValues);
858858

859859
impl Deref for ScopeValue {

src/py/convert.rs

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use std::sync::Arc;
1414
use super::IntoPyResult;
1515
use crate::base::{schema, value};
1616

17+
#[derive(Debug)]
1718
pub struct Pythonized<T>(pub T);
1819

1920
impl<'py, T: DeserializeOwned> FromPyObject<'py> for Pythonized<T> {
@@ -261,6 +262,7 @@ fn field_values_from_py_object<'py>(
261262
list.len()
262263
)));
263264
}
265+
264266
Ok(value::FieldValues {
265267
fields: schema
266268
.fields
@@ -291,13 +293,15 @@ pub fn value_from_py_object<'py>(
291293
.into_iter()
292294
.map(|v| field_values_from_py_object(&schema.row, &v))
293295
.collect::<PyResult<Vec<_>>>()?;
296+
294297
match schema.kind {
295298
schema::TableKind::UTable => {
296299
value::Value::UTable(values.into_iter().map(|v| v.into()).collect())
297300
}
298301
schema::TableKind::LTable => {
299302
value::Value::LTable(values.into_iter().map(|v| v.into()).collect())
300303
}
304+
301305
schema::TableKind::KTable => value::Value::KTable(
302306
values
303307
.into_iter()
@@ -319,3 +323,202 @@ pub fn value_from_py_object<'py>(
319323
};
320324
Ok(result)
321325
}
326+
327+
#[cfg(test)]
328+
mod tests {
329+
use super::*;
330+
use crate::base::schema;
331+
use crate::base::value;
332+
use crate::base::value::ScopeValue;
333+
use pyo3::Python;
334+
use std::collections::BTreeMap;
335+
use std::sync::Arc;
336+
337+
fn assert_roundtrip_conversion(original_value: &value::Value, value_type: &schema::ValueType) {
338+
Python::with_gil(|py| {
339+
// Convert Rust value to Python object using value_to_py_object
340+
let py_object = value_to_py_object(py, original_value)
341+
.expect("Failed to convert Rust value to Python object");
342+
343+
println!("Python object: {:?}", py_object);
344+
let roundtripped_value =
345+
value_from_py_object(value_type, &py_object)
346+
.expect("Failed to convert Python object back to Rust value");
347+
348+
println!("Roundtripped value: {:?}", roundtripped_value);
349+
assert_eq!(original_value, &roundtripped_value, "Value mismatch after roundtrip");
350+
});
351+
}
352+
353+
#[test]
354+
fn test_roundtrip_basic_values() {
355+
let values_and_types = vec![
356+
(
357+
value::Value::Basic(value::BasicValue::Int64(42)),
358+
schema::ValueType::Basic(schema::BasicValueType::Int64),
359+
),
360+
(
361+
value::Value::Basic(value::BasicValue::Float64(3.14)),
362+
schema::ValueType::Basic(schema::BasicValueType::Float64),
363+
),
364+
(
365+
value::Value::Basic(value::BasicValue::Str(Arc::from("hello"))),
366+
schema::ValueType::Basic(schema::BasicValueType::Str),
367+
),
368+
(
369+
value::Value::Basic(value::BasicValue::Bool(true)),
370+
schema::ValueType::Basic(schema::BasicValueType::Bool),
371+
),
372+
];
373+
374+
for (val, typ) in values_and_types {
375+
assert_roundtrip_conversion(&val, &typ);
376+
}
377+
}
378+
379+
#[test]
380+
fn test_roundtrip_struct() {
381+
let struct_schema = schema::StructSchema {
382+
description: Some(Arc::from("Test struct description")),
383+
fields: Arc::new(vec![
384+
schema::FieldSchema {
385+
name: "a".to_string(),
386+
value_type: schema::EnrichedValueType {
387+
typ: schema::ValueType::Basic(schema::BasicValueType::Int64),
388+
nullable: false,
389+
attrs: Default::default(),
390+
},
391+
},
392+
schema::FieldSchema {
393+
name: "b".to_string(),
394+
value_type: schema::EnrichedValueType {
395+
typ: schema::ValueType::Basic(schema::BasicValueType::Str),
396+
nullable: false,
397+
attrs: Default::default(),
398+
},
399+
},
400+
]),
401+
};
402+
403+
let struct_val_data = value::FieldValues {
404+
fields: vec![
405+
value::Value::Basic(value::BasicValue::Int64(10)),
406+
value::Value::Basic(value::BasicValue::Str(Arc::from("world"))),
407+
],
408+
};
409+
410+
let struct_val = value::Value::Struct(struct_val_data);
411+
let struct_typ = schema::ValueType::Struct(struct_schema); // No clone needed
412+
413+
assert_roundtrip_conversion(&struct_val, &struct_typ);
414+
}
415+
416+
#[test]
417+
fn test_roundtrip_table_types() {
418+
let row_schema_struct = Arc::new(schema::StructSchema {
419+
description: Some(Arc::from("Test table row description")),
420+
fields: Arc::new(vec![
421+
schema::FieldSchema {
422+
name: "key_col".to_string(), // Will be used as key for KTable implicitly
423+
value_type: schema::EnrichedValueType {
424+
typ: schema::ValueType::Basic(schema::BasicValueType::Int64),
425+
nullable: false,
426+
attrs: Default::default(),
427+
},
428+
},
429+
schema::FieldSchema {
430+
name: "data_col_1".to_string(),
431+
value_type: schema::EnrichedValueType {
432+
typ: schema::ValueType::Basic(schema::BasicValueType::Str),
433+
nullable: false,
434+
attrs: Default::default(),
435+
},
436+
},
437+
schema::FieldSchema {
438+
name: "data_col_2".to_string(),
439+
value_type: schema::EnrichedValueType {
440+
typ: schema::ValueType::Basic(schema::BasicValueType::Bool),
441+
nullable: false,
442+
attrs: Default::default(),
443+
},
444+
},
445+
]),
446+
});
447+
448+
let row1_fields = value::FieldValues {
449+
fields: vec![
450+
value::Value::Basic(value::BasicValue::Int64(1)),
451+
value::Value::Basic(value::BasicValue::Str(Arc::from("row1_data"))),
452+
value::Value::Basic(value::BasicValue::Bool(true)),
453+
],
454+
};
455+
let row1_scope_val: value::ScopeValue = row1_fields.into();
456+
457+
let row2_fields = value::FieldValues {
458+
fields: vec![
459+
value::Value::Basic(value::BasicValue::Int64(2)),
460+
value::Value::Basic(value::BasicValue::Str(Arc::from("row2_data"))),
461+
value::Value::Basic(value::BasicValue::Bool(false)),
462+
],
463+
};
464+
let row2_scope_val: value::ScopeValue = row2_fields.into();
465+
466+
// UTable
467+
let utable_schema = schema::TableSchema {
468+
kind: schema::TableKind::UTable,
469+
row: (*row_schema_struct).clone(),
470+
};
471+
let utable_val = value::Value::UTable(vec![row1_scope_val.clone(), row2_scope_val.clone()]);
472+
let utable_typ = schema::ValueType::Table(utable_schema);
473+
assert_roundtrip_conversion(&utable_val, &utable_typ);
474+
475+
// LTable
476+
let ltable_schema = schema::TableSchema {
477+
kind: schema::TableKind::LTable,
478+
row: (*row_schema_struct).clone(),
479+
};
480+
let ltable_val = value::Value::LTable(vec![row1_scope_val.clone(), row2_scope_val.clone()]);
481+
let ltable_typ = schema::ValueType::Table(ltable_schema);
482+
assert_roundtrip_conversion(&ltable_val, &ltable_typ);
483+
484+
// KTable
485+
let ktable_schema = schema::TableSchema {
486+
kind: schema::TableKind::KTable,
487+
row: (*row_schema_struct).clone(),
488+
};
489+
let mut ktable_data = BTreeMap::new();
490+
491+
// Create KTable entries where the ScopeValue doesn't include the key field
492+
// This matches how the Python code will serialize/deserialize
493+
let row1_fields = value::FieldValues {
494+
fields: vec![
495+
value::Value::Basic(value::BasicValue::Str(Arc::from("row1_data"))),
496+
value::Value::Basic(value::BasicValue::Bool(true)),
497+
],
498+
};
499+
let row1_scope_val: value::ScopeValue = row1_fields.into();
500+
501+
let row2_fields = value::FieldValues {
502+
fields: vec![
503+
value::Value::Basic(value::BasicValue::Str(Arc::from("row2_data"))),
504+
value::Value::Basic(value::BasicValue::Bool(false)),
505+
],
506+
};
507+
let row2_scope_val: value::ScopeValue = row2_fields.into();
508+
509+
// For KTable, the key is extracted from the first field of ScopeValue based on current serialization
510+
let key1 = value::Value::<ScopeValue>::Basic(value::BasicValue::Int64(1))
511+
.into_key()
512+
.unwrap();
513+
let key2 = value::Value::<ScopeValue>::Basic(value::BasicValue::Int64(2))
514+
.into_key()
515+
.unwrap();
516+
517+
ktable_data.insert(key1, row1_scope_val.clone());
518+
ktable_data.insert(key2, row2_scope_val.clone());
519+
520+
let ktable_val = value::Value::KTable(ktable_data);
521+
let ktable_typ = schema::ValueType::Table(ktable_schema);
522+
assert_roundtrip_conversion(&ktable_val, &ktable_typ);
523+
}
524+
}

0 commit comments

Comments
 (0)