Skip to content

Commit c180bc4

Browse files
committed
feat(native-composite-key): support multiple key parts in KTable
1 parent 4f02279 commit c180bc4

File tree

25 files changed

+496
-309
lines changed

25 files changed

+496
-309
lines changed

examples/postgres_source/main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def postgres_product_indexing_flow(
9797
with data_scope["products"].row() as product:
9898
product["full_description"] = flow_builder.transform(
9999
make_full_description,
100-
product["_key"]["product_category"],
101-
product["_key"]["product_name"],
100+
product["product_category"],
101+
product["product_name"],
102102
product["description"],
103103
)
104104
product["total_value"] = flow_builder.transform(
@@ -112,8 +112,8 @@ def postgres_product_indexing_flow(
112112
)
113113
)
114114
indexed_product.collect(
115-
product_category=product["_key"]["product_category"],
116-
product_name=product["_key"]["product_name"],
115+
product_category=product["product_category"],
116+
product_name=product["product_name"],
117117
description=product["description"],
118118
price=product["price"],
119119
amount=product["amount"],

python/cocoindex/convert.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import numpy as np
1515

1616
from .typing import (
17-
KEY_FIELD_NAME,
1817
TABLE_TYPES,
1918
AnalyzedAnyType,
2019
AnalyzedBasicType,
@@ -96,14 +95,24 @@ def encode_struct_list(value: Any) -> Any:
9695
f"Value type for dict is required to be a struct (e.g. dataclass or NamedTuple), got {variant.value_type}. "
9796
f"If you want a free-formed dict, use `cocoindex.Json` instead."
9897
)
98+
value_encoder = make_engine_value_encoder(value_type_info)
9999

100-
key_encoder = make_engine_value_encoder(analyze_type_info(variant.key_type))
101-
value_encoder = make_engine_value_encoder(analyze_type_info(variant.value_type))
100+
key_type_info = analyze_type_info(variant.key_type)
101+
key_encoder = make_engine_value_encoder(key_type_info)
102+
if isinstance(key_type_info.variant, AnalyzedBasicType):
103+
104+
def encode_row(k: Any, v: Any) -> Any:
105+
return [key_encoder(k)] + value_encoder(v)
106+
107+
else:
108+
109+
def encode_row(k: Any, v: Any) -> Any:
110+
return key_encoder(k) + value_encoder(v)
102111

103112
def encode_struct_dict(value: Any) -> Any:
104113
if not value:
105114
return []
106-
return [[key_encoder(k)] + value_encoder(v) for k, v in value.items()]
115+
return [encode_row(k, v) for k, v in value.items()]
107116

108117
return encode_struct_dict
109118

@@ -234,25 +243,47 @@ def decode(value: Any) -> Any | None:
234243
f"declared `{dst_type_info.core_type}`, a dict type expected"
235244
)
236245

237-
key_field_schema = engine_fields_schema[0]
238-
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
239-
key_decoder = make_engine_value_decoder(
240-
field_path,
241-
key_field_schema["type"],
242-
analyze_type_info(key_type),
243-
for_key=True,
244-
)
245-
field_path.pop()
246+
num_key_parts = src_type.get("num_key_parts", 1)
247+
key_type_info = analyze_type_info(key_type)
248+
key_decoder: Callable[..., Any] | None = None
249+
if (
250+
isinstance(
251+
key_type_info.variant, (AnalyzedBasicType, AnalyzedAnyType)
252+
)
253+
and num_key_parts == 1
254+
):
255+
single_key_decoder = make_engine_value_decoder(
256+
field_path,
257+
engine_fields_schema[0]["type"],
258+
key_type_info,
259+
for_key=True,
260+
)
261+
262+
def key_decoder(value: list[Any]) -> Any:
263+
return single_key_decoder(value[0])
264+
265+
else:
266+
key_decoder = make_engine_struct_decoder(
267+
field_path,
268+
engine_fields_schema[0:num_key_parts],
269+
key_type_info,
270+
for_key=True,
271+
)
246272
value_decoder = make_engine_struct_decoder(
247273
field_path,
248-
engine_fields_schema[1:],
274+
engine_fields_schema[num_key_parts:],
249275
analyze_type_info(value_type),
250276
)
251277

252278
def decode(value: Any) -> Any | None:
253279
if value is None:
254280
return None
255-
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
281+
return {
282+
key_decoder(v[0:num_key_parts]): value_decoder(
283+
v[num_key_parts:]
284+
)
285+
for v in value
286+
}
256287

257288
return decode
258289

python/cocoindex/typing.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -330,35 +330,50 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
330330

331331
def _encode_struct_schema(
332332
struct_type: type, key_type: type | None = None
333-
) -> dict[str, Any]:
333+
) -> tuple[dict[str, Any], int | None]:
334334
fields = []
335335

336-
def add_field(name: str, t: Any) -> None:
336+
def add_field(name: str, analyzed_type: AnalyzedTypeInfo) -> None:
337337
try:
338-
type_info = encode_enriched_type_info(analyze_type_info(t))
338+
type_info = encode_enriched_type_info(analyzed_type)
339339
except ValueError as e:
340340
e.add_note(
341341
f"Failed to encode annotation for field - "
342-
f"{struct_type.__name__}.{name}: {t}"
342+
f"{struct_type.__name__}.{name}: {analyzed_type.core_type}"
343343
)
344344
raise
345345
type_info["name"] = name
346346
fields.append(type_info)
347347

348+
def add_fields_from_struct(struct_type: type) -> None:
349+
if dataclasses.is_dataclass(struct_type):
350+
for field in dataclasses.fields(struct_type):
351+
add_field(field.name, analyze_type_info(field.type))
352+
elif is_namedtuple_type(struct_type):
353+
for name, field_type in struct_type.__annotations__.items():
354+
add_field(name, analyze_type_info(field_type))
355+
else:
356+
raise ValueError(f"Unsupported struct type: {struct_type}")
357+
358+
result: dict[str, Any] = {}
359+
num_key_parts = None
348360
if key_type is not None:
349-
add_field(KEY_FIELD_NAME, key_type)
361+
key_type_info = analyze_type_info(key_type)
362+
if isinstance(key_type_info.variant, AnalyzedBasicType):
363+
add_field(KEY_FIELD_NAME, key_type_info)
364+
num_key_parts = 1
365+
elif isinstance(key_type_info.variant, AnalyzedStructType):
366+
add_fields_from_struct(key_type_info.variant.struct_type)
367+
num_key_parts = len(fields)
368+
else:
369+
raise ValueError(f"Unsupported key type: {key_type}")
350370

351-
if dataclasses.is_dataclass(struct_type):
352-
for field in dataclasses.fields(struct_type):
353-
add_field(field.name, field.type)
354-
elif is_namedtuple_type(struct_type):
355-
for name, field_type in struct_type.__annotations__.items():
356-
add_field(name, field_type)
371+
add_fields_from_struct(struct_type)
357372

358-
result: dict[str, Any] = {"fields": fields}
373+
result["fields"] = fields
359374
if doc := inspect.getdoc(struct_type):
360375
result["description"] = doc
361-
return result
376+
return result, num_key_parts
362377

363378

364379
def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
@@ -374,7 +389,7 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
374389
return {"kind": variant.kind}
375390

376391
if isinstance(variant, AnalyzedStructType):
377-
encoded_type = _encode_struct_schema(variant.struct_type)
392+
encoded_type, _ = _encode_struct_schema(variant.struct_type)
378393
encoded_type["kind"] = "Struct"
379394
return encoded_type
380395

@@ -384,10 +399,8 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
384399
if isinstance(elem_type_info.variant, AnalyzedStructType):
385400
if variant.vector_info is not None:
386401
raise ValueError("LTable type must not have a vector info")
387-
return {
388-
"kind": "LTable",
389-
"row": _encode_struct_schema(elem_type_info.variant.struct_type),
390-
}
402+
row_type, _ = _encode_struct_schema(elem_type_info.variant.struct_type)
403+
return {"kind": "LTable", "row": row_type}
391404
else:
392405
vector_info = variant.vector_info
393406
return {
@@ -402,12 +415,14 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
402415
raise ValueError(
403416
f"KTable value must have a Struct type, got {value_type_info.core_type}"
404417
)
418+
row_type, num_key_parts = _encode_struct_schema(
419+
value_type_info.variant.struct_type,
420+
variant.key_type,
421+
)
405422
return {
406423
"kind": "KTable",
407-
"row": _encode_struct_schema(
408-
value_type_info.variant.struct_type,
409-
variant.key_type,
410-
),
424+
"row": row_type,
425+
"num_key_parts": num_key_parts,
411426
}
412427

413428
if isinstance(variant, AnalyzedUnionType):

src/base/schema.rs

Lines changed: 36 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -136,52 +136,63 @@ impl std::fmt::Display for StructSchema {
136136
}
137137

138138
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
139+
#[serde(tag = "kind")]
139140
#[allow(clippy::enum_variant_names)]
140141
pub enum TableKind {
141142
/// An table with unordered rows, without key.
142143
UTable,
143-
/// A table's first field is the key.
144+
/// A table's first field is the key. The value is number of fields serving as the key
144145
#[serde(alias = "Table")]
145-
KTable,
146+
KTable {
147+
// Omit the field if num_key_parts is 1 for backward compatibility.
148+
#[serde(default = "default_num_key_parts", skip_serializing_if = "is_one")]
149+
num_key_parts: usize,
150+
},
151+
146152
/// A table whose rows orders are preserved.
147153
#[serde(alias = "List")]
148154
LTable,
149155
}
150156

157+
fn default_num_key_parts() -> usize {
158+
1
159+
}
160+
161+
fn is_one(value: &usize) -> bool {
162+
*value == 1
163+
}
164+
151165
impl std::fmt::Display for TableKind {
152166
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153167
match self {
154168
TableKind::UTable => write!(f, "Table"),
155-
TableKind::KTable => write!(f, "KTable"),
169+
TableKind::KTable { num_key_parts } => write!(f, "KTable({num_key_parts})"),
156170
TableKind::LTable => write!(f, "LTable"),
157171
}
158172
}
159173
}
160174

161175
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
162176
pub struct TableSchema {
177+
#[serde(flatten)]
163178
pub kind: TableKind,
179+
164180
pub row: StructSchema,
165181
}
166182

183+
impl std::fmt::Display for TableSchema {
184+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185+
write!(f, "{}({})", self.kind, self.row)
186+
}
187+
}
188+
167189
impl TableSchema {
168-
pub fn has_key(&self) -> bool {
169-
match self.kind {
170-
TableKind::KTable => true,
171-
TableKind::UTable | TableKind::LTable => false,
172-
}
190+
pub fn new(kind: TableKind, row: StructSchema) -> Self {
191+
Self { kind, row }
173192
}
174193

175-
pub fn key_type(&self) -> Option<&EnrichedValueType> {
176-
match self.kind {
177-
TableKind::KTable => self
178-
.row
179-
.fields
180-
.first()
181-
.as_ref()
182-
.map(|field| &field.value_type),
183-
TableKind::UTable | TableKind::LTable => None,
184-
}
194+
pub fn has_key(&self) -> bool {
195+
!self.key_schema().is_empty()
185196
}
186197

187198
pub fn without_attrs(&self) -> Self {
@@ -190,23 +201,11 @@ impl TableSchema {
190201
row: self.row.without_attrs(),
191202
}
192203
}
193-
}
194-
195-
impl std::fmt::Display for TableSchema {
196-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197-
write!(f, "{}({})", self.kind, self.row)
198-
}
199-
}
200-
201-
impl TableSchema {
202-
pub fn new(kind: TableKind, row: StructSchema) -> Self {
203-
Self { kind, row }
204-
}
205204

206-
pub fn key_field(&self) -> Option<&FieldSchema> {
205+
pub fn key_schema(&self) -> &[FieldSchema] {
207206
match self.kind {
208-
TableKind::KTable => Some(self.row.fields.first().unwrap()),
209-
TableKind::UTable | TableKind::LTable => None,
207+
TableKind::KTable { num_key_parts: n } => &self.row.fields[..n],
208+
TableKind::UTable | TableKind::LTable => &[],
210209
}
211210
}
212211
}
@@ -224,11 +223,11 @@ pub enum ValueType {
224223
}
225224

226225
impl ValueType {
227-
pub fn key_type(&self) -> Option<&EnrichedValueType> {
226+
pub fn key_schema(&self) -> &[FieldSchema] {
228227
match self {
229-
ValueType::Basic(_) => None,
230-
ValueType::Struct(_) => None,
231-
ValueType::Table(c) => c.key_type(),
228+
ValueType::Basic(_) => &[],
229+
ValueType::Struct(_) => &[],
230+
ValueType::Table(c) => c.key_schema(),
232231
}
233232
}
234233

0 commit comments

Comments
 (0)