diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index 0c0f9f045..342df98b8 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -9,7 +9,7 @@ import inspect import warnings from enum import Enum -from typing import Any, Callable, Mapping, Type, get_origin +from typing import Any, Callable, Mapping, Sequence, Type, get_origin import numpy as np @@ -170,6 +170,37 @@ def encode_basic_value(value: Any) -> Any: return encode_basic_value +def make_engine_key_decoder( + field_path: list[str], + key_fields_schema: list[dict[str, Any]], + dst_type_info: AnalyzedTypeInfo, +) -> Callable[[Any], Any]: + """ + Create an encoder closure for a key type. + """ + if len(key_fields_schema) == 1 and isinstance( + dst_type_info.variant, (AnalyzedBasicType, AnalyzedAnyType) + ): + single_key_decoder = make_engine_value_decoder( + field_path, + key_fields_schema[0]["type"], + dst_type_info, + for_key=True, + ) + + def key_decoder(value: list[Any]) -> Any: + return single_key_decoder(value[0]) + + return key_decoder + + return make_engine_struct_decoder( + field_path, + key_fields_schema, + dst_type_info, + for_key=True, + ) + + def make_engine_value_decoder( field_path: list[str], src_type: dict[str, Any], @@ -244,31 +275,11 @@ def decode(value: Any) -> Any | None: ) num_key_parts = src_type.get("num_key_parts", 1) - key_type_info = analyze_type_info(key_type) - key_decoder: Callable[..., Any] | None = None - if ( - isinstance( - key_type_info.variant, (AnalyzedBasicType, AnalyzedAnyType) - ) - and num_key_parts == 1 - ): - single_key_decoder = make_engine_value_decoder( - field_path, - engine_fields_schema[0]["type"], - key_type_info, - for_key=True, - ) - - def key_decoder(value: list[Any]) -> Any: - return single_key_decoder(value[0]) - - else: - key_decoder = make_engine_struct_decoder( - field_path, - engine_fields_schema[0:num_key_parts], - key_type_info, - for_key=True, - ) + key_decoder = make_engine_key_decoder( + field_path, + engine_fields_schema[0:num_key_parts], + analyze_type_info(key_type), + ) value_decoder = make_engine_struct_decoder( field_path, engine_fields_schema[num_key_parts:], diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index 32afe993c..4bf65aa17 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -21,6 +21,7 @@ from .convert import ( make_engine_value_encoder, make_engine_value_decoder, + make_engine_key_decoder, make_engine_struct_decoder, ) from .typing import ( @@ -29,7 +30,6 @@ resolve_forward_ref, analyze_type_info, AnalyzedAnyType, - AnalyzedBasicType, AnalyzedDictType, ) @@ -532,24 +532,9 @@ def create_export_context( else (Any, Any) ) - key_type_info = analyze_type_info(key_annotation) - if ( - len(key_fields_schema) == 1 - and key_fields_schema[0]["type"]["kind"] != "Struct" - and isinstance(key_type_info.variant, (AnalyzedAnyType, AnalyzedBasicType)) - ): - # Special case for ease of use: single key column can be mapped to a basic type without the wrapper struct. - key_decoder = make_engine_value_decoder( - ["(key)"], - key_fields_schema[0]["type"], - key_type_info, - for_key=True, - ) - else: - key_decoder = make_engine_struct_decoder( - ["(key)"], key_fields_schema, key_type_info, for_key=True - ) - + key_decoder = make_engine_key_decoder( + ["(key)"], key_fields_schema, analyze_type_info(key_annotation) + ) value_decoder = make_engine_struct_decoder( ["(value)"], value_fields_schema, analyze_type_info(value_annotation) ) diff --git a/src/base/value.rs b/src/base/value.rs index 4930b2655..2fad98c5a 100644 --- a/src/base/value.rs +++ b/src/base/value.rs @@ -1,6 +1,5 @@ use super::schema::*; use crate::base::duration::parse_duration; -use crate::prelude::invariance_violation; use crate::{api_bail, api_error}; use anyhow::Result; use base64::prelude::*; @@ -82,7 +81,7 @@ impl<'de> Deserialize<'de> for RangeValue { /// Value of key. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize)] -pub enum KeyValue { +pub enum KeyPart { Bytes(Bytes), Str(Arc), Bool(bool), @@ -90,86 +89,86 @@ pub enum KeyValue { Range(RangeValue), Uuid(uuid::Uuid), Date(chrono::NaiveDate), - Struct(Vec), + Struct(Vec), } -impl From for KeyValue { +impl From for KeyPart { fn from(value: Bytes) -> Self { - KeyValue::Bytes(value) + KeyPart::Bytes(value) } } -impl From> for KeyValue { +impl From> for KeyPart { fn from(value: Vec) -> Self { - KeyValue::Bytes(Bytes::from(value)) + KeyPart::Bytes(Bytes::from(value)) } } -impl From> for KeyValue { +impl From> for KeyPart { fn from(value: Arc) -> Self { - KeyValue::Str(value) + KeyPart::Str(value) } } -impl From for KeyValue { +impl From for KeyPart { fn from(value: String) -> Self { - KeyValue::Str(Arc::from(value)) + KeyPart::Str(Arc::from(value)) } } -impl From for KeyValue { +impl From for KeyPart { fn from(value: bool) -> Self { - KeyValue::Bool(value) + KeyPart::Bool(value) } } -impl From for KeyValue { +impl From for KeyPart { fn from(value: i64) -> Self { - KeyValue::Int64(value) + KeyPart::Int64(value) } } -impl From for KeyValue { +impl From for KeyPart { fn from(value: RangeValue) -> Self { - KeyValue::Range(value) + KeyPart::Range(value) } } -impl From for KeyValue { +impl From for KeyPart { fn from(value: uuid::Uuid) -> Self { - KeyValue::Uuid(value) + KeyPart::Uuid(value) } } -impl From for KeyValue { +impl From for KeyPart { fn from(value: chrono::NaiveDate) -> Self { - KeyValue::Date(value) + KeyPart::Date(value) } } -impl From> for KeyValue { - fn from(value: Vec) -> Self { - KeyValue::Struct(value) +impl From> for KeyPart { + fn from(value: Vec) -> Self { + KeyPart::Struct(value) } } -impl serde::Serialize for KeyValue { +impl serde::Serialize for KeyPart { fn serialize(&self, serializer: S) -> Result { Value::from(self.clone()).serialize(serializer) } } -impl std::fmt::Display for KeyValue { +impl std::fmt::Display for KeyPart { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - KeyValue::Bytes(v) => write!(f, "{}", BASE64_STANDARD.encode(v)), - KeyValue::Str(v) => write!(f, "\"{}\"", v.escape_default()), - KeyValue::Bool(v) => write!(f, "{v}"), - KeyValue::Int64(v) => write!(f, "{v}"), - KeyValue::Range(v) => write!(f, "[{}, {})", v.start, v.end), - KeyValue::Uuid(v) => write!(f, "{v}"), - KeyValue::Date(v) => write!(f, "{v}"), - KeyValue::Struct(v) => { + KeyPart::Bytes(v) => write!(f, "{}", BASE64_STANDARD.encode(v)), + KeyPart::Str(v) => write!(f, "\"{}\"", v.escape_default()), + KeyPart::Bool(v) => write!(f, "{v}"), + KeyPart::Int64(v) => write!(f, "{v}"), + KeyPart::Range(v) => write!(f, "[{}, {})", v.start, v.end), + KeyPart::Uuid(v) => write!(f, "{v}"), + KeyPart::Date(v) => write!(f, "{v}"), + KeyPart::Struct(v) => { write!( f, "[{}]", @@ -183,50 +182,7 @@ impl std::fmt::Display for KeyValue { } } -impl KeyValue { - /// For export purpose only for now. Will remove after switching export to using FullKeyValue. - pub fn from_json_for_export( - value: serde_json::Value, - fields_schema: &[FieldSchema], - ) -> Result { - let value = if fields_schema.len() == 1 { - Value::from_json(value, &fields_schema[0].value_type.typ)? - } else { - let field_values: FieldValues = FieldValues::from_json(value, fields_schema)?; - Value::Struct(field_values) - }; - value.as_key() - } - - /// For export purpose only for now. Will remove after switching export to using FullKeyValue. - pub fn from_values_for_export<'a>( - values: impl ExactSizeIterator, - ) -> Result { - let key = if values.len() == 1 { - let mut values = values; - values.next().ok_or_else(invariance_violation)?.as_key()? - } else { - KeyValue::Struct(values.map(|v| v.as_key()).collect::>>()?) - }; - Ok(key) - } - - /// For export purpose only for now. Will remove after switching export to using FullKeyValue. - pub fn fields_iter_for_export( - &self, - num_fields: usize, - ) -> Result> { - let slice = if num_fields == 1 { - std::slice::from_ref(self) - } else { - match self { - KeyValue::Struct(v) => v, - _ => api_bail!("Invalid key value type"), - } - }; - Ok(slice.iter()) - } - +impl KeyPart { fn parts_from_str( values_iter: &mut impl Iterator, schema: &ValueType, @@ -238,29 +194,29 @@ impl KeyValue { .ok_or_else(|| api_error!("Key parts less than expected"))?; match basic_type { BasicValueType::Bytes => { - KeyValue::Bytes(Bytes::from(BASE64_STANDARD.decode(v)?)) + KeyPart::Bytes(Bytes::from(BASE64_STANDARD.decode(v)?)) } - BasicValueType::Str => KeyValue::Str(Arc::from(v)), - BasicValueType::Bool => KeyValue::Bool(v.parse()?), - BasicValueType::Int64 => KeyValue::Int64(v.parse()?), + BasicValueType::Str => KeyPart::Str(Arc::from(v)), + BasicValueType::Bool => KeyPart::Bool(v.parse()?), + BasicValueType::Int64 => KeyPart::Int64(v.parse()?), BasicValueType::Range => { let v2 = values_iter .next() .ok_or_else(|| api_error!("Key parts less than expected"))?; - KeyValue::Range(RangeValue { + KeyPart::Range(RangeValue { start: v.parse()?, end: v2.parse()?, }) } - BasicValueType::Uuid => KeyValue::Uuid(v.parse()?), - BasicValueType::Date => KeyValue::Date(v.parse()?), + BasicValueType::Uuid => KeyPart::Uuid(v.parse()?), + BasicValueType::Date => KeyPart::Date(v.parse()?), schema => api_bail!("Invalid key type {schema}"), } } - ValueType::Struct(s) => KeyValue::Struct( + ValueType::Struct(s) => KeyPart::Struct( s.fields .iter() - .map(|f| KeyValue::parts_from_str(values_iter, &f.value_type.typ)) + .map(|f| KeyPart::parts_from_str(values_iter, &f.value_type.typ)) .collect::>>()?, ), _ => api_bail!("Invalid key type {schema}"), @@ -270,17 +226,17 @@ impl KeyValue { fn parts_to_strs(&self, output: &mut Vec) { match self { - KeyValue::Bytes(v) => output.push(BASE64_STANDARD.encode(v)), - KeyValue::Str(v) => output.push(v.to_string()), - KeyValue::Bool(v) => output.push(v.to_string()), - KeyValue::Int64(v) => output.push(v.to_string()), - KeyValue::Range(v) => { + KeyPart::Bytes(v) => output.push(BASE64_STANDARD.encode(v)), + KeyPart::Str(v) => output.push(v.to_string()), + KeyPart::Bool(v) => output.push(v.to_string()), + KeyPart::Int64(v) => output.push(v.to_string()), + KeyPart::Range(v) => { output.push(v.start.to_string()); output.push(v.end.to_string()); } - KeyValue::Uuid(v) => output.push(v.to_string()), - KeyValue::Date(v) => output.push(v.to_string()), - KeyValue::Struct(v) => { + KeyPart::Uuid(v) => output.push(v.to_string()), + KeyPart::Date(v) => output.push(v.to_string()), + KeyPart::Struct(v) => { for part in v { part.parts_to_strs(output); } @@ -305,136 +261,136 @@ impl KeyValue { pub fn kind_str(&self) -> &'static str { match self { - KeyValue::Bytes(_) => "bytes", - KeyValue::Str(_) => "str", - KeyValue::Bool(_) => "bool", - KeyValue::Int64(_) => "int64", - KeyValue::Range { .. } => "range", - KeyValue::Uuid(_) => "uuid", - KeyValue::Date(_) => "date", - KeyValue::Struct(_) => "struct", + KeyPart::Bytes(_) => "bytes", + KeyPart::Str(_) => "str", + KeyPart::Bool(_) => "bool", + KeyPart::Int64(_) => "int64", + KeyPart::Range { .. } => "range", + KeyPart::Uuid(_) => "uuid", + KeyPart::Date(_) => "date", + KeyPart::Struct(_) => "struct", } } pub fn bytes_value(&self) -> Result<&Bytes> { match self { - KeyValue::Bytes(v) => Ok(v), + KeyPart::Bytes(v) => Ok(v), _ => anyhow::bail!("expected bytes value, but got {}", self.kind_str()), } } pub fn str_value(&self) -> Result<&Arc> { match self { - KeyValue::Str(v) => Ok(v), + KeyPart::Str(v) => Ok(v), _ => anyhow::bail!("expected str value, but got {}", self.kind_str()), } } pub fn bool_value(&self) -> Result { match self { - KeyValue::Bool(v) => Ok(*v), + KeyPart::Bool(v) => Ok(*v), _ => anyhow::bail!("expected bool value, but got {}", self.kind_str()), } } pub fn int64_value(&self) -> Result { match self { - KeyValue::Int64(v) => Ok(*v), + KeyPart::Int64(v) => Ok(*v), _ => anyhow::bail!("expected int64 value, but got {}", self.kind_str()), } } pub fn range_value(&self) -> Result { match self { - KeyValue::Range(v) => Ok(*v), + KeyPart::Range(v) => Ok(*v), _ => anyhow::bail!("expected range value, but got {}", self.kind_str()), } } pub fn uuid_value(&self) -> Result { match self { - KeyValue::Uuid(v) => Ok(*v), + KeyPart::Uuid(v) => Ok(*v), _ => anyhow::bail!("expected uuid value, but got {}", self.kind_str()), } } pub fn date_value(&self) -> Result { match self { - KeyValue::Date(v) => Ok(*v), + KeyPart::Date(v) => Ok(*v), _ => anyhow::bail!("expected date value, but got {}", self.kind_str()), } } - pub fn struct_value(&self) -> Result<&Vec> { + pub fn struct_value(&self) -> Result<&Vec> { match self { - KeyValue::Struct(v) => Ok(v), + KeyPart::Struct(v) => Ok(v), _ => anyhow::bail!("expected struct value, but got {}", self.kind_str()), } } pub fn num_parts(&self) -> usize { match self { - KeyValue::Range(_) => 2, - KeyValue::Struct(v) => v.iter().map(|v| v.num_parts()).sum(), + KeyPart::Range(_) => 2, + KeyPart::Struct(v) => v.iter().map(|v| v.num_parts()).sum(), _ => 1, } } fn estimated_detached_byte_size(&self) -> usize { match self { - KeyValue::Bytes(v) => v.len(), - KeyValue::Str(v) => v.len(), - KeyValue::Struct(v) => { + KeyPart::Bytes(v) => v.len(), + KeyPart::Str(v) => v.len(), + KeyPart::Struct(v) => { v.iter() - .map(KeyValue::estimated_detached_byte_size) + .map(KeyPart::estimated_detached_byte_size) .sum::() - + v.len() * std::mem::size_of::() + + v.len() * std::mem::size_of::() } - KeyValue::Bool(_) - | KeyValue::Int64(_) - | KeyValue::Range(_) - | KeyValue::Uuid(_) - | KeyValue::Date(_) => 0, + KeyPart::Bool(_) + | KeyPart::Int64(_) + | KeyPart::Range(_) + | KeyPart::Uuid(_) + | KeyPart::Date(_) => 0, } } } #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct FullKeyValue(pub Box<[KeyValue]>); +pub struct KeyValue(pub Box<[KeyPart]>); -impl>> From for FullKeyValue { +impl>> From for KeyValue { fn from(value: T) -> Self { - FullKeyValue(value.into()) + KeyValue(value.into()) } } -impl IntoIterator for FullKeyValue { - type Item = KeyValue; - type IntoIter = std::vec::IntoIter; +impl IntoIterator for KeyValue { + type Item = KeyPart; + type IntoIter = std::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } } -impl<'a> IntoIterator for &'a FullKeyValue { - type Item = &'a KeyValue; - type IntoIter = std::slice::Iter<'a, KeyValue>; +impl<'a> IntoIterator for &'a KeyValue { + type Item = &'a KeyPart; + type IntoIter = std::slice::Iter<'a, KeyPart>; fn into_iter(self) -> Self::IntoIter { self.0.iter() } } -impl Deref for FullKeyValue { - type Target = [KeyValue]; +impl Deref for KeyValue { + type Target = [KeyPart]; fn deref(&self) -> &Self::Target { &self.0 } } -impl std::fmt::Display for FullKeyValue { +impl std::fmt::Display for KeyValue { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, @@ -448,9 +404,9 @@ impl std::fmt::Display for FullKeyValue { } } -impl Serialize for FullKeyValue { +impl Serialize for KeyValue { fn serialize(&self, serializer: S) -> Result { - if self.0.len() == 1 && !matches!(self.0[0], KeyValue::Struct(_)) { + if self.0.len() == 1 && !matches!(self.0[0], KeyPart::Struct(_)) { self.0[0].serialize(serializer) } else { self.0.serialize(serializer) @@ -458,12 +414,12 @@ impl Serialize for FullKeyValue { } } -impl FullKeyValue { - pub fn from_single_part>(value: V) -> Self { +impl KeyValue { + pub fn from_single_part>(value: V) -> Self { Self(Box::new([value.into()])) } - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator { self.0.iter() } @@ -471,7 +427,8 @@ impl FullKeyValue { let field_values = if schema.len() == 1 && matches!(schema[0].value_type.typ, ValueType::Basic(_)) { - Box::from([KeyValue::from_json_for_export(value, schema)?]) + let val = Value::::from_json(value, &schema[0].value_type.typ)?; + Box::from([val.into_key()?]) } else { match value { serde_json::Value::Array(arr) => std::iter::zip(arr.into_iter(), schema) @@ -497,9 +454,9 @@ impl FullKeyValue { schema: &[FieldSchema], ) -> Result { let mut values_iter = value.into_iter(); - let keys: Box<[KeyValue]> = schema + let keys: Box<[KeyPart]> = schema .iter() - .map(|f| KeyValue::parts_from_str(&mut values_iter, &f.value_type.typ)) + .map(|f| KeyPart::parts_from_str(&mut values_iter, &f.value_type.typ)) .collect::>>()?; if values_iter.next().is_some() { api_bail!("Key parts more than expected"); @@ -507,11 +464,11 @@ impl FullKeyValue { Ok(Self(keys)) } - pub fn to_values(&self) -> Vec { + pub fn to_values(&self) -> Box<[Value]> { self.0.iter().map(|v| v.into()).collect() } - pub fn single_part(&self) -> Result<&KeyValue> { + pub fn single_part(&self) -> Result<&KeyPart> { if self.0.len() != 1 { api_bail!("expected single value, but got {}", self.0.len()); } @@ -641,15 +598,15 @@ impl> From> for BasicValue { } impl BasicValue { - pub fn into_key(self) -> Result { + pub fn into_key(self) -> Result { let result = match self { - BasicValue::Bytes(v) => KeyValue::Bytes(v), - BasicValue::Str(v) => KeyValue::Str(v), - BasicValue::Bool(v) => KeyValue::Bool(v), - BasicValue::Int64(v) => KeyValue::Int64(v), - BasicValue::Range(v) => KeyValue::Range(v), - BasicValue::Uuid(v) => KeyValue::Uuid(v), - BasicValue::Date(v) => KeyValue::Date(v), + BasicValue::Bytes(v) => KeyPart::Bytes(v), + BasicValue::Str(v) => KeyPart::Str(v), + BasicValue::Bool(v) => KeyPart::Bool(v), + BasicValue::Int64(v) => KeyPart::Int64(v), + BasicValue::Range(v) => KeyPart::Range(v), + BasicValue::Uuid(v) => KeyPart::Uuid(v), + BasicValue::Date(v) => KeyPart::Date(v), BasicValue::Float32(_) | BasicValue::Float64(_) | BasicValue::Time(_) @@ -663,15 +620,15 @@ impl BasicValue { Ok(result) } - pub fn as_key(&self) -> Result { + pub fn as_key(&self) -> Result { let result = match self { - BasicValue::Bytes(v) => KeyValue::Bytes(v.clone()), - BasicValue::Str(v) => KeyValue::Str(v.clone()), - BasicValue::Bool(v) => KeyValue::Bool(*v), - BasicValue::Int64(v) => KeyValue::Int64(*v), - BasicValue::Range(v) => KeyValue::Range(*v), - BasicValue::Uuid(v) => KeyValue::Uuid(*v), - BasicValue::Date(v) => KeyValue::Date(*v), + BasicValue::Bytes(v) => KeyPart::Bytes(v.clone()), + BasicValue::Str(v) => KeyPart::Str(v.clone()), + BasicValue::Bool(v) => KeyPart::Bool(*v), + BasicValue::Int64(v) => KeyPart::Int64(*v), + BasicValue::Range(v) => KeyPart::Range(*v), + BasicValue::Uuid(v) => KeyPart::Uuid(*v), + BasicValue::Date(v) => KeyPart::Date(*v), BasicValue::Float32(_) | BasicValue::Float64(_) | BasicValue::Time(_) @@ -765,7 +722,7 @@ pub enum Value { Basic(BasicValue), Struct(FieldValues), UTable(Vec), - KTable(BTreeMap), + KTable(BTreeMap), LTable(Vec), } @@ -775,34 +732,34 @@ impl> From for Value { } } -impl From for Value { - fn from(value: KeyValue) -> Self { +impl From for Value { + fn from(value: KeyPart) -> Self { match value { - KeyValue::Bytes(v) => Value::Basic(BasicValue::Bytes(v)), - KeyValue::Str(v) => Value::Basic(BasicValue::Str(v)), - KeyValue::Bool(v) => Value::Basic(BasicValue::Bool(v)), - KeyValue::Int64(v) => Value::Basic(BasicValue::Int64(v)), - KeyValue::Range(v) => Value::Basic(BasicValue::Range(v)), - KeyValue::Uuid(v) => Value::Basic(BasicValue::Uuid(v)), - KeyValue::Date(v) => Value::Basic(BasicValue::Date(v)), - KeyValue::Struct(v) => Value::Struct(FieldValues { + KeyPart::Bytes(v) => Value::Basic(BasicValue::Bytes(v)), + KeyPart::Str(v) => Value::Basic(BasicValue::Str(v)), + KeyPart::Bool(v) => Value::Basic(BasicValue::Bool(v)), + KeyPart::Int64(v) => Value::Basic(BasicValue::Int64(v)), + KeyPart::Range(v) => Value::Basic(BasicValue::Range(v)), + KeyPart::Uuid(v) => Value::Basic(BasicValue::Uuid(v)), + KeyPart::Date(v) => Value::Basic(BasicValue::Date(v)), + KeyPart::Struct(v) => Value::Struct(FieldValues { fields: v.into_iter().map(Value::from).collect(), }), } } } -impl From<&KeyValue> for Value { - fn from(value: &KeyValue) -> Self { +impl From<&KeyPart> for Value { + fn from(value: &KeyPart) -> Self { match value { - KeyValue::Bytes(v) => Value::Basic(BasicValue::Bytes(v.clone())), - KeyValue::Str(v) => Value::Basic(BasicValue::Str(v.clone())), - KeyValue::Bool(v) => Value::Basic(BasicValue::Bool(*v)), - KeyValue::Int64(v) => Value::Basic(BasicValue::Int64(*v)), - KeyValue::Range(v) => Value::Basic(BasicValue::Range(*v)), - KeyValue::Uuid(v) => Value::Basic(BasicValue::Uuid(*v)), - KeyValue::Date(v) => Value::Basic(BasicValue::Date(*v)), - KeyValue::Struct(v) => Value::Struct(FieldValues { + KeyPart::Bytes(v) => Value::Basic(BasicValue::Bytes(v.clone())), + KeyPart::Str(v) => Value::Basic(BasicValue::Str(v.clone())), + KeyPart::Bool(v) => Value::Basic(BasicValue::Bool(*v)), + KeyPart::Int64(v) => Value::Basic(BasicValue::Int64(*v)), + KeyPart::Range(v) => Value::Basic(BasicValue::Range(*v)), + KeyPart::Uuid(v) => Value::Basic(BasicValue::Uuid(*v)), + KeyPart::Date(v) => Value::Basic(BasicValue::Date(*v)), + KeyPart::Struct(v) => Value::Struct(FieldValues { fields: v.iter().map(Value::from).collect(), }), } @@ -871,10 +828,10 @@ impl Value { matches!(self, Value::Null) } - pub fn into_key(self) -> Result { + pub fn into_key(self) -> Result { let result = match self { Value::Basic(v) => v.into_key()?, - Value::Struct(v) => KeyValue::Struct( + Value::Struct(v) => KeyPart::Struct( v.fields .into_iter() .map(|v| v.into_key()) @@ -887,10 +844,10 @@ impl Value { Ok(result) } - pub fn as_key(&self) -> Result { + pub fn as_key(&self) -> Result { let result = match self { Value::Basic(v) => v.as_key()?, - Value::Struct(v) => KeyValue::Struct( + Value::Struct(v) => KeyPart::Struct( v.fields .iter() .map(|v| v.as_key()) @@ -1259,7 +1216,7 @@ impl BasicValue { } } -struct TableEntry<'a>(&'a [KeyValue], &'a ScopeValue); +struct TableEntry<'a>(&'a [KeyPart], &'a ScopeValue); impl serde::Serialize for Value { fn serialize(&self, serializer: S) -> Result { @@ -1330,7 +1287,7 @@ where } let mut field_vals_iter = v.into_iter(); - let keys: Box<[KeyValue]> = (0..num_key_parts) + let keys: Box<[KeyPart]> = (0..num_key_parts) .map(|_| { Self::from_json( field_vals_iter.next().unwrap(), @@ -1343,10 +1300,10 @@ where let values = FieldValues::from_json_values( std::iter::zip(fields_iter, field_vals_iter), )?; - Ok((FullKeyValue(keys), values.into())) + Ok((KeyValue(keys), values.into())) } serde_json::Value::Object(mut v) => { - let keys: Box<[KeyValue]> = (0..num_key_parts).map(|_| { + let keys: Box<[KeyPart]> = (0..num_key_parts).map(|_| { let f = fields_iter.next().unwrap(); Self::from_json( std::mem::take(v.get_mut(&f.name).ok_or_else( @@ -1360,7 +1317,7 @@ where &f.value_type.typ)?.into_key() }).collect::>()?; let values = FieldValues::from_json_object(v, fields_iter)?; - Ok((FullKeyValue(keys), values.into())) + Ok((KeyValue(keys), values.into())) } _ => api_bail!("Table value must be a JSON array or object"), } @@ -1617,7 +1574,7 @@ mod tests { fn test_estimated_byte_size_ktable() { let mut map = BTreeMap::new(); map.insert( - FullKeyValue(Box::from([KeyValue::Str(Arc::from("key1"))])), + KeyValue(Box::from([KeyPart::Str(Arc::from("key1"))])), ScopeValue(FieldValues { fields: vec![Value::::Basic(BasicValue::Str(Arc::from( "value1", @@ -1625,7 +1582,7 @@ mod tests { }), ); map.insert( - FullKeyValue(Box::from([KeyValue::Str(Arc::from("key2"))])), + KeyValue(Box::from([KeyPart::Str(Arc::from("key2"))])), ScopeValue(FieldValues { fields: vec![Value::::Basic(BasicValue::Str(Arc::from( "value2", diff --git a/src/builder/analyzer.rs b/src/builder/analyzer.rs index 82283a5f0..a79b146bf 100644 --- a/src/builder/analyzer.rs +++ b/src/builder/analyzer.rs @@ -643,7 +643,7 @@ fn add_collector( struct ExportDataFieldsInfo { local_collector_ref: AnalyzedLocalCollectorReference, primary_key_def: AnalyzedPrimaryKeyDef, - primary_key_type: ValueType, + primary_key_schema: Vec, value_fields_idx: Vec, value_stable: bool, } @@ -835,7 +835,7 @@ impl AnalyzerContext { .lock() .unwrap() .consume_collector(&export_op.spec.collector_name)?; - let (key_fields_schema, value_fields_schema, data_collection_info) = + let (value_fields_schema, data_collection_info) = match &export_op.spec.index_options.primary_key_fields { Some(fields) => { let pk_fields_idx = fields @@ -849,18 +849,10 @@ impl AnalyzerContext { }) .collect::>>()?; - let key_fields_schema = pk_fields_idx + let primary_key_schema = pk_fields_idx .iter() .map(|idx| collector_schema.fields[*idx].clone()) .collect::>(); - let primary_key_type = if pk_fields_idx.len() == 1 { - key_fields_schema[0].value_type.typ.clone() - } else { - ValueType::Struct(StructSchema { - fields: Arc::from(key_fields_schema.clone()), - description: None, - }) - }; let mut value_fields_schema: Vec = vec![]; let mut value_fields_idx = vec![]; for (idx, field) in collector_schema.fields.iter().enumerate() { @@ -875,12 +867,11 @@ impl AnalyzerContext { .map(|uuid_idx| pk_fields_idx.contains(uuid_idx)) .unwrap_or(false); ( - key_fields_schema, value_fields_schema, ExportDataFieldsInfo { local_collector_ref, primary_key_def: AnalyzedPrimaryKeyDef::Fields(pk_fields_idx), - primary_key_type, + primary_key_schema, value_fields_idx, value_stable, }, @@ -894,7 +885,7 @@ impl AnalyzerContext { collection_specs.push(interface::ExportDataCollectionSpec { name: export_op.name.clone(), spec: serde_json::Value::Object(export_op.spec.target.spec.clone()), - key_fields_schema, + key_fields_schema: data_collection_info.primary_key_schema.clone(), value_fields_schema, index_options: export_op.spec.index_options.clone(), }); @@ -936,7 +927,7 @@ impl AnalyzerContext { export_target_factory, export_context, primary_key_def: data_fields_info.primary_key_def, - primary_key_type: data_fields_info.primary_key_type, + primary_key_schema: data_fields_info.primary_key_schema, value_fields: data_fields_info.value_fields_idx, value_stable: data_fields_info.value_stable, }) diff --git a/src/builder/plan.rs b/src/builder/plan.rs index 7ac6fa8bd..5216373e6 100644 --- a/src/builder/plan.rs +++ b/src/builder/plan.rs @@ -105,7 +105,7 @@ pub struct AnalyzedExportOp { pub export_target_factory: Arc, pub export_context: Arc, pub primary_key_def: AnalyzedPrimaryKeyDef, - pub primary_key_type: schema::ValueType, + pub primary_key_schema: Vec, /// idx for value fields - excluding the primary key field. pub value_fields: Vec, /// If true, value is never changed on the same primary key. diff --git a/src/execution/dumper.rs b/src/execution/dumper.rs index 5e9853cc3..8903d4f8b 100644 --- a/src/execution/dumper.rs +++ b/src/execution/dumper.rs @@ -69,7 +69,7 @@ impl<'a> Dumper<'a> { &'a self, import_op_idx: usize, import_op: &'a AnalyzedImportOp, - key: &value::FullKeyValue, + key: &value::KeyValue, key_aux_info: &serde_json::Value, collected_values_buffer: &'b mut Vec>, ) -> Result>>> @@ -135,7 +135,7 @@ impl<'a> Dumper<'a> { &self, import_op_idx: usize, import_op: &AnalyzedImportOp, - key: value::FullKeyValue, + key: value::KeyValue, key_aux_info: serde_json::Value, file_path: PathBuf, ) -> Result<()> { @@ -188,7 +188,7 @@ impl<'a> Dumper<'a> { ) -> Result<()> { let mut keys_by_filename_prefix: IndexMap< String, - Vec<(value::FullKeyValue, serde_json::Value)>, + Vec<(value::KeyValue, serde_json::Value)>, > = IndexMap::new(); let mut rows_stream = import_op diff --git a/src/execution/evaluator.rs b/src/execution/evaluator.rs index d9c26f1c5..740f1c76d 100644 --- a/src/execution/evaluator.rs +++ b/src/execution/evaluator.rs @@ -120,18 +120,18 @@ enum ScopeKey<'a> { /// For root struct and UTable. None, /// For KTable row. - MapKey(&'a value::FullKeyValue), + MapKey(&'a value::KeyValue), /// For LTable row. ListIndex(usize), } impl<'a> ScopeKey<'a> { - pub fn key(&self) -> Option> { + pub fn key(&self) -> Option> { match self { ScopeKey::None => None, ScopeKey::MapKey(k) => Some(Cow::Borrowed(&k)), ScopeKey::ListIndex(i) => { - Some(Cow::Owned(value::FullKeyValue::from_single_part(*i as i64))) + Some(Cow::Owned(value::KeyValue::from_single_part(*i as i64))) } } } @@ -199,12 +199,12 @@ impl<'a> ScopeEntry<'a> { } fn get_local_key_field<'b>( - key_val: &'b value::KeyValue, + key_val: &'b value::KeyPart, indices: &'_ [u32], - ) -> &'b value::KeyValue { + ) -> &'b value::KeyPart { if indices.is_empty() { key_val - } else if let value::KeyValue::Struct(fields) = key_val { + } else if let value::KeyPart::Struct(fields) = key_val { Self::get_local_key_field(&fields[indices[0] as usize], &indices[1..]) } else { panic!("Only struct can be accessed by sub field"); @@ -494,7 +494,7 @@ pub struct SourceRowEvaluationContext<'a> { pub plan: &'a ExecutionPlan, pub import_op: &'a AnalyzedImportOp, pub schema: &'a schema::FlowSchema, - pub key: &'a value::FullKeyValue, + pub key: &'a value::KeyValue, pub import_op_idx: usize, } diff --git a/src/execution/row_indexer.rs b/src/execution/row_indexer.rs index eef1d4b21..4cd633397 100644 --- a/src/execution/row_indexer.rs +++ b/src/execution/row_indexer.rs @@ -27,7 +27,11 @@ pub fn extract_primary_key_for_export( ) -> Result { match primary_key_def { AnalyzedPrimaryKeyDef::Fields(fields) => { - KeyValue::from_values_for_export(fields.iter().map(|field| &record.fields[*field])) + let key_parts: Box<[value::KeyPart]> = fields + .iter() + .map(|field| record.fields[*field].as_key()) + .collect::>>()?; + Ok(KeyValue(key_parts)) } } } @@ -662,7 +666,7 @@ impl<'a> RowIndexer<'a> { let mut new_staging_target_keys = db_tracking::TrackedTargetKeyForSource::default(); let mut target_mutations = HashMap::with_capacity(export_ops.len()); for (target_id, target_tracking_info) in tracking_info_for_targets.into_iter() { - let legacy_keys: HashSet = target_tracking_info + let previous_keys: HashSet = target_tracking_info .existing_keys_info .into_keys() .chain(target_tracking_info.existing_staging_keys_info.into_keys()) @@ -670,7 +674,7 @@ impl<'a> RowIndexer<'a> { let mut new_staging_keys_info = target_tracking_info.new_staging_keys_info; // add deletions - new_staging_keys_info.extend(legacy_keys.iter().map(|key| TrackedTargetKeyInfo { + new_staging_keys_info.extend(previous_keys.iter().map(|key| TrackedTargetKeyInfo { key: key.key.clone(), additional_key: key.additional_key.clone(), process_ordinal, @@ -680,16 +684,11 @@ impl<'a> RowIndexer<'a> { if let Some(export_op) = target_tracking_info.export_op { let mut mutation = target_tracking_info.mutation; - mutation.deletes.reserve(legacy_keys.len()); - for legacy_key in legacy_keys.into_iter() { - let key = value::Value::::from_json( - legacy_key.key, - &export_op.primary_key_type, - )? - .as_key()?; + mutation.deletes.reserve(previous_keys.len()); + for previous_key in previous_keys.into_iter() { mutation.deletes.push(interface::ExportTargetDeleteEntry { - key, - additional_key: legacy_key.additional_key, + key: KeyValue::from_json(previous_key.key, &export_op.primary_key_schema)?, + additional_key: previous_key.additional_key, }); } target_mutations.insert(target_id, mutation); diff --git a/src/execution/source_indexer.rs b/src/execution/source_indexer.rs index 3cf50e5d9..d574380cc 100644 --- a/src/execution/source_indexer.rs +++ b/src/execution/source_indexer.rs @@ -39,7 +39,7 @@ impl Default for SourceRowIndexingState { } struct SourceIndexingState { - rows: HashMap, + rows: HashMap, scan_generation: usize, } @@ -55,7 +55,7 @@ pub struct SourceIndexingContext { pub const NO_ACK: Option Ready>> = None; struct LocalSourceRowStateOperator<'a> { - key: &'a value::FullKeyValue, + key: &'a value::KeyValue, indexing_state: &'a Mutex, update_stats: &'a Arc, @@ -75,7 +75,7 @@ enum RowStateAdvanceOutcome { impl<'a> LocalSourceRowStateOperator<'a> { fn new( - key: &'a value::FullKeyValue, + key: &'a value::KeyValue, indexing_state: &'a Mutex, update_stats: &'a Arc, ) -> Self { @@ -167,7 +167,7 @@ impl<'a> LocalSourceRowStateOperator<'a> { } pub struct ProcessSourceRowInput { - pub key: value::FullKeyValue, + pub key: value::KeyValue, /// `key_aux_info` is not available for deletions. It must be provided if `data.value` is `None`. pub key_aux_info: Option, pub data: interface::PartialSourceRowData, @@ -193,7 +193,7 @@ impl SourceIndexingContext { ); while let Some(key_metadata) = key_metadata_stream.next().await { let key_metadata = key_metadata?; - let source_pk = value::FullKeyValue::from_json( + let source_pk = value::KeyValue::from_json( key_metadata.source_key, &import_op.primary_key_schema, )?; diff --git a/src/ops/functions/split_recursively.rs b/src/ops/functions/split_recursively.rs index f3d3f993f..966babcbb 100644 --- a/src/ops/functions/split_recursively.rs +++ b/src/ops/functions/split_recursively.rs @@ -932,7 +932,7 @@ impl SimpleFunctionExecutor for Executor { let output_start = chunk_output.start_pos.output.unwrap(); let output_end = chunk_output.end_pos.output.unwrap(); ( - FullKeyValue::from_single_part(RangeValue::new( + KeyValue::from_single_part(RangeValue::new( output_start.char_offset, output_end.char_offset, )), @@ -1153,7 +1153,7 @@ mod tests { ]; for (range, expected_text) in expected_chunks { - let key = FullKeyValue::from_single_part(range); + let key = KeyValue::from_single_part(range); match table.get(&key) { Some(scope_value_ref) => { let chunk_text = diff --git a/src/ops/interface.rs b/src/ops/interface.rs index ab4478770..1902025ed 100644 --- a/src/ops/interface.rs +++ b/src/ops/interface.rs @@ -70,7 +70,7 @@ pub struct PartialSourceRowData { } pub struct PartialSourceRow { - pub key: FullKeyValue, + pub key: KeyValue, /// Auxiliary information for the source row, to be used when reading the content. /// e.g. it can be used to uniquely identify version of the row. /// Use serde_json::Value::Null to represent no auxiliary information. @@ -100,7 +100,7 @@ impl SourceValue { } pub struct SourceChange { - pub key: FullKeyValue, + pub key: KeyValue, /// Auxiliary information for the source row, to be used when reading the content. /// e.g. it can be used to uniquely identify version of the row. pub key_aux_info: serde_json::Value, @@ -144,7 +144,7 @@ pub trait SourceExecutor: Send + Sync { // Get the value for the given key. async fn get_value( &self, - key: &FullKeyValue, + key: &KeyValue, key_aux_info: &serde_json::Value, options: &SourceExecutorReadOptions, ) -> Result; diff --git a/src/ops/py_factory.rs b/src/ops/py_factory.rs index d02f811a3..7278d8ab0 100644 --- a/src/ops/py_factory.rs +++ b/src/ops/py_factory.rs @@ -464,13 +464,13 @@ impl interface::TargetFactory for PyExportTargetFactory { ); for upsert in mutation.mutation.upserts.into_iter() { flattened_mutations.push(( - py::value_to_py_object(py, &upsert.key.into())?, + py::key_to_py_object(py, &upsert.key)?, py::field_values_to_py_object(py, upsert.value.fields.iter())?, )); } for delete in mutation.mutation.deletes.into_iter() { flattened_mutations.push(( - py::value_to_py_object(py, &delete.key.into())?, + py::key_to_py_object(py, &delete.key)?, py.None().into_bound(py), )); } diff --git a/src/ops/shared/postgres.rs b/src/ops/shared/postgres.rs index f35613532..28e9daf0b 100644 --- a/src/ops/shared/postgres.rs +++ b/src/ops/shared/postgres.rs @@ -22,51 +22,36 @@ pub async fn get_db_pool( Ok(db_pool) } -pub fn key_value_fields_iter<'a>( - key_fields_schema: impl ExactSizeIterator, - key_value: &'a KeyValue, -) -> Result<&'a [KeyValue]> { - let slice = if key_fields_schema.into_iter().count() == 1 { - std::slice::from_ref(key_value) - } else { - match key_value { - KeyValue::Struct(fields) => fields, - _ => bail!("expect struct key value"), - } - }; - Ok(slice) -} - pub fn bind_key_field<'arg>( builder: &mut sqlx::QueryBuilder<'arg, sqlx::Postgres>, - key_value: &'arg KeyValue, + key_value: &'arg KeyPart, ) -> Result<()> { match key_value { - KeyValue::Bytes(v) => { + KeyPart::Bytes(v) => { builder.push_bind(&**v); } - KeyValue::Str(v) => { + KeyPart::Str(v) => { builder.push_bind(&**v); } - KeyValue::Bool(v) => { + KeyPart::Bool(v) => { builder.push_bind(v); } - KeyValue::Int64(v) => { + KeyPart::Int64(v) => { builder.push_bind(v); } - KeyValue::Range(v) => { + KeyPart::Range(v) => { builder.push_bind(PgRange { start: Bound::Included(v.start as i64), end: Bound::Excluded(v.end as i64), }); } - KeyValue::Uuid(v) => { + KeyPart::Uuid(v) => { builder.push_bind(v); } - KeyValue::Date(v) => { + KeyPart::Date(v) => { builder.push_bind(v); } - KeyValue::Struct(fields) => { + KeyPart::Struct(fields) => { builder.push_bind(sqlx::types::Json(fields)); } } diff --git a/src/ops/sources/amazon_s3.rs b/src/ops/sources/amazon_s3.rs index 466d14557..8c1763557 100644 --- a/src/ops/sources/amazon_s3.rs +++ b/src/ops/sources/amazon_s3.rs @@ -86,7 +86,7 @@ impl SourceExecutor for Executor { if key.ends_with('/') { continue; } if self.pattern_matcher.is_file_included(key) { batch.push(PartialSourceRow { - key: FullKeyValue::from_single_part(key.to_string()), + key: KeyValue::from_single_part(key.to_string()), key_aux_info: serde_json::Value::Null, data: PartialSourceRowData { ordinal: obj.last_modified().map(datetime_to_ordinal), @@ -113,7 +113,7 @@ impl SourceExecutor for Executor { async fn get_value( &self, - key: &FullKeyValue, + key: &KeyValue, _key_aux_info: &serde_json::Value, options: &SourceExecutorReadOptions, ) -> Result { @@ -264,7 +264,7 @@ impl Executor { { let decoded_key = decode_form_encoded_url(&s3.object.key)?; changes.push(SourceChange { - key: FullKeyValue::from_single_part(decoded_key), + key: KeyValue::from_single_part(decoded_key), key_aux_info: serde_json::Value::Null, data: PartialSourceRowData::default(), }); diff --git a/src/ops/sources/azure_blob.rs b/src/ops/sources/azure_blob.rs index 1ab4e9200..053fc2dc7 100644 --- a/src/ops/sources/azure_blob.rs +++ b/src/ops/sources/azure_blob.rs @@ -76,7 +76,7 @@ impl SourceExecutor for Executor { if self.pattern_matcher.is_file_included(key) { let ordinal = Some(datetime_to_ordinal(&blob.properties.last_modified)); batch.push(PartialSourceRow { - key: FullKeyValue::from_single_part(key.clone()), + key: KeyValue::from_single_part(key.clone()), key_aux_info: serde_json::Value::Null, data: PartialSourceRowData { ordinal, @@ -102,7 +102,7 @@ impl SourceExecutor for Executor { async fn get_value( &self, - key: &FullKeyValue, + key: &KeyValue, _key_aux_info: &serde_json::Value, options: &SourceExecutorReadOptions, ) -> Result { diff --git a/src/ops/sources/google_drive.rs b/src/ops/sources/google_drive.rs index b96f5515c..7560f3a71 100644 --- a/src/ops/sources/google_drive.rs +++ b/src/ops/sources/google_drive.rs @@ -134,7 +134,7 @@ impl Executor { None } else if is_supported_file_type(&mime_type) { Some(PartialSourceRow { - key: FullKeyValue::from_single_part(id), + key: KeyValue::from_single_part(id), key_aux_info: serde_json::Value::Null, data: PartialSourceRowData { ordinal: file.modified_time.map(|t| t.try_into()).transpose()?, @@ -214,7 +214,7 @@ impl Executor { let file_id = file.id.ok_or_else(|| anyhow!("File has no id"))?; if self.is_file_covered(&file_id).await? { changes.push(SourceChange { - key: FullKeyValue::from_single_part(file_id), + key: KeyValue::from_single_part(file_id), key_aux_info: serde_json::Value::Null, data: PartialSourceRowData::default(), }); @@ -328,7 +328,7 @@ impl SourceExecutor for Executor { async fn get_value( &self, - key: &FullKeyValue, + key: &KeyValue, _key_aux_info: &serde_json::Value, options: &SourceExecutorReadOptions, ) -> Result { diff --git a/src/ops/sources/local_file.rs b/src/ops/sources/local_file.rs index 514386bdb..29e233ded 100644 --- a/src/ops/sources/local_file.rs +++ b/src/ops/sources/local_file.rs @@ -55,7 +55,7 @@ impl SourceExecutor for Executor { None }; yield vec![PartialSourceRow { - key: FullKeyValue::from_single_part(relative_path.to_string()), + key: KeyValue::from_single_part(relative_path.to_string()), key_aux_info: serde_json::Value::Null, data: PartialSourceRowData { ordinal, @@ -73,7 +73,7 @@ impl SourceExecutor for Executor { async fn get_value( &self, - key: &FullKeyValue, + key: &KeyValue, _key_aux_info: &serde_json::Value, options: &SourceExecutorReadOptions, ) -> Result { diff --git a/src/ops/sources/postgres.rs b/src/ops/sources/postgres.rs index 6fe32275d..d5ba4b88e 100644 --- a/src/ops/sources/postgres.rs +++ b/src/ops/sources/postgres.rs @@ -408,8 +408,8 @@ impl SourceExecutor for Executor { .iter() .enumerate() .map(|(i, info)| (info.decoder)(&row, i)?.into_key()) - .collect::>>()?; - let key = FullKeyValue(parts); + .collect::>>()?; + let key = KeyValue(parts); // Decode value and ordinal let data = self.decode_row_data(&row, options, ordinal_col_index, pk_count)?; @@ -426,7 +426,7 @@ impl SourceExecutor for Executor { async fn get_value( &self, - key: &FullKeyValue, + key: &KeyValue, _key_aux_info: &serde_json::Value, options: &SourceExecutorReadOptions, ) -> Result { diff --git a/src/ops/targets/kuzu.rs b/src/ops/targets/kuzu.rs index a650accd7..4e8bd106f 100644 --- a/src/ops/targets/kuzu.rs +++ b/src/ops/targets/kuzu.rs @@ -528,7 +528,7 @@ fn append_upsert_node( &data_coll.schema.key_fields, upsert_entry .key - .fields_iter_for_export(data_coll.schema.key_fields.len())? + .iter() .map(|f| Cow::Owned(value::Value::from(f))), )?; write!(cypher.query_mut(), ")")?; @@ -607,7 +607,7 @@ fn append_upsert_rel( &data_coll.schema.key_fields, upsert_entry .key - .fields_iter_for_export(data_coll.schema.key_fields.len())? + .iter() .map(|f| Cow::Owned(value::Value::from(f))), )?; write!(cypher.query_mut(), "]->({TGT_NODE_VAR_NAME})")?; @@ -635,8 +635,7 @@ fn append_delete_node( append_key_pattern( cypher, &data_coll.schema.key_fields, - key.fields_iter_for_export(data_coll.schema.key_fields.len())? - .map(|f| Cow::Owned(value::Value::from(f))), + key.iter().map(|f| Cow::Owned(value::Value::from(f))), )?; writeln!(cypher.query_mut(), ")")?; writeln!( @@ -673,7 +672,7 @@ fn append_delete_rel( cypher, src_key_schema, src_node_key - .fields_iter_for_export(src_key_schema.len())? + .iter() .map(|k| Cow::Owned(value::Value::from(k))), )?; @@ -682,8 +681,7 @@ fn append_delete_rel( append_key_pattern( cypher, key_schema, - key.fields_iter_for_export(key_schema.len())? - .map(|k| Cow::Owned(value::Value::from(k))), + key.iter().map(|k| Cow::Owned(value::Value::from(k))), )?; write!( @@ -696,7 +694,7 @@ fn append_delete_rel( cypher, tgt_key_schema, tgt_node_key - .fields_iter_for_export(tgt_key_schema.len())? + .iter() .map(|k| Cow::Owned(value::Value::from(k))), )?; write!(cypher.query_mut(), ") DELETE {REL_VAR_NAME}")?; @@ -715,8 +713,7 @@ fn append_maybe_gc_node( append_key_pattern( cypher, &schema.key_fields, - key.fields_iter_for_export(schema.key_fields.len())? - .map(|f| Cow::Owned(value::Value::from(f))), + key.iter().map(|f| Cow::Owned(value::Value::from(f))), )?; writeln!(cypher.query_mut(), ")")?; write!( @@ -975,11 +972,11 @@ impl TargetFactoryBase for Factory { delete.additional_key ); } - let src_key = KeyValue::from_json_for_export( + let src_key = KeyValue::from_json( additional_keys[0].take(), &rel.source.schema.key_fields, )?; - let tgt_key = KeyValue::from_json_for_export( + let tgt_key = KeyValue::from_json( additional_keys[1].take(), &rel.target.schema.key_fields, )?; diff --git a/src/ops/targets/neo4j.rs b/src/ops/targets/neo4j.rs index 73db8b446..a7f2532b2 100644 --- a/src/ops/targets/neo4j.rs +++ b/src/ops/targets/neo4j.rs @@ -145,7 +145,7 @@ fn json_value_to_bolt_value(value: &serde_json::Value) -> Result { Ok(bolt_value) } -fn key_to_bolt(key: &KeyValue, schema: &schema::ValueType) -> Result { +fn key_to_bolt(key: &KeyPart, schema: &schema::ValueType) -> Result { value_to_bolt(&key.into(), schema) } @@ -456,10 +456,7 @@ impl ExportContext { val: &KeyValue, ) -> Result { let mut query = query; - for (i, val) in val - .fields_iter_for_export(self.analyzed_data_coll.schema.key_fields.len())? - .enumerate() - { + for (i, val) in val.iter().enumerate() { query = query.param( &self.key_field_params[i], key_to_bolt( diff --git a/src/ops/targets/postgres.rs b/src/ops/targets/postgres.rs index 741e062b7..7acd69f94 100644 --- a/src/ops/targets/postgres.rs +++ b/src/ops/targets/postgres.rs @@ -4,7 +4,7 @@ use super::shared::table_columns::{ TableColumnsSchema, TableMainSetupAction, TableUpsertionAction, check_table_compatibility, }; use crate::base::spec::{self, *}; -use crate::ops::shared::postgres::{bind_key_field, get_db_pool, key_value_fields_iter}; +use crate::ops::shared::postgres::{bind_key_field, get_db_pool}; use crate::settings::DatabaseConnectionSpec; use async_trait::async_trait; use indexmap::{IndexMap, IndexSet}; @@ -192,11 +192,7 @@ impl ExportContext { query_builder.push(","); } query_builder.push(" ("); - for (j, key_value) in - key_value_fields_iter(self.key_fields_schema.iter(), &upsert.key)? - .iter() - .enumerate() - { + for (j, key_value) in upsert.key.iter().enumerate() { if j > 0 { query_builder.push(", "); } @@ -234,11 +230,8 @@ impl ExportContext { for deletion in deletions.iter() { let mut query_builder = sqlx::QueryBuilder::new(""); query_builder.push(&self.delete_sql_prefix); - for (i, (schema, value)) in self - .key_fields_schema - .iter() - .zip(key_value_fields_iter(self.key_fields_schema.iter(), &deletion.key)?.iter()) - .enumerate() + for (i, (schema, value)) in + std::iter::zip(&self.key_fields_schema, &deletion.key).enumerate() { if i > 0 { query_builder.push(" AND "); diff --git a/src/ops/targets/qdrant.rs b/src/ops/targets/qdrant.rs index dd0bfc96e..f58e03206 100644 --- a/src/ops/targets/qdrant.rs +++ b/src/ops/targets/qdrant.rs @@ -290,10 +290,11 @@ impl ExportContext { } } fn key_to_point_id(key_value: &KeyValue) -> Result { - let point_id = match key_value { - KeyValue::Str(v) => PointId::from(v.to_string()), - KeyValue::Int64(v) => PointId::from(*v as u64), - KeyValue::Uuid(v) => PointId::from(v.to_string()), + let key_part = key_value.single_part()?; + let point_id = match key_part { + KeyPart::Str(v) => PointId::from(v.to_string()), + KeyPart::Int64(v) => PointId::from(*v as u64), + KeyPart::Uuid(v) => PointId::from(v.to_string()), e => bail!("Invalid Qdrant point ID: {e}"), }; @@ -389,7 +390,7 @@ impl TargetFactoryBase for Factory { .map(|d| { if d.key_fields_schema.len() != 1 { api_bail!( - "Expected one primary key field for the point ID. Got {}.", + "Expected exactly one primary key field for the point ID. Got {}.", d.key_fields_schema.len() ) } diff --git a/src/ops/targets/shared/property_graph.rs b/src/ops/targets/shared/property_graph.rs index d0079b9a3..25a48e8b0 100644 --- a/src/ops/targets/shared/property_graph.rs +++ b/src/ops/targets/shared/property_graph.rs @@ -123,7 +123,9 @@ pub struct GraphElementInputFieldsIdx { impl GraphElementInputFieldsIdx { pub fn extract_key(&self, fields: &[value::Value]) -> Result { - value::KeyValue::from_values_for_export(self.key.iter().map(|idx| &fields[*idx])) + let key_parts: Result> = + self.key.iter().map(|idx| fields[*idx].as_key()).collect(); + Ok(value::KeyValue(key_parts?)) } } diff --git a/src/py/convert.rs b/src/py/convert.rs index 5ed2b4a0d..782b2ddd0 100644 --- a/src/py/convert.rs +++ b/src/py/convert.rs @@ -1,4 +1,4 @@ -use crate::base::value::FullKeyValue; +use crate::base::value::KeyValue; use crate::prelude::*; use bytes::Bytes; @@ -93,6 +93,33 @@ pub fn field_values_to_py_object<'py, 'a>( Ok(PyTuple::new(py, fields)?.into_any()) } +pub fn key_to_py_object<'py, 'a>( + py: Python<'py>, + key: impl IntoIterator, +) -> PyResult> { + fn key_part_to_py_object<'py>( + py: Python<'py>, + part: &value::KeyPart, + ) -> PyResult> { + let result = match part { + value::KeyPart::Bytes(v) => v.into_bound_py_any(py)?, + value::KeyPart::Str(v) => v.into_bound_py_any(py)?, + value::KeyPart::Bool(v) => v.into_bound_py_any(py)?, + value::KeyPart::Int64(v) => v.into_bound_py_any(py)?, + value::KeyPart::Range(v) => pythonize(py, v).into_py_result()?, + value::KeyPart::Uuid(v) => v.into_bound_py_any(py)?, + value::KeyPart::Date(v) => v.into_bound_py_any(py)?, + value::KeyPart::Struct(v) => key_to_py_object(py, v)?, + }; + Ok(result) + } + let fields = key + .into_iter() + .map(|part| key_part_to_py_object(py, part)) + .collect::>>()?; + Ok(PyTuple::new(py, fields)?.into_any()) +} + pub fn value_to_py_object<'py>(py: Python<'py>, v: &value::Value) -> PyResult> { let result = match v { value::Value::Null => py.None().into_bound(py), @@ -347,13 +374,13 @@ pub fn value_from_py_object<'py>( iter.len() ); } - let keys: Box<[value::KeyValue]> = (0..num_key_parts) + let keys: Box<[value::KeyPart]> = (0..num_key_parts) .map(|_| iter.next().unwrap().into_key()) .collect::>()?; let values = value::FieldValues { fields: iter.collect::>(), }; - Ok((FullKeyValue(keys), values.into())) + Ok((KeyValue(keys), values.into())) }) .collect::>>() .into_py_result()?, @@ -558,8 +585,8 @@ mod tests { .into_key() .unwrap(); - ktable_data.insert(FullKeyValue(Box::from([key1])), row1_scope_val.clone()); - ktable_data.insert(FullKeyValue(Box::from([key2])), row2_scope_val.clone()); + ktable_data.insert(KeyValue(Box::from([key1])), row1_scope_val.clone()); + ktable_data.insert(KeyValue(Box::from([key2])), row2_scope_val.clone()); let ktable_val = value::Value::KTable(ktable_data); let ktable_typ = schema::ValueType::Table(ktable_schema); diff --git a/src/service/flows.rs b/src/service/flows.rs index 4e58ee2e4..04cd2cfab 100644 --- a/src/service/flows.rs +++ b/src/service/flows.rs @@ -61,7 +61,7 @@ pub struct GetKeysParam { #[derive(Serialize)] pub struct GetKeysResponse { key_schema: Vec, - keys: Vec<(value::FullKeyValue, serde_json::Value)>, + keys: Vec<(value::KeyValue, serde_json::Value)>, } pub async fn get_keys( @@ -134,7 +134,7 @@ struct SourceRowKeyContextHolder<'a> { plan: Arc, import_op_idx: usize, schema: &'a FlowSchema, - key: value::FullKeyValue, + key: value::KeyValue, key_aux_info: serde_json::Value, } @@ -161,7 +161,7 @@ impl<'a> SourceRowKeyContextHolder<'a> { _ => api_bail!("field is not a table: {}", source_row_key.field), }; let key_schema = table_schema.key_schema(); - let key = value::FullKeyValue::decode_from_strs(source_row_key.key, key_schema)?; + let key = value::KeyValue::decode_from_strs(source_row_key.key, key_schema)?; let key_aux_info = source_row_key .key_aux .map(|s| serde_json::from_str(&s))