diff --git a/docs/docs/core/basics.md b/docs/docs/core/basics.md index 2155cac9d..47cd2c178 100644 --- a/docs/docs/core/basics.md +++ b/docs/docs/core/basics.md @@ -23,7 +23,7 @@ Each piece of data has a **data type**, falling into one of the following catego * *Basic type*. * *Struct type*: a collection of **fields**, each with a name and a type. -* *Table type*: a collection of **rows**, each of which is a struct with specified schema. A table type can be a *KTable* (which has a key field) or a *LTable* (ordered but without key field). +* *Table type*: a collection of **rows**, each of which is a struct with specified schema. A table type can be a *KTable* (with key columns that uniquely identify each row) or a *LTable* (rows are ordered but without keys). An indexing flow always has a top-level struct, containing all data within and managed by the flow. diff --git a/docs/docs/core/data_types.mdx b/docs/docs/core/data_types.mdx index 321eeedd1..57656683e 100644 --- a/docs/docs/core/data_types.mdx +++ b/docs/docs/core/data_types.mdx @@ -148,21 +148,27 @@ We have two specific types of *Table* types: *KTable* and *LTable*. #### KTable -*KTable* is a *Table* type whose first column serves as the key. +*KTable* is a *Table* type whose one or more columns together serve as the key. The row order of a *KTable* is not preserved. -Type of the first column (key column) must be a [key type](#key-types). +Each key column must be a [key type](#key-types). When multiple key columns are present, they form a composite key. -In Python, a *KTable* type is represented by `dict[K, V]`. -The `K` should be the type binding to a key type, -and the `V` should be the type binding to a *Struct* type representing the value fields of each row. -When the specific type annotation is not provided, -the key type is bound to a tuple with its key parts when it's a *Struct* type, the value type is bound to `dict[str, Any]`. +In Python, a *KTable* type is represented by `dict[K, V]`. +`K` represents the key and `V` represents the value for each row: + +- `K` can be a Struct type (either a frozen dataclass or a `NamedTuple`) that contains all key parts as fields. This is the general way to model multi-part keys. +- When there is only a single key part and it is a basic type (e.g. `str`, `int`), you may use that basic type directly as the dictionary key instead of wrapping it in a Struct. +- `V` should be the type bound to a *Struct* representing the non-key value fields of each row. + +When a specific type annotation is not provided: +- For composite keys (multiple key parts), the key binds to a Python tuple of the key parts, e.g. `tuple[str, str]`. +- For a single basic key part, the key binds to that basic Python type. +- The value binds to `dict[str, Any]`. For example, you can use `dict[str, Person]` or `dict[str, PersonTuple]` to represent a *KTable*, with 4 columns: key (*Str*), `first_name` (*Str*), `last_name` (*Str*), `dob` (*Date*). It's bound to `dict[str, dict[str, Any]]` if you don't annotate the function argument with a specific type. -Note that if you want to use a *Struct* as the key, you need to ensure its value in Python is immutable. For `dataclass`, annotate it with `@dataclass(frozen=True)`. For `NamedTuple`, immutability is built-in. For example: +Note that when using a Struct as the key, it must be immutable in Python. For a dataclass, annotate it with `@dataclass(frozen=True)`. For `NamedTuple`, immutability is built-in. For example: ```python @dataclass(frozen=True) @@ -175,8 +181,8 @@ class PersonKeyTuple(NamedTuple): id: str ``` -Then you can use `dict[PersonKey, Person]` or `dict[PersonKeyTuple, PersonTuple]` to represent a KTable keyed by `PersonKey` or `PersonKeyTuple`. -It's bound to `dict[(str, str), dict[str, Any]]` if you don't annotate the function argument with a specific type. +Then you can use `dict[PersonKey, Person]` or `dict[PersonKeyTuple, PersonTuple]` to represent a KTable keyed by both `id_kind` and `id`. +If you don't annotate the function argument with a specific type, it's bound to `dict[tuple[str, str], dict[str, Any]]`. #### LTable diff --git a/docs/docs/getting_started/quickstart.md b/docs/docs/getting_started/quickstart.md index b1a8e2af7..f9b2760c5 100644 --- a/docs/docs/getting_started/quickstart.md +++ b/docs/docs/getting_started/quickstart.md @@ -105,7 +105,7 @@ Notes: * `chunk`, representing each row of `chunks`. 3. A *data source* extracts data from an external source. - In this example, the `LocalFile` data source imports local files as a KTable (table with a key field, see [KTable](../core/data_types#ktable) for details), each row has `"filename"` and `"content"` fields. + In this example, the `LocalFile` data source imports local files as a KTable (table with key columns, see [KTable](../core/data_types#ktable) for details), each row has `"filename"` and `"content"` fields. 4. After defining the KTable, we extend a new field `"chunks"` to each row by *transforming* the `"content"` field using `SplitRecursively`. The output of the `SplitRecursively` is also a KTable representing each chunk of the document, with `"location"` and `"text"` fields. diff --git a/examples/postgres_source/main.py b/examples/postgres_source/main.py index 6465fff51..ecb087a65 100644 --- a/examples/postgres_source/main.py +++ b/examples/postgres_source/main.py @@ -97,8 +97,8 @@ def postgres_product_indexing_flow( with data_scope["products"].row() as product: product["full_description"] = flow_builder.transform( make_full_description, - product["_key"]["product_category"], - product["_key"]["product_name"], + product["product_category"], + product["product_name"], product["description"], ) product["total_value"] = flow_builder.transform( @@ -112,8 +112,8 @@ def postgres_product_indexing_flow( ) ) indexed_product.collect( - product_category=product["_key"]["product_category"], - product_name=product["_key"]["product_name"], + product_category=product["product_category"], + product_name=product["product_name"], description=product["description"], price=product["price"], amount=product["amount"], diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index 9ec819161..0c0f9f045 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -14,7 +14,6 @@ import numpy as np from .typing import ( - KEY_FIELD_NAME, TABLE_TYPES, AnalyzedAnyType, AnalyzedBasicType, @@ -96,14 +95,24 @@ def encode_struct_list(value: Any) -> Any: f"Value type for dict is required to be a struct (e.g. dataclass or NamedTuple), got {variant.value_type}. " f"If you want a free-formed dict, use `cocoindex.Json` instead." ) + value_encoder = make_engine_value_encoder(value_type_info) - key_encoder = make_engine_value_encoder(analyze_type_info(variant.key_type)) - value_encoder = make_engine_value_encoder(analyze_type_info(variant.value_type)) + key_type_info = analyze_type_info(variant.key_type) + key_encoder = make_engine_value_encoder(key_type_info) + if isinstance(key_type_info.variant, AnalyzedBasicType): + + def encode_row(k: Any, v: Any) -> Any: + return [key_encoder(k)] + value_encoder(v) + + else: + + def encode_row(k: Any, v: Any) -> Any: + return key_encoder(k) + value_encoder(v) def encode_struct_dict(value: Any) -> Any: if not value: return [] - return [[key_encoder(k)] + value_encoder(v) for k, v in value.items()] + return [encode_row(k, v) for k, v in value.items()] return encode_struct_dict @@ -234,25 +243,47 @@ def decode(value: Any) -> Any | None: f"declared `{dst_type_info.core_type}`, a dict type expected" ) - key_field_schema = engine_fields_schema[0] - field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}") - key_decoder = make_engine_value_decoder( - field_path, - key_field_schema["type"], - analyze_type_info(key_type), - for_key=True, - ) - field_path.pop() + num_key_parts = src_type.get("num_key_parts", 1) + key_type_info = analyze_type_info(key_type) + key_decoder: Callable[..., Any] | None = None + if ( + isinstance( + key_type_info.variant, (AnalyzedBasicType, AnalyzedAnyType) + ) + and num_key_parts == 1 + ): + single_key_decoder = make_engine_value_decoder( + field_path, + engine_fields_schema[0]["type"], + key_type_info, + for_key=True, + ) + + def key_decoder(value: list[Any]) -> Any: + return single_key_decoder(value[0]) + + else: + key_decoder = make_engine_struct_decoder( + field_path, + engine_fields_schema[0:num_key_parts], + key_type_info, + for_key=True, + ) value_decoder = make_engine_struct_decoder( field_path, - engine_fields_schema[1:], + engine_fields_schema[num_key_parts:], analyze_type_info(value_type), ) def decode(value: Any) -> Any | None: if value is None: return None - return {key_decoder(v[0]): value_decoder(v[1:]) for v in value} + return { + key_decoder(v[0:num_key_parts]): value_decoder( + v[num_key_parts:] + ) + for v in value + } return decode diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index d08c5e08f..e14b634bc 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -330,35 +330,50 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: def _encode_struct_schema( struct_type: type, key_type: type | None = None -) -> dict[str, Any]: +) -> tuple[dict[str, Any], int | None]: fields = [] - def add_field(name: str, t: Any) -> None: + def add_field(name: str, analyzed_type: AnalyzedTypeInfo) -> None: try: - type_info = encode_enriched_type_info(analyze_type_info(t)) + type_info = encode_enriched_type_info(analyzed_type) except ValueError as e: e.add_note( f"Failed to encode annotation for field - " - f"{struct_type.__name__}.{name}: {t}" + f"{struct_type.__name__}.{name}: {analyzed_type.core_type}" ) raise type_info["name"] = name fields.append(type_info) + def add_fields_from_struct(struct_type: type) -> None: + if dataclasses.is_dataclass(struct_type): + for field in dataclasses.fields(struct_type): + add_field(field.name, analyze_type_info(field.type)) + elif is_namedtuple_type(struct_type): + for name, field_type in struct_type.__annotations__.items(): + add_field(name, analyze_type_info(field_type)) + else: + raise ValueError(f"Unsupported struct type: {struct_type}") + + result: dict[str, Any] = {} + num_key_parts = None if key_type is not None: - add_field(KEY_FIELD_NAME, key_type) + key_type_info = analyze_type_info(key_type) + if isinstance(key_type_info.variant, AnalyzedBasicType): + add_field(KEY_FIELD_NAME, key_type_info) + num_key_parts = 1 + elif isinstance(key_type_info.variant, AnalyzedStructType): + add_fields_from_struct(key_type_info.variant.struct_type) + num_key_parts = len(fields) + else: + raise ValueError(f"Unsupported key type: {key_type}") - if dataclasses.is_dataclass(struct_type): - for field in dataclasses.fields(struct_type): - add_field(field.name, field.type) - elif is_namedtuple_type(struct_type): - for name, field_type in struct_type.__annotations__.items(): - add_field(name, field_type) + add_fields_from_struct(struct_type) - result: dict[str, Any] = {"fields": fields} + result["fields"] = fields if doc := inspect.getdoc(struct_type): result["description"] = doc - return result + return result, num_key_parts def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: @@ -374,7 +389,7 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: return {"kind": variant.kind} if isinstance(variant, AnalyzedStructType): - encoded_type = _encode_struct_schema(variant.struct_type) + encoded_type, _ = _encode_struct_schema(variant.struct_type) encoded_type["kind"] = "Struct" return encoded_type @@ -384,10 +399,8 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: if isinstance(elem_type_info.variant, AnalyzedStructType): if variant.vector_info is not None: raise ValueError("LTable type must not have a vector info") - return { - "kind": "LTable", - "row": _encode_struct_schema(elem_type_info.variant.struct_type), - } + row_type, _ = _encode_struct_schema(elem_type_info.variant.struct_type) + return {"kind": "LTable", "row": row_type} else: vector_info = variant.vector_info return { @@ -402,12 +415,14 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: raise ValueError( f"KTable value must have a Struct type, got {value_type_info.core_type}" ) + row_type, num_key_parts = _encode_struct_schema( + value_type_info.variant.struct_type, + variant.key_type, + ) return { "kind": "KTable", - "row": _encode_struct_schema( - value_type_info.variant.struct_type, - variant.key_type, - ), + "row": row_type, + "num_key_parts": num_key_parts, } if isinstance(variant, AnalyzedUnionType): diff --git a/src/base/schema.rs b/src/base/schema.rs index e9d3e6b56..769d53d65 100644 --- a/src/base/schema.rs +++ b/src/base/schema.rs @@ -136,13 +136,26 @@ impl std::fmt::Display for StructSchema { } #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +pub struct KTableInfo { + // Omit the field if num_key_parts is 1 for backward compatibility. + #[serde(default = "default_num_key_parts")] + pub num_key_parts: usize, +} + +fn default_num_key_parts() -> usize { + 1 +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "kind")] #[allow(clippy::enum_variant_names)] pub enum TableKind { /// An table with unordered rows, without key. UTable, - /// A table's first field is the key. + /// A table's first field is the key. The value is number of fields serving as the key #[serde(alias = "Table")] - KTable, + KTable(KTableInfo), + /// A table whose rows orders are preserved. #[serde(alias = "List")] LTable, @@ -152,7 +165,7 @@ impl std::fmt::Display for TableKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { TableKind::UTable => write!(f, "Table"), - TableKind::KTable => write!(f, "KTable"), + TableKind::KTable(KTableInfo { num_key_parts }) => write!(f, "KTable({num_key_parts})"), TableKind::LTable => write!(f, "LTable"), } } @@ -160,28 +173,25 @@ impl std::fmt::Display for TableKind { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct TableSchema { + #[serde(flatten)] pub kind: TableKind, + pub row: StructSchema, } +impl std::fmt::Display for TableSchema { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}({})", self.kind, self.row) + } +} + impl TableSchema { - pub fn has_key(&self) -> bool { - match self.kind { - TableKind::KTable => true, - TableKind::UTable | TableKind::LTable => false, - } + pub fn new(kind: TableKind, row: StructSchema) -> Self { + Self { kind, row } } - pub fn key_type(&self) -> Option<&EnrichedValueType> { - match self.kind { - TableKind::KTable => self - .row - .fields - .first() - .as_ref() - .map(|field| &field.value_type), - TableKind::UTable | TableKind::LTable => None, - } + pub fn has_key(&self) -> bool { + !self.key_schema().is_empty() } pub fn without_attrs(&self) -> Self { @@ -190,23 +200,11 @@ impl TableSchema { row: self.row.without_attrs(), } } -} - -impl std::fmt::Display for TableSchema { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}({})", self.kind, self.row) - } -} - -impl TableSchema { - pub fn new(kind: TableKind, row: StructSchema) -> Self { - Self { kind, row } - } - pub fn key_field(&self) -> Option<&FieldSchema> { + pub fn key_schema(&self) -> &[FieldSchema] { match self.kind { - TableKind::KTable => Some(self.row.fields.first().unwrap()), - TableKind::UTable | TableKind::LTable => None, + TableKind::KTable(KTableInfo { num_key_parts: n }) => &self.row.fields[..n], + TableKind::UTable | TableKind::LTable => &[], } } } @@ -224,11 +222,11 @@ pub enum ValueType { } impl ValueType { - pub fn key_type(&self) -> Option<&EnrichedValueType> { + pub fn key_schema(&self) -> &[FieldSchema] { match self { - ValueType::Basic(_) => None, - ValueType::Struct(_) => None, - ValueType::Table(c) => c.key_type(), + ValueType::Basic(_) => &[], + ValueType::Struct(_) => &[], + ValueType::Table(c) => c.key_schema(), } } diff --git a/src/base/value.rs b/src/base/value.rs index ba4987c0a..4930b2655 100644 --- a/src/base/value.rs +++ b/src/base/value.rs @@ -184,7 +184,11 @@ impl std::fmt::Display for KeyValue { } impl KeyValue { - pub fn from_json(value: serde_json::Value, fields_schema: &[FieldSchema]) -> Result { + /// For export purpose only for now. Will remove after switching export to using FullKeyValue. + pub fn from_json_for_export( + value: serde_json::Value, + fields_schema: &[FieldSchema], + ) -> Result { let value = if fields_schema.len() == 1 { Value::from_json(value, &fields_schema[0].value_type.typ)? } else { @@ -194,7 +198,10 @@ impl KeyValue { value.as_key() } - pub fn from_values<'a>(values: impl ExactSizeIterator) -> Result { + /// For export purpose only for now. Will remove after switching export to using FullKeyValue. + pub fn from_values_for_export<'a>( + values: impl ExactSizeIterator, + ) -> Result { let key = if values.len() == 1 { let mut values = values; values.next().ok_or_else(invariance_violation)?.as_key()? @@ -204,7 +211,11 @@ impl KeyValue { Ok(key) } - pub fn fields_iter(&self, num_fields: usize) -> Result> { + /// For export purpose only for now. Will remove after switching export to using FullKeyValue. + pub fn fields_iter_for_export( + &self, + num_fields: usize, + ) -> Result> { let slice = if num_fields == 1 { std::slice::from_ref(self) } else { @@ -388,6 +399,126 @@ impl KeyValue { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct FullKeyValue(pub Box<[KeyValue]>); + +impl>> From for FullKeyValue { + fn from(value: T) -> Self { + FullKeyValue(value.into()) + } +} + +impl IntoIterator for FullKeyValue { + type Item = KeyValue; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a> IntoIterator for &'a FullKeyValue { + type Item = &'a KeyValue; + type IntoIter = std::slice::Iter<'a, KeyValue>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} + +impl Deref for FullKeyValue { + type Target = [KeyValue]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::fmt::Display for FullKeyValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{{{}}}", + self.0 + .iter() + .map(|v| v.to_string()) + .collect::>() + .join(", ") + ) + } +} + +impl Serialize for FullKeyValue { + fn serialize(&self, serializer: S) -> Result { + if self.0.len() == 1 && !matches!(self.0[0], KeyValue::Struct(_)) { + self.0[0].serialize(serializer) + } else { + self.0.serialize(serializer) + } + } +} + +impl FullKeyValue { + pub fn from_single_part>(value: V) -> Self { + Self(Box::new([value.into()])) + } + + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + pub fn from_json(value: serde_json::Value, schema: &[FieldSchema]) -> Result { + let field_values = if schema.len() == 1 + && matches!(schema[0].value_type.typ, ValueType::Basic(_)) + { + Box::from([KeyValue::from_json_for_export(value, schema)?]) + } else { + match value { + serde_json::Value::Array(arr) => std::iter::zip(arr.into_iter(), schema) + .map(|(v, s)| Value::::from_json(v, &s.value_type.typ)?.into_key()) + .collect::>>()?, + _ => anyhow::bail!("expected array value, but got {}", value), + } + }; + Ok(Self(field_values)) + } + + pub fn encode_to_strs(&self) -> Vec { + let capacity = self.0.iter().map(|k| k.num_parts()).sum(); + let mut output = Vec::with_capacity(capacity); + for part in self.0.iter() { + part.parts_to_strs(&mut output); + } + output + } + + pub fn decode_from_strs( + value: impl IntoIterator, + schema: &[FieldSchema], + ) -> Result { + let mut values_iter = value.into_iter(); + let keys: Box<[KeyValue]> = schema + .iter() + .map(|f| KeyValue::parts_from_str(&mut values_iter, &f.value_type.typ)) + .collect::>>()?; + if values_iter.next().is_some() { + api_bail!("Key parts more than expected"); + } + Ok(Self(keys)) + } + + pub fn to_values(&self) -> Vec { + self.0.iter().map(|v| v.into()).collect() + } + + pub fn single_part(&self) -> Result<&KeyValue> { + if self.0.len() != 1 { + api_bail!("expected single value, but got {}", self.0.len()); + } + Ok(&self.0[0]) + } +} + #[derive(Debug, Clone, PartialEq, Deserialize)] pub enum BasicValue { Bytes(Bytes), @@ -627,14 +758,14 @@ impl BasicValue { } } -#[derive(Debug, Clone, Default, PartialEq, Deserialize)] +#[derive(Debug, Clone, Default, PartialEq)] pub enum Value { #[default] Null, Basic(BasicValue), Struct(FieldValues), UTable(Vec), - KTable(BTreeMap), + KTable(BTreeMap), LTable(Vec), } @@ -709,9 +840,7 @@ impl Value { .collect(), }), Value::UTable(v) => Value::UTable(v.into_iter().map(|v| v.into()).collect()), - Value::KTable(v) => { - Value::KTable(v.into_iter().map(|(k, v)| (k.clone(), v.into())).collect()) - } + Value::KTable(v) => Value::KTable(v.into_iter().map(|(k, v)| (k, v.into())).collect()), Value::LTable(v) => Value::LTable(v.into_iter().map(|v| v.into()).collect()), } } @@ -879,7 +1008,10 @@ impl Value { Value::KTable(v) => { v.iter() .map(|(k, v)| { - k.estimated_detached_byte_size() + v.estimated_detached_byte_size() + k.iter() + .map(|k| k.estimated_detached_byte_size()) + .sum::() + + v.estimated_detached_byte_size() }) .sum::() + v.len() * std::mem::size_of::<(String, ScopeValue)>() @@ -888,7 +1020,7 @@ impl Value { } } -#[derive(Debug, Clone, PartialEq, Deserialize)] +#[derive(Debug, Clone, PartialEq)] pub struct FieldValues { pub fields: Vec>, } @@ -972,7 +1104,7 @@ where } } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[derive(Debug, Clone, Serialize, PartialEq)] pub struct ScopeValue(pub FieldValues); impl EstimatedByteSize for ScopeValue { @@ -1127,7 +1259,7 @@ impl BasicValue { } } -struct TableEntry<'a>(&'a KeyValue, &'a ScopeValue); +struct TableEntry<'a>(&'a [KeyValue], &'a ScopeValue); impl serde::Serialize for Value { fn serialize(&self, serializer: S) -> Result { @@ -1151,8 +1283,10 @@ impl serde::Serialize for Value { impl serde::Serialize for TableEntry<'_> { fn serialize(&self, serializer: S) -> Result { let &TableEntry(key, value) = self; - let mut seq = serializer.serialize_seq(Some(value.0.fields.len() + 1))?; - seq.serialize_element(key)?; + let mut seq = serializer.serialize_seq(Some(key.len() + value.0.fields.len()))?; + for item in key.iter() { + seq.serialize_element(item)?; + } for item in value.0.fields.iter() { seq.serialize_element(item)?; } @@ -1171,68 +1305,78 @@ where (v, ValueType::Struct(s)) => { Value::::Struct(FieldValues::::from_json(v, &s.fields)?) } - (serde_json::Value::Array(v), ValueType::Table(s)) => match s.kind { - TableKind::UTable => { - let rows = v - .into_iter() - .map(|v| Ok(FieldValues::from_json(v, &s.row.fields)?.into())) - .collect::>>()?; - Value::LTable(rows) - } - TableKind::KTable => { - let rows = v - .into_iter() - .map(|v| { - let mut fields_iter = s.row.fields.iter(); - let key_field = fields_iter - .next() - .ok_or_else(|| api_error!("Empty struct field values"))?; - - match v { - serde_json::Value::Array(v) => { - let mut field_vals_iter = v.into_iter(); - let key = Self::from_json( - field_vals_iter.next().ok_or_else(|| { - api_error!("Empty struct field values") - })?, - &key_field.value_type.typ, - )? - .into_key()?; - let values = FieldValues::from_json_values( - fields_iter.zip(field_vals_iter), - )?; - Ok((key, values.into())) + (serde_json::Value::Array(v), ValueType::Table(s)) => { + match s.kind { + TableKind::UTable => { + let rows = v + .into_iter() + .map(|v| Ok(FieldValues::from_json(v, &s.row.fields)?.into())) + .collect::>>()?; + Value::LTable(rows) + } + TableKind::KTable(info) => { + let num_key_parts = info.num_key_parts; + let rows = + v.into_iter() + .map(|v| { + if s.row.fields.len() < num_key_parts { + anyhow::bail!("Invalid KTable schema: expect at least {} fields, got {}", num_key_parts, s.row.fields.len()); } - serde_json::Value::Object(mut v) => { - let key = Self::from_json( - std::mem::take(v.get_mut(&key_field.name).ok_or_else( - || { - api_error!( - "key field `{}` doesn't exist in value", - key_field.name - ) - }, - )?), - &key_field.value_type.typ, - )? - .into_key()?; - let values = FieldValues::from_json_object(v, fields_iter)?; - Ok((key, values.into())) + let mut fields_iter = s.row.fields.iter(); + match v { + serde_json::Value::Array(v) => { + if v.len() != fields_iter.len() { + anyhow::bail!("Invalid KTable value: expect {} values, received {}", fields_iter.len(), v.len()); + } + + let mut field_vals_iter = v.into_iter(); + let keys: Box<[KeyValue]> = (0..num_key_parts) + .map(|_| { + Self::from_json( + field_vals_iter.next().unwrap(), + &fields_iter.next().unwrap().value_type.typ, + )? + .into_key() + }) + .collect::>()?; + + let values = FieldValues::from_json_values( + std::iter::zip(fields_iter, field_vals_iter), + )?; + Ok((FullKeyValue(keys), values.into())) + } + serde_json::Value::Object(mut v) => { + let keys: Box<[KeyValue]> = (0..num_key_parts).map(|_| { + let f = fields_iter.next().unwrap(); + Self::from_json( + std::mem::take(v.get_mut(&f.name).ok_or_else( + || { + api_error!( + "key field `{}` doesn't exist in value", + f.name + ) + }, + )?), + &f.value_type.typ)?.into_key() + }).collect::>()?; + let values = FieldValues::from_json_object(v, fields_iter)?; + Ok((FullKeyValue(keys), values.into())) + } + _ => api_bail!("Table value must be a JSON array or object"), } - _ => api_bail!("Table value must be a JSON array or object"), - } - }) - .collect::>>()?; - Value::KTable(rows) - } - TableKind::LTable => { - let rows = v - .into_iter() - .map(|v| Ok(FieldValues::from_json(v, &s.row.fields)?.into())) - .collect::>>()?; - Value::LTable(rows) + }) + .collect::>>()?; + Value::KTable(rows) + } + TableKind::LTable => { + let rows = v + .into_iter() + .map(|v| Ok(FieldValues::from_json(v, &s.row.fields)?.into())) + .collect::>>()?; + Value::LTable(rows) + } } - }, + } (v, t) => { anyhow::bail!("Value and type not matched.\nTarget type {t:?}\nJSON value: {v}\n") } @@ -1280,10 +1424,10 @@ impl Serialize for TypedValue<'_> { (ValueType::Table(c), Value::KTable(rows)) => { let mut seq = serializer.serialize_seq(Some(rows.len()))?; for (k, v) in rows { + let keys: Box<[Value]> = k.iter().map(|k| Value::from(k.clone())).collect(); seq.serialize_element(&TypedFieldsValue { schema: &c.row.fields, - values_iter: std::iter::once(&Value::from(k.clone())) - .chain(v.fields.iter()), + values_iter: keys.iter().chain(v.fields.iter()), })?; } seq.end() @@ -1473,7 +1617,7 @@ mod tests { fn test_estimated_byte_size_ktable() { let mut map = BTreeMap::new(); map.insert( - KeyValue::Str(Arc::from("key1")), + FullKeyValue(Box::from([KeyValue::Str(Arc::from("key1"))])), ScopeValue(FieldValues { fields: vec![Value::::Basic(BasicValue::Str(Arc::from( "value1", @@ -1481,7 +1625,7 @@ mod tests { }), ); map.insert( - KeyValue::Str(Arc::from("key2")), + FullKeyValue(Box::from([KeyValue::Str(Arc::from("key2"))])), ScopeValue(FieldValues { fields: vec![Value::::Basic(BasicValue::Str(Arc::from( "value2", diff --git a/src/builder/analyzer.rs b/src/builder/analyzer.rs index 517bdab86..82283a5f0 100644 --- a/src/builder/analyzer.rs +++ b/src/builder/analyzer.rs @@ -663,12 +663,7 @@ impl AnalyzerContext { .await?; let op_name = import_op.name.clone(); - let primary_key_type = output_type - .typ - .key_type() - .ok_or_else(|| api_error!("Source must produce a type with key: {op_name}"))? - .typ - .clone(); + let primary_key_schema = Box::from(output_type.typ.key_schema()); let output = op_scope.add_op_output(import_op.name, output_type)?; let concur_control_options = import_op @@ -683,7 +678,7 @@ impl AnalyzerContext { Ok(AnalyzedImportOp { executor, output, - primary_key_type, + primary_key_schema, name: op_name, refresh_options: import_op.spec.refresh_options, concurrency_controller: concur_control::CombinedConcurrencyController::new( diff --git a/src/builder/exec_ctx.rs b/src/builder/exec_ctx.rs index 1e6dda603..264c2911e 100644 --- a/src/builder/exec_ctx.rs +++ b/src/builder/exec_ctx.rs @@ -36,18 +36,26 @@ fn build_import_op_exec_ctx( existing_source_states: Option<&Vec<&setup::SourceSetupState>>, metadata: &mut setup::FlowSetupMetadata, ) -> Result { - let key_schema_no_attrs = import_op_output_type + let keys_schema_no_attrs = import_op_output_type .typ - .key_type() - .ok_or_else(|| api_error!("Source must produce a type with key"))? - .typ - .without_attrs(); + .key_schema() + .iter() + .map(|field| field.value_type.typ.without_attrs()) + .collect::>(); let existing_source_ids = existing_source_states .iter() .flat_map(|v| v.iter()) .filter_map(|state| { - if state.key_schema == key_schema_no_attrs { + let existing_keys_schema: &[schema::ValueType] = + if let Some(keys_schema) = &state.keys_schema { + keys_schema + } else if let Some(key_schema) = &state.key_schema { + std::slice::from_ref(key_schema) + } else { + &[] + }; + if existing_keys_schema == keys_schema_no_attrs.as_ref() { Some(state.source_id) } else { None @@ -67,7 +75,30 @@ fn build_import_op_exec_ctx( import_op.name.clone(), setup::SourceSetupState { source_id, - key_schema: key_schema_no_attrs, + + // Keep this field for backward compatibility, + // so users can still swap back to older version if needed. + key_schema: Some(if keys_schema_no_attrs.len() == 1 { + keys_schema_no_attrs[0].clone() + } else { + schema::ValueType::Struct(schema::StructSchema { + fields: Arc::new( + import_op_output_type + .typ + .key_schema() + .iter() + .map(|field| { + schema::FieldSchema::new( + field.name.clone(), + field.value_type.clone(), + ) + }) + .collect(), + ), + description: None, + }) + }), + keys_schema: Some(keys_schema_no_attrs), source_kind: import_op.spec.source.kind.clone(), }, ); diff --git a/src/builder/plan.rs b/src/builder/plan.rs index 665c907c5..7ac6fa8bd 100644 --- a/src/builder/plan.rs +++ b/src/builder/plan.rs @@ -1,3 +1,4 @@ +use crate::base::schema::FieldSchema; use crate::prelude::*; use crate::ops::interface::*; @@ -54,7 +55,7 @@ pub struct AnalyzedImportOp { pub name: String, pub executor: Box, pub output: AnalyzedOpOutput, - pub primary_key_type: schema::ValueType, + pub primary_key_schema: Box<[FieldSchema]>, pub refresh_options: spec::SourceRefreshOptions, pub concurrency_controller: concur_control::CombinedConcurrencyController, diff --git a/src/execution/dumper.rs b/src/execution/dumper.rs index 410458c95..93f86b93b 100644 --- a/src/execution/dumper.rs +++ b/src/execution/dumper.rs @@ -47,7 +47,7 @@ impl Serialize for TargetExportData<'_> { #[derive(Serialize)] struct SourceOutputData<'a> { - key: value::TypedValue<'a>, + key: value::TypedFieldsValue<'a, std::slice::Iter<'a, value::Value>>, #[serde(skip_serializing_if = "Option::is_none")] exports: Option>>, @@ -69,7 +69,7 @@ impl<'a> Dumper<'a> { &'a self, import_op_idx: usize, import_op: &'a AnalyzedImportOp, - key: &value::KeyValue, + key: &value::FullKeyValue, key_aux_info: &serde_json::Value, collected_values_buffer: &'b mut Vec>, ) -> Result>>> @@ -116,7 +116,7 @@ impl<'a> Dumper<'a> { data: collected_values_buffer[collector_idx] .iter() .map(|v| -> Result<_> { - let key = row_indexer::extract_primary_key( + let key = row_indexer::extract_primary_key_for_export( &export_op.primary_key_def, v, )?; @@ -135,7 +135,7 @@ impl<'a> Dumper<'a> { &self, import_op_idx: usize, import_op: &AnalyzedImportOp, - key: value::KeyValue, + key: value::FullKeyValue, key_aux_info: serde_json::Value, file_path: PathBuf, ) -> Result<()> { @@ -157,11 +157,11 @@ impl<'a> Dumper<'a> { Ok(exports) => (exports, None), Err(e) => (None, Some(format!("{e:?}"))), }; - let key_value = value::Value::from(key); + let key_values: Vec = key.into_iter().map(|v| v.into()).collect::>(); let file_data = SourceOutputData { - key: value::TypedValue { - t: &import_op.primary_key_type, - v: &key_value, + key: value::TypedFieldsValue { + schema: &import_op.primary_key_schema, + values_iter: key_values.iter(), }, exports, error, @@ -188,7 +188,7 @@ impl<'a> Dumper<'a> { ) -> Result<()> { let mut keys_by_filename_prefix: IndexMap< String, - Vec<(value::KeyValue, serde_json::Value)>, + Vec<(value::FullKeyValue, serde_json::Value)>, > = IndexMap::new(); let mut rows_stream = import_op @@ -202,7 +202,7 @@ impl<'a> Dumper<'a> { for row in rows?.into_iter() { let mut s = row .key - .to_strs() + .encode_to_strs() .into_iter() .map(|s| urlencoding::encode(&s).into_owned()) .join(":"); diff --git a/src/execution/evaluator.rs b/src/execution/evaluator.rs index d1b4c3752..d9c26f1c5 100644 --- a/src/execution/evaluator.rs +++ b/src/execution/evaluator.rs @@ -60,7 +60,7 @@ impl ScopeValueBuilder { } fn augmented_from(source: &value::ScopeValue, schema: &schema::TableSchema) -> Result { - let val_index_base = if schema.has_key() { 1 } else { 0 }; + let val_index_base = schema.key_schema().len(); let len = schema.row.fields.len() - val_index_base; let mut builder = Self::new(len); @@ -120,24 +120,26 @@ enum ScopeKey<'a> { /// For root struct and UTable. None, /// For KTable row. - MapKey(&'a value::KeyValue), + MapKey(&'a value::FullKeyValue), /// For LTable row. ListIndex(usize), } impl<'a> ScopeKey<'a> { - pub fn key(&self) -> Option> { + pub fn key(&self) -> Option> { match self { ScopeKey::None => None, - ScopeKey::MapKey(k) => Some(Cow::Borrowed(k)), - ScopeKey::ListIndex(i) => Some(Cow::Owned(value::KeyValue::Int64(*i as i64))), + ScopeKey::MapKey(k) => Some(Cow::Borrowed(&k)), + ScopeKey::ListIndex(i) => { + Some(Cow::Owned(value::FullKeyValue::from_single_part(*i as i64))) + } } } - pub fn value_field_index_base(&self) -> u32 { + pub fn value_field_index_base(&self) -> usize { match *self { ScopeKey::None => 0, - ScopeKey::MapKey(_) => 1, + ScopeKey::MapKey(v) => v.len(), ScopeKey::ListIndex(_) => 0, } } @@ -147,7 +149,7 @@ impl std::fmt::Display for ScopeKey<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ScopeKey::None => write!(f, "()"), - ScopeKey::MapKey(k) => write!(f, "{{{k}}}"), + ScopeKey::MapKey(k) => write!(f, "{k}"), ScopeKey::ListIndex(i) => write!(f, "[{i}]"), } } @@ -226,7 +228,7 @@ impl<'a> ScopeEntry<'a> { &self, field_ref: &AnalyzedLocalFieldReference, ) -> &value::Value { - let first_index = field_ref.fields_idx[0]; + let first_index = field_ref.fields_idx[0] as usize; let index_base = self.key.value_field_index_base(); let val = self.value.fields[(first_index - index_base) as usize] .get() @@ -235,11 +237,12 @@ impl<'a> ScopeEntry<'a> { } fn get_field(&self, field_ref: &AnalyzedLocalFieldReference) -> value::Value { - let first_index = field_ref.fields_idx[0]; + let first_index = field_ref.fields_idx[0] as usize; let index_base = self.key.value_field_index_base(); if first_index < index_base { - let key_val = self.key.key().unwrap().into_owned(); - let key_part = Self::get_local_key_field(&key_val, &field_ref.fields_idx[1..]); + let key_val = self.key.key().unwrap(); + let key_part = + Self::get_local_key_field(&key_val[first_index], &field_ref.fields_idx[1..]); key_part.clone().into() } else { let val = self.value.fields[(first_index - index_base) as usize] @@ -491,7 +494,7 @@ pub struct SourceRowEvaluationContext<'a> { pub plan: &'a ExecutionPlan, pub import_op: &'a AnalyzedImportOp, pub schema: &'a schema::FlowSchema, - pub key: &'a value::KeyValue, + pub key: &'a value::FullKeyValue, pub import_op_idx: usize, } diff --git a/src/execution/row_indexer.rs b/src/execution/row_indexer.rs index 513689463..1fa451462 100644 --- a/src/execution/row_indexer.rs +++ b/src/execution/row_indexer.rs @@ -21,13 +21,13 @@ use crate::ops::interface::{ use crate::utils::db::WriteAction; use crate::utils::fingerprint::{Fingerprint, Fingerprinter}; -pub fn extract_primary_key( +pub fn extract_primary_key_for_export( primary_key_def: &AnalyzedPrimaryKeyDef, record: &FieldValues, ) -> Result { match primary_key_def { AnalyzedPrimaryKeyDef::Fields(fields) => { - KeyValue::from_values(fields.iter().map(|field| &record.fields[*field])) + KeyValue::from_values_for_export(fields.iter().map(|field| &record.fields[*field])) } } } @@ -582,7 +582,8 @@ impl<'a> RowIndexer<'a> { let collected_values = &data.evaluate_output.collected_values[export_op.input.collector_idx as usize]; for value in collected_values.iter() { - let primary_key = extract_primary_key(&export_op.primary_key_def, value)?; + let primary_key = + extract_primary_key_for_export(&export_op.primary_key_def, value)?; let primary_key_json = serde_json::to_value(&primary_key)?; let mut field_values = FieldValues { diff --git a/src/execution/source_indexer.rs b/src/execution/source_indexer.rs index 2f08b79bf..e2ad2b1fd 100644 --- a/src/execution/source_indexer.rs +++ b/src/execution/source_indexer.rs @@ -39,7 +39,7 @@ impl Default for SourceRowIndexingState { } struct SourceIndexingState { - rows: HashMap, + rows: HashMap, scan_generation: usize, } @@ -55,7 +55,7 @@ pub struct SourceIndexingContext { pub const NO_ACK: Option Ready>> = None; struct LocalSourceRowStateOperator<'a> { - key: &'a value::KeyValue, + key: &'a value::FullKeyValue, indexing_state: &'a Mutex, update_stats: &'a Arc, @@ -75,7 +75,7 @@ enum RowStateAdvanceOutcome { impl<'a> LocalSourceRowStateOperator<'a> { fn new( - key: &'a value::KeyValue, + key: &'a value::FullKeyValue, indexing_state: &'a Mutex, update_stats: &'a Arc, ) -> Self { @@ -192,13 +192,12 @@ impl SourceIndexingContext { ); while let Some(key_metadata) = key_metadata_stream.next().await { let key_metadata = key_metadata?; - let source_key = value::Value::::from_json( + let source_pk = value::FullKeyValue::from_json( key_metadata.source_key, - &import_op.primary_key_type, - )? - .into_key()?; + &import_op.primary_key_schema, + )?; rows.insert( - source_key, + source_pk, SourceRowIndexingState { source_version: SourceVersion::from_stored( key_metadata.processed_source_ordinal, @@ -230,7 +229,7 @@ impl SourceIndexingContext { AckFn: FnOnce() -> AckFut, >( self: Arc, - key: value::KeyValue, + key: value::FullKeyValue, update_stats: Arc, _concur_permit: concur_control::CombinedConcurrencyControllerPermit, ack_fn: Option, diff --git a/src/ops/functions/split_recursively.rs b/src/ops/functions/split_recursively.rs index 147a181a3..f3d3f993f 100644 --- a/src/ops/functions/split_recursively.rs +++ b/src/ops/functions/split_recursively.rs @@ -932,7 +932,10 @@ impl SimpleFunctionExecutor for Executor { let output_start = chunk_output.start_pos.output.unwrap(); let output_end = chunk_output.end_pos.output.unwrap(); ( - RangeValue::new(output_start.char_offset, output_end.char_offset).into(), + FullKeyValue::from_single_part(RangeValue::new( + output_start.char_offset, + output_end.char_offset, + )), fields_value!( Arc::::from(chunk_output.text), output_start.into_output(), @@ -1022,11 +1025,14 @@ impl SimpleFunctionFactoryBase for Factory { attrs: Default::default(), }, )); - let output_schema = make_output_type(TableSchema::new(TableKind::KTable, struct_schema)) - .with_attr( - field_attrs::CHUNK_BASE_TEXT, - serde_json::to_value(args_resolver.get_analyze_value(&args.text))?, - ); + let output_schema = make_output_type(TableSchema::new( + TableKind::KTable(KTableInfo { num_key_parts: 1 }), + struct_schema, + )) + .with_attr( + field_attrs::CHUNK_BASE_TEXT, + serde_json::to_value(args_resolver.get_analyze_value(&args.text))?, + ); Ok((args, output_schema)) } @@ -1073,12 +1079,12 @@ mod tests { } // Creates a default RecursiveChunker for testing, assuming no language-specific parsing. - fn create_test_chunker( - text: &str, + fn create_test_chunker<'a>( + text: &'a str, chunk_size: usize, min_chunk_size: usize, chunk_overlap: usize, - ) -> RecursiveChunker { + ) -> RecursiveChunker<'a> { RecursiveChunker { full_text: text, chunk_size, @@ -1147,7 +1153,7 @@ mod tests { ]; for (range, expected_text) in expected_chunks { - let key: KeyValue = range.into(); + let key = FullKeyValue::from_single_part(range); match table.get(&key) { Some(scope_value_ref) => { let chunk_text = diff --git a/src/ops/interface.rs b/src/ops/interface.rs index 3dfe88848..f592247b5 100644 --- a/src/ops/interface.rs +++ b/src/ops/interface.rs @@ -49,7 +49,7 @@ impl TryFrom> for Ordinal { } pub struct PartialSourceRowMetadata { - pub key: KeyValue, + pub key: FullKeyValue, /// Auxiliary information for the source row, to be used when reading the content. /// e.g. it can be used to uniquely identify version of the row. /// Use serde_json::Value::Null to represent no auxiliary information. @@ -93,7 +93,7 @@ impl SourceValue { } pub struct SourceChange { - pub key: KeyValue, + pub key: FullKeyValue, /// Auxiliary information for the source row, to be used when reading the content. /// e.g. it can be used to uniquely identify version of the row. pub key_aux_info: serde_json::Value, @@ -138,7 +138,7 @@ pub trait SourceExecutor: Send + Sync { // Get the value for the given key. async fn get_value( &self, - key: &KeyValue, + key: &FullKeyValue, key_aux_info: &serde_json::Value, options: &SourceExecutorGetOptions, ) -> Result; diff --git a/src/ops/sources/amazon_s3.rs b/src/ops/sources/amazon_s3.rs index 6ff11b98e..cc132c2db 100644 --- a/src/ops/sources/amazon_s3.rs +++ b/src/ops/sources/amazon_s3.rs @@ -86,7 +86,7 @@ impl SourceExecutor for Executor { if key.ends_with('/') { continue; } if self.pattern_matcher.is_file_included(key) { batch.push(PartialSourceRowMetadata { - key: KeyValue::Str(key.to_string().into()), + key: FullKeyValue::from_single_part(key.to_string()), key_aux_info: serde_json::Value::Null, ordinal: obj.last_modified().map(datetime_to_ordinal), content_version_fp: None, @@ -110,11 +110,11 @@ impl SourceExecutor for Executor { async fn get_value( &self, - key: &KeyValue, + key: &FullKeyValue, _key_aux_info: &serde_json::Value, options: &SourceExecutorGetOptions, ) -> Result { - let key_str = key.str_value()?; + let key_str = key.single_part()?.str_value()?; if !self.pattern_matcher.is_file_included(key_str) { return Ok(PartialSourceRowData { value: Some(SourceValue::NonExistence), @@ -257,7 +257,7 @@ impl Executor { { let decoded_key = decode_form_encoded_url(&s3.object.key)?; changes.push(SourceChange { - key: KeyValue::Str(decoded_key), + key: FullKeyValue::from_single_part(decoded_key), key_aux_info: serde_json::Value::Null, data: PartialSourceRowData::default(), }); @@ -317,7 +317,7 @@ impl SourceFactoryBase for Factory { ), )); Ok(make_output_type(TableSchema::new( - TableKind::KTable, + TableKind::KTable(KTableInfo { num_key_parts: 1 }), struct_schema, ))) } diff --git a/src/ops/sources/azure_blob.rs b/src/ops/sources/azure_blob.rs index 583ca3b9e..c6ee5ebe7 100644 --- a/src/ops/sources/azure_blob.rs +++ b/src/ops/sources/azure_blob.rs @@ -76,7 +76,7 @@ impl SourceExecutor for Executor { if self.pattern_matcher.is_file_included(key) { let ordinal = Some(datetime_to_ordinal(&blob.properties.last_modified)); batch.push(PartialSourceRowMetadata { - key: KeyValue::Str(key.clone().into()), + key: FullKeyValue::from_single_part(key.clone()), key_aux_info: serde_json::Value::Null, ordinal, content_version_fp: None, @@ -99,11 +99,11 @@ impl SourceExecutor for Executor { async fn get_value( &self, - key: &KeyValue, + key: &FullKeyValue, _key_aux_info: &serde_json::Value, options: &SourceExecutorGetOptions, ) -> Result { - let key_str = key.str_value()?; + let key_str = key.single_part()?.str_value()?; if !self.pattern_matcher.is_file_included(key_str) { return Ok(PartialSourceRowData { value: Some(SourceValue::NonExistence), @@ -199,7 +199,7 @@ impl SourceFactoryBase for Factory { ), )); Ok(make_output_type(TableSchema::new( - TableKind::KTable, + TableKind::KTable(KTableInfo { num_key_parts: 1 }), struct_schema, ))) } diff --git a/src/ops/sources/google_drive.rs b/src/ops/sources/google_drive.rs index 6f0cc8029..28c2cbb09 100644 --- a/src/ops/sources/google_drive.rs +++ b/src/ops/sources/google_drive.rs @@ -134,7 +134,7 @@ impl Executor { None } else if is_supported_file_type(&mime_type) { Some(PartialSourceRowMetadata { - key: KeyValue::Str(id), + key: FullKeyValue::from_single_part(id), key_aux_info: serde_json::Value::Null, ordinal: file.modified_time.map(|t| t.try_into()).transpose()?, content_version_fp: None, @@ -211,7 +211,7 @@ impl Executor { let file_id = file.id.ok_or_else(|| anyhow!("File has no id"))?; if self.is_file_covered(&file_id).await? { changes.push(SourceChange { - key: KeyValue::Str(Arc::from(file_id)), + key: FullKeyValue::from_single_part(file_id), key_aux_info: serde_json::Value::Null, data: PartialSourceRowData::default(), }); @@ -325,11 +325,11 @@ impl SourceExecutor for Executor { async fn get_value( &self, - key: &KeyValue, + key: &FullKeyValue, _key_aux_info: &serde_json::Value, options: &SourceExecutorGetOptions, ) -> Result { - let file_id = key.str_value()?; + let file_id = key.single_part()?.str_value()?; let fields = format!( "id,name,mimeType,trashed{}", optional_modified_time(options.include_ordinal) @@ -480,7 +480,7 @@ impl SourceFactoryBase for Factory { ), )); Ok(make_output_type(TableSchema::new( - TableKind::KTable, + TableKind::KTable(KTableInfo { num_key_parts: 1 }), struct_schema, ))) } diff --git a/src/ops/sources/local_file.rs b/src/ops/sources/local_file.rs index 72b1ad38b..3e12064dd 100644 --- a/src/ops/sources/local_file.rs +++ b/src/ops/sources/local_file.rs @@ -55,7 +55,7 @@ impl SourceExecutor for Executor { None }; yield vec![PartialSourceRowMetadata { - key: KeyValue::Str(relative_path.into()), + key: FullKeyValue::from_single_part(relative_path.to_string()), key_aux_info: serde_json::Value::Null, ordinal, content_version_fp: None, @@ -70,21 +70,19 @@ impl SourceExecutor for Executor { async fn get_value( &self, - key: &KeyValue, + key: &FullKeyValue, _key_aux_info: &serde_json::Value, options: &SourceExecutorGetOptions, ) -> Result { - if !self - .pattern_matcher - .is_file_included(key.str_value()?.as_ref()) - { + let path = key.single_part()?.str_value()?.as_ref(); + if !self.pattern_matcher.is_file_included(path) { return Ok(PartialSourceRowData { value: Some(SourceValue::NonExistence), ordinal: Some(Ordinal::unavailable()), content_version_fp: None, }); } - let path = self.root_path.join(key.str_value()?.as_ref()); + let path = self.root_path.join(path); let ordinal = if options.include_ordinal { Some(path.metadata()?.modified()?.try_into()?) } else { @@ -151,7 +149,7 @@ impl SourceFactoryBase for Factory { )); Ok(make_output_type(TableSchema::new( - TableKind::KTable, + TableKind::KTable(KTableInfo { num_key_parts: 1 }), struct_schema, ))) } diff --git a/src/ops/sources/postgres.rs b/src/ops/sources/postgres.rs index 413eb9197..342998cb4 100644 --- a/src/ops/sources/postgres.rs +++ b/src/ops/sources/postgres.rs @@ -1,6 +1,6 @@ use crate::ops::sdk::*; -use crate::ops::shared::postgres::{bind_key_field, get_db_pool, key_value_fields_iter}; +use crate::ops::shared::postgres::{bind_key_field, get_db_pool}; use crate::settings::DatabaseConnectionSpec; use sqlx::postgres::types::PgInterval; use sqlx::{PgPool, Row}; @@ -345,14 +345,9 @@ impl SourceExecutor for Executor { .primary_key_columns .iter() .enumerate() - .map(|(i, info)| (info.decoder)(&row, i)) - .collect::>>()?; - if parts.iter().any(|v| v.is_null()) { - Err(anyhow::anyhow!( - "Composite primary key contains NULL component" - ))?; - } - let key = KeyValue::from_values(parts.iter())?; + .map(|(i, info)| (info.decoder)(&row, i)?.into_key()) + .collect::>>()?; + let key = FullKeyValue(parts); // Compute ordinal if requested let ordinal = if options.include_ordinal { @@ -385,7 +380,7 @@ impl SourceExecutor for Executor { async fn get_value( &self, - key: &KeyValue, + key: &FullKeyValue, _key_aux_info: &serde_json::Value, options: &SourceExecutorGetOptions, ) -> Result { @@ -420,17 +415,10 @@ impl SourceExecutor for Executor { qb.push(&self.table_name); qb.push("\" WHERE "); - let key_values = key_value_fields_iter( - self.table_schema - .primary_key_columns - .iter() - .map(|i| &i.schema), - key, - )?; - if key_values.len() != self.table_schema.primary_key_columns.len() { + if key.len() != self.table_schema.primary_key_columns.len() { bail!( "Composite key has {} values but table has {} primary key columns", - key_values.len(), + key.len(), self.table_schema.primary_key_columns.len() ); } @@ -439,7 +427,7 @@ impl SourceExecutor for Executor { .table_schema .primary_key_columns .iter() - .zip(key_values.iter()) + .zip(key.iter()) .enumerate() { if i > 0 { @@ -524,56 +512,22 @@ impl SourceFactoryBase for Factory { ) .await?; - // Build fields: first key, then value columns - let mut fields: Vec = Vec::new(); - - if table_schema.primary_key_columns.len() == 1 { - let pk_col = &table_schema.primary_key_columns[0]; - fields.push(FieldSchema::new( - &pk_col.schema.name, - pk_col.schema.value_type.clone(), - )); - } else { - // Composite primary key - put all PK columns into a nested struct `_key` - let key_fields: Vec = table_schema - .primary_key_columns - .iter() - .map(|pk_col| { - FieldSchema::new(&pk_col.schema.name, pk_col.schema.value_type.clone()) - }) - .collect(); - let key_struct_schema = StructSchema { - fields: Arc::new(key_fields), - description: None, - }; - fields.push(FieldSchema::new( - "_key", - make_output_type(key_struct_schema), - )); - } - - for value_col in &table_schema.value_columns { - fields.push(FieldSchema::new( - &value_col.schema.name, - value_col.schema.value_type.clone(), - )); - } - - // Log schema information for debugging - if table_schema.primary_key_columns.len() > 1 { - info!( - "Composite primary key detected: {} columns", - table_schema.primary_key_columns.len() - ); - } - - let struct_schema = StructSchema { - fields: Arc::new(fields), - description: None, - }; Ok(make_output_type(TableSchema::new( - TableKind::KTable, - struct_schema, + TableKind::KTable(KTableInfo { + num_key_parts: table_schema.primary_key_columns.len(), + }), + StructSchema { + fields: Arc::new( + (table_schema.primary_key_columns.into_iter().map(|pk_col| { + FieldSchema::new(&pk_col.schema.name, pk_col.schema.value_type) + })) + .chain(table_schema.value_columns.into_iter().map(|value_col| { + FieldSchema::new(&value_col.schema.name, value_col.schema.value_type) + })) + .collect(), + ), + description: None, + }, ))) } @@ -599,11 +553,6 @@ impl SourceFactoryBase for Factory { table_schema, }; - let filter_info = match &spec.included_columns { - Some(cols) => format!(" (filtered to {} specified columns)", cols.len()), - None => " (all columns)".to_string(), - }; - Ok(Box::new(executor)) } } diff --git a/src/ops/targets/kuzu.rs b/src/ops/targets/kuzu.rs index 346c89535..a650accd7 100644 --- a/src/ops/targets/kuzu.rs +++ b/src/ops/targets/kuzu.rs @@ -422,13 +422,12 @@ fn append_value( cypher.query_mut().push('['); let mut prefix = ""; for (k, v) in map.iter() { - let key_value = value::Value::from(k); cypher.query_mut().push_str(prefix); cypher.query_mut().push('{'); append_struct_fields( cypher, &row_schema.fields, - std::iter::once(&key_value).chain(v.fields.iter()), + k.to_values().iter().chain(v.fields.iter()), )?; cypher.query_mut().push('}'); prefix = ", "; @@ -529,7 +528,7 @@ fn append_upsert_node( &data_coll.schema.key_fields, upsert_entry .key - .fields_iter(data_coll.schema.key_fields.len())? + .fields_iter_for_export(data_coll.schema.key_fields.len())? .map(|f| Cow::Owned(value::Value::from(f))), )?; write!(cypher.query_mut(), ")")?; @@ -608,7 +607,7 @@ fn append_upsert_rel( &data_coll.schema.key_fields, upsert_entry .key - .fields_iter(data_coll.schema.key_fields.len())? + .fields_iter_for_export(data_coll.schema.key_fields.len())? .map(|f| Cow::Owned(value::Value::from(f))), )?; write!(cypher.query_mut(), "]->({TGT_NODE_VAR_NAME})")?; @@ -636,7 +635,7 @@ fn append_delete_node( append_key_pattern( cypher, &data_coll.schema.key_fields, - key.fields_iter(data_coll.schema.key_fields.len())? + key.fields_iter_for_export(data_coll.schema.key_fields.len())? .map(|f| Cow::Owned(value::Value::from(f))), )?; writeln!(cypher.query_mut(), ")")?; @@ -674,7 +673,7 @@ fn append_delete_rel( cypher, src_key_schema, src_node_key - .fields_iter(src_key_schema.len())? + .fields_iter_for_export(src_key_schema.len())? .map(|k| Cow::Owned(value::Value::from(k))), )?; @@ -683,7 +682,7 @@ fn append_delete_rel( append_key_pattern( cypher, key_schema, - key.fields_iter(key_schema.len())? + key.fields_iter_for_export(key_schema.len())? .map(|k| Cow::Owned(value::Value::from(k))), )?; @@ -697,7 +696,7 @@ fn append_delete_rel( cypher, tgt_key_schema, tgt_node_key - .fields_iter(tgt_key_schema.len())? + .fields_iter_for_export(tgt_key_schema.len())? .map(|k| Cow::Owned(value::Value::from(k))), )?; write!(cypher.query_mut(), ") DELETE {REL_VAR_NAME}")?; @@ -716,7 +715,7 @@ fn append_maybe_gc_node( append_key_pattern( cypher, &schema.key_fields, - key.fields_iter(schema.key_fields.len())? + key.fields_iter_for_export(schema.key_fields.len())? .map(|f| Cow::Owned(value::Value::from(f))), )?; writeln!(cypher.query_mut(), ")")?; @@ -976,11 +975,11 @@ impl TargetFactoryBase for Factory { delete.additional_key ); } - let src_key = KeyValue::from_json( + let src_key = KeyValue::from_json_for_export( additional_keys[0].take(), &rel.source.schema.key_fields, )?; - let tgt_key = KeyValue::from_json( + let tgt_key = KeyValue::from_json_for_export( additional_keys[1].take(), &rel.target.schema.key_fields, )?; diff --git a/src/ops/targets/neo4j.rs b/src/ops/targets/neo4j.rs index f77ea49bf..73db8b446 100644 --- a/src/ops/targets/neo4j.rs +++ b/src/ops/targets/neo4j.rs @@ -267,8 +267,7 @@ fn value_to_bolt(value: &Value, schema: &schema::ValueType) -> Result .iter() .map(|(k, v)| { field_values_to_bolt( - std::iter::once(&Into::::into(k.clone())) - .chain(v.0.fields.iter()), + k.to_values().iter().chain(v.0.fields.iter()), t.row.fields.iter(), ) }) @@ -458,7 +457,7 @@ impl ExportContext { ) -> Result { let mut query = query; for (i, val) in val - .fields_iter(self.analyzed_data_coll.schema.key_fields.len())? + .fields_iter_for_export(self.analyzed_data_coll.schema.key_fields.len())? .enumerate() { query = query.param( diff --git a/src/ops/targets/shared/property_graph.rs b/src/ops/targets/shared/property_graph.rs index cc3b12619..d0079b9a3 100644 --- a/src/ops/targets/shared/property_graph.rs +++ b/src/ops/targets/shared/property_graph.rs @@ -123,7 +123,7 @@ pub struct GraphElementInputFieldsIdx { impl GraphElementInputFieldsIdx { pub fn extract_key(&self, fields: &[value::Value]) -> Result { - value::KeyValue::from_values(self.key.iter().map(|idx| &fields[*idx])) + value::KeyValue::from_values_for_export(self.key.iter().map(|idx| &fields[*idx])) } } diff --git a/src/py/convert.rs b/src/py/convert.rs index 07a946a37..5ed2b4a0d 100644 --- a/src/py/convert.rs +++ b/src/py/convert.rs @@ -1,3 +1,4 @@ +use crate::base::value::FullKeyValue; use crate::prelude::*; use bytes::Bytes; @@ -108,10 +109,9 @@ pub fn value_to_py_object<'py>(py: Python<'py>, v: &value::Value) -> PyResult = + k.into_iter().map(|k| value::Value::from(k)).collect(); + field_values_to_py_object(py, k.iter().chain(v.0.fields.iter())) }) .collect::>>()?; PyList::new(py, rows)?.into_any() @@ -333,21 +333,32 @@ pub fn value_from_py_object<'py>( value::Value::LTable(values.into_iter().map(|v| v.into()).collect()) } - schema::TableKind::KTable => value::Value::KTable( - values - .into_iter() - .map(|v| { - let mut iter = v.fields.into_iter(); - let key = iter.next().unwrap().into_key().into_py_result()?; - Ok(( - key, - value::ScopeValue(value::FieldValues { + schema::TableKind::KTable(info) => { + let num_key_parts = info.num_key_parts; + value::Value::KTable( + values + .into_iter() + .map(|v| { + let mut iter = v.fields.into_iter(); + if iter.len() < num_key_parts { + anyhow::bail!( + "Invalid KTable value: expect at least {} fields, got {}", + num_key_parts, + iter.len() + ); + } + let keys: Box<[value::KeyValue]> = (0..num_key_parts) + .map(|_| iter.next().unwrap().into_key()) + .collect::>()?; + let values = value::FieldValues { fields: iter.collect::>(), - }), - )) - }) - .collect::>>()?, - ), + }; + Ok((FullKeyValue(keys), values.into())) + }) + .collect::>>() + .into_py_result()?, + ) + } } } } @@ -516,7 +527,7 @@ mod tests { // KTable let ktable_schema = schema::TableSchema { - kind: schema::TableKind::KTable, + kind: schema::TableKind::KTable(schema::KTableInfo { num_key_parts: 1 }), row: (*row_schema_struct).clone(), }; let mut ktable_data = BTreeMap::new(); @@ -547,8 +558,8 @@ mod tests { .into_key() .unwrap(); - ktable_data.insert(key1, row1_scope_val.clone()); - ktable_data.insert(key2, row2_scope_val.clone()); + ktable_data.insert(FullKeyValue(Box::from([key1])), row1_scope_val.clone()); + ktable_data.insert(FullKeyValue(Box::from([key2])), row2_scope_val.clone()); let ktable_val = value::Value::KTable(ktable_data); let ktable_typ = schema::ValueType::Table(ktable_schema); diff --git a/src/service/flows.rs b/src/service/flows.rs index 03971a0f9..dbb87b16e 100644 --- a/src/service/flows.rs +++ b/src/service/flows.rs @@ -60,8 +60,8 @@ pub struct GetKeysParam { #[derive(Serialize)] pub struct GetKeysResponse { - key_type: schema::EnrichedValueType, - keys: Vec<(value::KeyValue, serde_json::Value)>, + key_schema: Vec, + keys: Vec<(value::FullKeyValue, serde_json::Value)>, } pub async fn get_keys( @@ -82,16 +82,10 @@ pub async fn get_keys( StatusCode::BAD_REQUEST, ) })?; - let key_type = schema.fields[field_idx] - .value_type - .typ - .key_type() - .ok_or_else(|| { - ApiError::new( - &format!("field has no key: {}", query.field), - StatusCode::BAD_REQUEST, - ) - })?; + let pk_schema = schema.fields[field_idx].value_type.typ.key_schema(); + if pk_schema.is_empty() { + api_bail!("field has no key: {}", query.field); + } let execution_plan = flow_ctx.flow.get_execution_plan().await?; let import_op = execution_plan @@ -117,7 +111,7 @@ pub async fn get_keys( keys.extend(rows?.into_iter().map(|row| (row.key, row.key_aux_info))); } Ok(Json(GetKeysResponse { - key_type: key_type.clone(), + key_schema: pk_schema.to_vec(), keys, })) } @@ -139,7 +133,7 @@ struct SourceRowKeyContextHolder<'a> { plan: Arc, import_op_idx: usize, schema: &'a FlowSchema, - key: value::KeyValue, + key: value::FullKeyValue, key_aux_info: serde_json::Value, } @@ -165,10 +159,8 @@ impl<'a> SourceRowKeyContextHolder<'a> { schema::ValueType::Table(table) => table, _ => api_bail!("field is not a table: {}", source_row_key.field), }; - let key_field = table_schema - .key_field() - .ok_or_else(|| api_error!("field {} does not have a key", source_row_key.field))?; - let key = value::KeyValue::from_strs(source_row_key.key, &key_field.value_type.typ)?; + let key_schema = table_schema.key_schema(); + let key = value::FullKeyValue::decode_from_strs(source_row_key.key, key_schema)?; let key_aux_info = source_row_key .key_aux .map(|s| serde_json::from_str(&s)) diff --git a/src/setup/states.rs b/src/setup/states.rs index e49c41402..6c96352c8 100644 --- a/src/setup/states.rs +++ b/src/setup/states.rs @@ -147,7 +147,13 @@ impl StateChange { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct SourceSetupState { pub source_id: i32, - pub key_schema: schema::ValueType, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub keys_schema: Option>, + + /// DEPRECATED. For backward compatibility. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub key_schema: Option, // Allow empty string during deserialization for backward compatibility. #[serde(default)]