Skip to content

Commit e4c3229

Browse files
committed
feat(kuzu): add HNSW vector index support
Implements HNSW vector index support for Kuzu following the same pattern as the Postgres implementation in PR cocoindex-io#1050. Changes: - Remove blanket "Vector indexes are not supported for Kuzu yet" error - Add validation to accept HNSW and reject IVFFlat with clear error message - Implement CREATE_VECTOR_INDEX and DROP_VECTOR_INDEX Cypher generation - Map cocoindex HNSW parameters to Kuzu format (m→mu/ml, ef_construction→efc) - Add vector index lifecycle management (create, update, delete) - Install Kuzu vector extension automatically when needed - Support all similarity metrics (cosine, l2, dotproduct) Technical details: - Add VectorIndexState struct to track index configuration - Update SetupState and GraphElementDataSetupChange for index tracking - Implement diff_setup_states logic for index change computation - Add vector index compatibility checking in check_state_compatibility - Integrate vector index operations in apply_setup_changes Fixes cocoindex-io#1055 Related to cocoindex-io#1051 Follows pattern from cocoindex-io#1050
1 parent c25d179 commit e4c3229

File tree

1 file changed

+187
-2
lines changed

1 file changed

+187
-2
lines changed

src/ops/targets/kuzu.rs

Lines changed: 187 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,16 @@ struct SetupState {
165165

166166
#[serde(default, skip_serializing_if = "Option::is_none")]
167167
referenced_node_tables: Option<(ReferencedNodeTable, ReferencedNodeTable)>,
168+
169+
#[serde(default, skip_serializing_if = "Vec::is_empty")]
170+
vector_indexes: Vec<VectorIndexState>,
171+
}
172+
173+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
174+
struct VectorIndexState {
175+
field_name: String,
176+
metric: spec::VectorSimilarityMetric,
177+
method: Option<spec::VectorIndexMethod>,
168178
}
169179

170180
impl<'a> From<&'a SetupState> for Cow<'a, TableColumnsSchema<String>> {
@@ -178,6 +188,8 @@ struct GraphElementDataSetupChange {
178188
actions: TableMainSetupAction<String>,
179189
referenced_node_tables: Option<(String, String)>,
180190
drop_affected_referenced_node_tables: IndexSet<String>,
191+
vector_indexes_to_create: Vec<spec::VectorIndexDef>,
192+
vector_indexes_to_drop: Vec<String>, // field names
181193
}
182194

183195
impl setup::ResourceSetupChange for GraphElementDataSetupChange {
@@ -190,6 +202,81 @@ impl setup::ResourceSetupChange for GraphElementDataSetupChange {
190202
}
191203
}
192204

205+
////////////////////////////////////////////////////////////
206+
// Vector Index Support Functions
207+
////////////////////////////////////////////////////////////
208+
209+
fn validate_vector_index_method(method: &Option<spec::VectorIndexMethod>) -> Result<()> {
210+
if let Some(method) = method {
211+
match method {
212+
spec::VectorIndexMethod::IvfFlat { .. } => {
213+
api_bail!(
214+
"IVFFlat vector index method is not supported by Kuzu. Only HNSW is supported."
215+
)
216+
}
217+
spec::VectorIndexMethod::Hnsw { .. } => Ok(()),
218+
}
219+
} else {
220+
Ok(())
221+
}
222+
}
223+
224+
fn append_create_vector_index(
225+
cypher: &mut CypherBuilder,
226+
table_name: &str,
227+
index_def: &spec::VectorIndexDef,
228+
) -> Result<()> {
229+
let index_name = format!("{}_{}_vector_idx", table_name, index_def.field_name);
230+
231+
write!(
232+
cypher.query_mut(),
233+
"CALL CREATE_VECTOR_INDEX('{}', '{}', '{}'",
234+
table_name, index_name, index_def.field_name
235+
)?;
236+
237+
let mut params = Vec::new();
238+
239+
// Map parameters from cocoindex to Kuzu
240+
if let Some(spec::VectorIndexMethod::Hnsw { m, ef_construction }) = &index_def.method {
241+
if let Some(m_val) = m {
242+
params.push(format!("mu := {}", m_val));
243+
params.push(format!("ml := {}", m_val * 2));
244+
}
245+
if let Some(ef_val) = ef_construction {
246+
params.push(format!("efc := {}", ef_val));
247+
}
248+
}
249+
250+
// Map metric
251+
let metric = match index_def.metric {
252+
spec::VectorSimilarityMetric::CosineSimilarity => "cosine",
253+
spec::VectorSimilarityMetric::L2Distance => "l2",
254+
spec::VectorSimilarityMetric::InnerProduct => "dotproduct",
255+
};
256+
params.push(format!("metric := '{}'", metric));
257+
258+
if !params.is_empty() {
259+
write!(cypher.query_mut(), ", {}", params.join(", "))?;
260+
}
261+
262+
writeln!(cypher.query_mut(), ");")?;
263+
Ok(())
264+
}
265+
266+
fn append_drop_vector_index(
267+
cypher: &mut CypherBuilder,
268+
table_name: &str,
269+
field_name: &str,
270+
) -> Result<()> {
271+
let index_name = format!("{}_{}_vector_idx", table_name, field_name);
272+
writeln!(
273+
cypher.query_mut(),
274+
"CALL DROP_VECTOR_INDEX('{}', '{}');",
275+
table_name, index_name
276+
)?;
277+
Ok(())
278+
}
279+
193280
fn append_drop_table(
194281
cypher: &mut CypherBuilder,
195282
setup_change: &GraphElementDataSetupChange,
@@ -772,8 +859,9 @@ impl TargetFactoryBase for Factory {
772859
let data_coll_outputs: Vec<TypedExportDataCollectionBuildOutput<Self>> =
773860
std::iter::zip(data_collections, analyzed_data_colls.into_iter())
774861
.map(|(data_coll, analyzed)| {
775-
if !data_coll.index_options.vector_indexes.is_empty() {
776-
api_bail!("Vector indexes are not supported for Kuzu yet");
862+
// Validate vector index methods
863+
for vector_index in &data_coll.index_options.vector_indexes {
864+
validate_vector_index_method(&vector_index.method)?;
777865
}
778866
fn to_dep_table(
779867
field_mapping: &AnalyzedGraphElementFieldMapping,
@@ -797,6 +885,14 @@ impl TargetFactoryBase for Factory {
797885
anyhow::Ok((to_dep_table(&rel.source)?, to_dep_table(&rel.target)?))
798886
})
799887
.transpose()?,
888+
vector_indexes: data_coll.index_options.vector_indexes
889+
.iter()
890+
.map(|vi| VectorIndexState {
891+
field_name: vi.field_name.clone(),
892+
metric: vi.metric,
893+
method: vi.method.clone(),
894+
})
895+
.collect(),
800896
};
801897

802898
let export_context = ExportContext {
@@ -824,6 +920,7 @@ impl TargetFactoryBase for Factory {
824920
value_columns: to_kuzu_cols(&graph_elem_schema.value_fields)?,
825921
},
826922
referenced_node_tables: None,
923+
vector_indexes: Vec::new(),
827924
};
828925
let setup_key = GraphElementType {
829926
connection: decl.connection,
@@ -847,8 +944,10 @@ impl TargetFactoryBase for Factory {
847944
.possible_versions()
848945
.any(|v| v.referenced_node_tables != desired.referenced_node_tables)
849946
});
947+
850948
let actions =
851949
TableMainSetupAction::from_states(desired.as_ref(), &existing, existing_invalidated);
950+
852951
let drop_affected_referenced_node_tables = if actions.drop_existing {
853952
existing
854953
.possible_versions()
@@ -858,12 +957,69 @@ impl TargetFactoryBase for Factory {
858957
} else {
859958
IndexSet::new()
860959
};
960+
961+
// Compute vector index changes
962+
let (vector_indexes_to_create, vector_indexes_to_drop) = match &desired {
963+
Some(desired_state) => {
964+
let existing_indexes: Vec<&VectorIndexState> = existing
965+
.possible_versions()
966+
.flat_map(|v| &v.vector_indexes)
967+
.collect();
968+
969+
let existing_index_map: std::collections::HashMap<&str, &VectorIndexState> =
970+
existing_indexes.iter()
971+
.map(|vi| (vi.field_name.as_str(), *vi))
972+
.collect();
973+
974+
let mut to_create = Vec::new();
975+
let mut to_drop = Vec::new();
976+
977+
for desired_vi in &desired_state.vector_indexes {
978+
if let Some(existing_vi) = existing_index_map.get(desired_vi.field_name.as_str()) {
979+
if existing_vi.metric != desired_vi.metric || existing_vi.method != desired_vi.method {
980+
to_drop.push(desired_vi.field_name.clone());
981+
} else {
982+
continue;
983+
}
984+
}
985+
to_create.push(spec::VectorIndexDef {
986+
field_name: desired_vi.field_name.clone(),
987+
metric: desired_vi.metric,
988+
method: desired_vi.method.clone(),
989+
});
990+
}
991+
992+
let desired_fields: std::collections::HashSet<&str> =
993+
desired_state.vector_indexes.iter()
994+
.map(|vi| vi.field_name.as_str())
995+
.collect();
996+
997+
for existing_vi in &existing_indexes {
998+
if !desired_fields.contains(existing_vi.field_name.as_str()) {
999+
to_drop.push(existing_vi.field_name.clone());
1000+
}
1001+
}
1002+
1003+
(to_create, to_drop)
1004+
}
1005+
None => {
1006+
let to_drop = existing
1007+
.possible_versions()
1008+
.flat_map(|v| &v.vector_indexes)
1009+
.map(|vi| vi.field_name.clone())
1010+
.collect();
1011+
(Vec::new(), to_drop)
1012+
}
1013+
};
1014+
8611015
Ok(GraphElementDataSetupChange {
8621016
actions,
8631017
referenced_node_tables: desired
8641018
.and_then(|desired| desired.referenced_node_tables)
8651019
.map(|(src, tgt)| (src.table_name, tgt.table_name)),
8661020
drop_affected_referenced_node_tables,
1021+
vector_indexes_to_create,
1022+
vector_indexes_to_drop,
8671023
})
8681024
}
8691025

@@ -875,6 +1031,8 @@ impl TargetFactoryBase for Factory {
8751031
Ok(
8761032
if desired.referenced_node_tables != existing.referenced_node_tables {
8771033
SetupStateCompatibility::NotCompatible
1034+
} else if desired.vector_indexes != existing.vector_indexes {
1035+
SetupStateCompatibility::Compatible
8781036
} else {
8791037
check_table_compatibility(&desired.schema, &existing.schema)
8801038
},
@@ -1080,6 +1238,33 @@ impl TargetFactoryBase for Factory {
10801238
append_delete_orphaned_nodes(&mut cypher, table)?;
10811239
}
10821240

1241+
// Install vector extension if needed
1242+
let has_vector_changes = node_changes.iter().any(|c| {
1243+
!c.setup_change.vector_indexes_to_create.is_empty()
1244+
|| !c.setup_change.vector_indexes_to_drop.is_empty()
1245+
});
1246+
1247+
if has_vector_changes {
1248+
writeln!(cypher.query_mut(), "INSTALL vector;")?;
1249+
writeln!(cypher.query_mut(), "LOAD vector;")?;
1250+
}
1251+
1252+
// Drop vector indexes first
1253+
for change in node_changes.iter() {
1254+
let table_name = change.key.typ.label();
1255+
for field_name in &change.setup_change.vector_indexes_to_drop {
1256+
append_drop_vector_index(&mut cypher, table_name, field_name)?;
1257+
}
1258+
}
1259+
1260+
// Create vector indexes
1261+
for change in node_changes.iter() {
1262+
let table_name = change.key.typ.label();
1263+
for index_def in &change.setup_change.vector_indexes_to_create {
1264+
append_create_vector_index(&mut cypher, table_name, index_def)?;
1265+
}
1266+
}
1267+
10831268
kuzu_client.run_cypher(cypher).await?;
10841269
}
10851270
Ok(())

0 commit comments

Comments
 (0)