Skip to content

Commit 7eeeaf5

Browse files
authored
feat(kuzu): support kuzu as storage target (#561)
* feat(kuzu): support setup * style: rearrange `kuzu.rs` into sections * refactor(kuzu): expose single `register()` method to seal internal types * refactor: use standalone methods to build setup query * feat(kuzu): support kuzu as target
1 parent b246372 commit 7eeeaf5

File tree

15 files changed

+1276
-69
lines changed

15 files changed

+1276
-69
lines changed

docs/docs/core/settings.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ description: Provide settings for CocoIndex, e.g. database connection, app names
66
import Tabs from '@theme/Tabs';
77
import TabItem from '@theme/TabItem';
88

9-
# CocoIndex Settings
9+
# CocoIndex Setting
1010

1111
Certain settings need to be provided for CocoIndex to work, e.g. database connections, app namespace, etc.
1212

examples/docs_to_knowledge_graph/main.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,41 @@ class Relationship:
2626
object: str
2727

2828

29+
neo4j_conn_spec = cocoindex.add_auth_entry(
30+
"Neo4jConnection",
31+
cocoindex.storages.Neo4jConnection(
32+
uri="bolt://localhost:7687",
33+
user="neo4j",
34+
password="cocoindex",
35+
),
36+
)
37+
kuzu_conn_spec = cocoindex.add_auth_entry(
38+
"KuzuConnection",
39+
cocoindex.storages.KuzuConnection(
40+
api_server_url="http://localhost:8123",
41+
),
42+
)
43+
44+
# Use Neo4j as the graph database
45+
GraphDbSpec = cocoindex.storages.Neo4j
46+
GraphDbConnection = cocoindex.storages.Neo4jConnection
47+
GraphDbDeclaration = cocoindex.storages.Neo4jDeclaration
48+
conn_spec = neo4j_conn_spec
49+
50+
# Use Kuzu as the graph database
51+
# GraphDbSpec = cocoindex.storages.Kuzu
52+
# GraphDbConnection = cocoindex.storages.KuzuConnection
53+
# GraphDbDeclaration = cocoindex.storages.KuzuDeclaration
54+
# conn_spec = kuzu_conn_spec
55+
56+
2957
@cocoindex.flow_def(name="DocsToKG")
3058
def docs_to_kg_flow(
3159
flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope
32-
):
60+
) -> None:
3361
"""
3462
Define an example flow that extracts relationship from files and build knowledge graph.
3563
"""
36-
# configure neo4j connection
37-
conn_spec = cocoindex.add_auth_entry(
38-
"Neo4jConnection",
39-
cocoindex.storages.Neo4jConnection(
40-
uri="bolt://localhost:7687",
41-
user="neo4j",
42-
password="cocoindex",
43-
),
44-
)
45-
4664
data_scope["documents"] = flow_builder.add_source(
4765
cocoindex.sources.LocalFile(
4866
path="../../docs/docs/core", included_patterns=["*.md", "*.mdx"]
@@ -112,22 +130,22 @@ def docs_to_kg_flow(
112130
# export to neo4j
113131
document_node.export(
114132
"document_node",
115-
cocoindex.storages.Neo4j(
133+
GraphDbSpec(
116134
connection=conn_spec, mapping=cocoindex.storages.Nodes(label="Document")
117135
),
118136
primary_key_fields=["filename"],
119137
)
120138
# Declare reference Node to reference entity node in a relationship
121139
flow_builder.declare(
122-
cocoindex.storages.Neo4jDeclaration(
140+
GraphDbDeclaration(
123141
connection=conn_spec,
124142
nodes_label="Entity",
125143
primary_key_fields=["value"],
126144
)
127145
)
128146
entity_relationship.export(
129147
"entity_relationship",
130-
cocoindex.storages.Neo4j(
148+
GraphDbSpec(
131149
connection=conn_spec,
132150
mapping=cocoindex.storages.Relationships(
133151
rel_type="RELATIONSHIP",
@@ -153,7 +171,7 @@ def docs_to_kg_flow(
153171
)
154172
entity_mention.export(
155173
"entity_mention",
156-
cocoindex.storages.Neo4j(
174+
GraphDbSpec(
157175
connection=conn_spec,
158176
mapping=cocoindex.storages.Relationships(
159177
rel_type="MENTION",

examples/product_recommendation/main.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,38 @@ def extract_product_info(product: cocoindex.Json, filename: str) -> ProductInfo:
7777
)
7878

7979

80+
neo4j_conn_spec = cocoindex.add_auth_entry(
81+
"Neo4jConnection",
82+
cocoindex.storages.Neo4jConnection(
83+
uri="bolt://localhost:7687",
84+
user="neo4j",
85+
password="cocoindex",
86+
),
87+
)
88+
kuzu_conn_spec = cocoindex.add_auth_entry(
89+
"KuzuConnection",
90+
cocoindex.storages.KuzuConnection(
91+
api_server_url="http://localhost:8123",
92+
),
93+
)
94+
95+
# Use Neo4j as the graph database
96+
GraphDbSpec = cocoindex.storages.Neo4j
97+
GraphDbConnection = cocoindex.storages.Neo4jConnection
98+
GraphDbDeclaration = cocoindex.storages.Neo4jDeclaration
99+
conn_spec = neo4j_conn_spec
100+
101+
# Use Kuzu as the graph database
102+
# GraphDbSpec = cocoindex.storages.Kuzu
103+
# GraphDbConnection = cocoindex.storages.KuzuConnection
104+
# GraphDbDeclaration = cocoindex.storages.KuzuDeclaration
105+
# conn_spec = kuzu_conn_spec
106+
107+
80108
@cocoindex.flow_def(name="StoreProduct")
81109
def store_product_flow(
82110
flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope
83-
):
111+
) -> None:
84112
"""
85113
Define an example flow that extracts triples from files and build knowledge graph.
86114
"""
@@ -122,25 +150,16 @@ def store_product_flow(
122150
taxonomy=t["name"],
123151
)
124152

125-
conn_spec = cocoindex.add_auth_entry(
126-
"Neo4jConnection",
127-
cocoindex.storages.Neo4jConnection(
128-
uri="bolt://localhost:7687",
129-
user="neo4j",
130-
password="cocoindex",
131-
),
132-
)
133-
134153
product_node.export(
135154
"product_node",
136-
cocoindex.storages.Neo4j(
155+
GraphDbSpec(
137156
connection=conn_spec, mapping=cocoindex.storages.Nodes(label="Product")
138157
),
139158
primary_key_fields=["id"],
140159
)
141160

142161
flow_builder.declare(
143-
cocoindex.storages.Neo4jDeclaration(
162+
GraphDbDeclaration(
144163
connection=conn_spec,
145164
nodes_label="Taxonomy",
146165
primary_key_fields=["value"],
@@ -149,7 +168,7 @@ def store_product_flow(
149168

150169
product_taxonomy.export(
151170
"product_taxonomy",
152-
cocoindex.storages.Neo4j(
171+
GraphDbSpec(
153172
connection=conn_spec,
154173
mapping=cocoindex.storages.Relationships(
155174
rel_type="PRODUCT_TAXONOMY",
@@ -175,7 +194,7 @@ def store_product_flow(
175194
)
176195
product_complementary_taxonomy.export(
177196
"product_complementary_taxonomy",
178-
cocoindex.storages.Neo4j(
197+
GraphDbSpec(
179198
connection=conn_spec,
180199
mapping=cocoindex.storages.Relationships(
181200
rel_type="PRODUCT_COMPLEMENTARY_TAXONOMY",

python/cocoindex/storages.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,6 @@ class Qdrant(op.StorageSpec):
2525
api_key: str | None = None
2626

2727

28-
@dataclass
29-
class Neo4jConnection:
30-
"""Connection spec for Neo4j."""
31-
32-
uri: str
33-
user: str
34-
password: str
35-
db: str | None = None
36-
37-
3828
@dataclass
3929
class TargetFieldMapping:
4030
"""Mapping for a graph element (node or relationship) field."""
@@ -88,6 +78,16 @@ class Relationships:
8878
NodeReferenceMapping = NodeFromFields
8979

9080

81+
@dataclass
82+
class Neo4jConnection:
83+
"""Connection spec for Neo4j."""
84+
85+
uri: str
86+
user: str
87+
password: str
88+
db: str | None = None
89+
90+
9191
class Neo4j(op.StorageSpec):
9292
"""Graph storage powered by Neo4j."""
9393

@@ -103,3 +103,26 @@ class Neo4jDeclaration(op.DeclarationSpec):
103103
nodes_label: str
104104
primary_key_fields: Sequence[str]
105105
vector_indexes: Sequence[index.VectorIndexDef] = ()
106+
107+
108+
@dataclass
109+
class KuzuConnection:
110+
"""Connection spec for Kuzu."""
111+
112+
api_server_url: str
113+
114+
115+
class Kuzu(op.StorageSpec):
116+
"""Graph storage powered by Kuzu."""
117+
118+
connection: AuthEntryReference[KuzuConnection]
119+
mapping: Nodes | Relationships
120+
121+
122+
class KuzuDeclaration(op.DeclarationSpec):
123+
"""Declarations for Kuzu."""
124+
125+
kind = "Kuzu"
126+
connection: AuthEntryReference[KuzuConnection]
127+
nodes_label: str
128+
primary_key_fields: Sequence[str]

src/base/schema.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ impl std::fmt::Display for BasicValueType {
8686
}
8787
}
8888

89-
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
89+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
9090
pub struct StructSchema {
9191
pub fields: Arc<Vec<FieldSchema>>,
9292

@@ -138,7 +138,7 @@ impl std::fmt::Display for TableKind {
138138
}
139139
}
140140

141-
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
141+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
142142
pub struct TableSchema {
143143
pub kind: TableKind,
144144
pub row: StructSchema,
@@ -191,7 +191,7 @@ impl TableSchema {
191191
}
192192
}
193193

194-
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
194+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
195195
#[serde(tag = "kind")]
196196
pub enum ValueType {
197197
Struct(StructSchema),
@@ -222,7 +222,7 @@ impl ValueType {
222222
}
223223
}
224224

225-
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
225+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
226226
pub struct EnrichedValueType<DataType = ValueType> {
227227
#[serde(rename = "type")]
228228
pub typ: DataType,
@@ -295,7 +295,7 @@ impl std::fmt::Display for ValueType {
295295
}
296296
}
297297

298-
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
298+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
299299
pub struct FieldSchema<DataType = ValueType> {
300300
/// ID is used to identify the field in the schema.
301301
pub name: FieldName,
@@ -342,7 +342,7 @@ impl std::fmt::Display for FieldSchema {
342342
pub struct CollectorSchema {
343343
pub fields: Vec<FieldSchema>,
344344
/// If specified, the collector will have an automatically generated UUID field with the given index.
345-
pub auto_uuid_field_idx: Option<u32>,
345+
pub auto_uuid_field_idx: Option<usize>,
346346
}
347347

348348
impl std::fmt::Display for CollectorSchema {

src/base/value.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use super::schema::*;
22
use crate::base::duration::parse_duration;
3+
use crate::prelude::invariance_violation;
34
use crate::{api_bail, api_error};
45
use anyhow::Result;
56
use base64::prelude::*;
@@ -175,6 +176,26 @@ impl std::fmt::Display for KeyValue {
175176
}
176177

177178
impl KeyValue {
179+
pub fn from_json(value: serde_json::Value, fields_schema: &[FieldSchema]) -> Result<Self> {
180+
let value = if fields_schema.len() == 1 {
181+
Value::from_json(value, &fields_schema[0].value_type.typ)?
182+
} else {
183+
let field_values: FieldValues = FieldValues::from_json(value, fields_schema)?;
184+
Value::Struct(field_values)
185+
};
186+
Ok(value.as_key()?)
187+
}
188+
189+
pub fn from_values<'a>(values: impl ExactSizeIterator<Item = &'a Value>) -> Result<Self> {
190+
let key = if values.len() == 1 {
191+
let mut values = values;
192+
values.next().ok_or_else(invariance_violation)?.as_key()?
193+
} else {
194+
KeyValue::Struct(values.map(|v| v.as_key()).collect::<Result<Vec<_>>>()?)
195+
};
196+
Ok(key)
197+
}
198+
178199
pub fn fields_iter(&self, num_fields: usize) -> Result<impl Iterator<Item = &KeyValue>> {
179200
let slice = if num_fields == 1 {
180201
std::slice::from_ref(self)

src/builder/analyzer.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,6 @@ impl AnalyzerContext<'_> {
987987
.fields
988988
.iter()
989989
.position(|field| &field.name == f)
990-
.map(|idx| idx as u32)
991990
.ok_or_else(|| anyhow!("field not found: {}", f))
992991
})
993992
.collect::<Result<Vec<_>>>()?;
@@ -1007,7 +1006,7 @@ impl AnalyzerContext<'_> {
10071006
let mut value_fields_schema: Vec<FieldSchema> = vec![];
10081007
let mut value_fields_idx = vec![];
10091008
for (idx, field) in collector_schema.fields.iter().enumerate() {
1010-
if !pk_fields_idx.contains(&(idx as u32)) {
1009+
if !pk_fields_idx.contains(&idx) {
10111010
value_fields_schema.push(field.clone());
10121011
value_fields_idx.push(idx as u32);
10131012
}

src/builder/plan.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ pub struct AnalyzedCollectOp {
9494
}
9595

9696
pub enum AnalyzedPrimaryKeyDef {
97-
Fields(Vec<u32>),
97+
Fields(Vec<usize>),
9898
}
9999

100100
pub struct AnalyzedExportOp {

src/execution/row_indexer.rs

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,11 @@ pub fn extract_primary_key(
2424
primary_key_def: &AnalyzedPrimaryKeyDef,
2525
record: &FieldValues,
2626
) -> Result<KeyValue> {
27-
let key = match primary_key_def {
27+
match primary_key_def {
2828
AnalyzedPrimaryKeyDef::Fields(fields) => {
29-
if fields.len() == 1 {
30-
record.fields[fields[0] as usize].as_key()?
31-
} else {
32-
let mut key_values = Vec::with_capacity(fields.len());
33-
for field in fields.iter() {
34-
key_values.push(record.fields[*field as usize].as_key()?);
35-
}
36-
KeyValue::Struct(key_values)
37-
}
29+
KeyValue::from_values(fields.iter().map(|field| &record.fields[*field as usize]))
3830
}
39-
};
40-
Ok(key)
31+
}
4132
}
4233

4334
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]

src/ops/registration.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ use anyhow::Result;
66
use std::sync::{Arc, LazyLock, RwLock, RwLockReadGuard};
77

88
fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
9+
let reqwest_client = reqwest::Client::new();
10+
911
sources::local_file::Factory.register(registry)?;
1012
sources::google_drive::Factory.register(registry)?;
1113
sources::amazon_s3::Factory.register(registry)?;
@@ -16,6 +18,7 @@ fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result
1618

1719
storages::postgres::Factory::default().register(registry)?;
1820
Arc::new(storages::qdrant::Factory::default()).register(registry)?;
21+
storages::kuzu::register(registry, reqwest_client)?;
1922

2023
storages::neo4j::Factory::new().register(registry)?;
2124

0 commit comments

Comments
 (0)