diff --git a/python/cocoindex/storages.py b/python/cocoindex/storages.py index 1744ef4e3..bdb2e7d1f 100644 --- a/python/cocoindex/storages.py +++ b/python/cocoindex/storages.py @@ -1,5 +1,6 @@ """All builtin storages.""" from dataclasses import dataclass +from typing import Sequence from . import op from . import index @@ -44,8 +45,9 @@ class Neo4jRelationshipEndSpec: @dataclass class Neo4jRelationshipNodeSpec: """Spec for a Neo4j node type.""" - primary_key_fields: list[str] - vector_indexes: list[index.VectorIndexDef] | None = None + primary_key_fields: Sequence[str] + vector_indexes: Sequence[index.VectorIndexDef] = () + class Neo4jRelationship(op.StorageSpec): """Graph storage powered by Neo4j.""" diff --git a/src/ops/storages/neo4j.rs b/src/ops/storages/neo4j.rs index 17b8bdd14..297265016 100644 --- a/src/ops/storages/neo4j.rs +++ b/src/ops/storages/neo4j.rs @@ -1,4 +1,5 @@ use crate::prelude::*; +use crate::setup::components::{self, State}; use crate::setup::{ResourceSetupStatusCheck, SetupChangeType}; use crate::{ops::sdk::*, setup::CombinedState}; @@ -7,7 +8,7 @@ use tokio::sync::OnceCell; const DEFAULT_DB: &str = "neo4j"; -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Clone)] pub struct ConnectionSpec { uri: String, user: String, @@ -532,22 +533,159 @@ impl ExportTargetExecutor for RelationshipStorageExecutor { } } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -struct VectorIndexState { - label: String, - field_name: String, - vector_size: usize, - metric: spec::VectorSimilarityMetric, +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct RelationshipSetupState { + key_field_names: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + node_labels: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + sub_components: Vec, } -impl VectorIndexState { +impl RelationshipSetupState { fn new( - label: &str, + spec: &RelationshipSpec, + key_field_names: Vec, + index_options: &IndexOptions, + rel_value_fields_info: &[AnalyzedGraphFieldMapping], + src_label_info: &AnalyzedNodeLabelInfo, + tgt_label_info: &AnalyzedNodeLabelInfo, + ) -> Result { + let mut sub_components = vec![]; + sub_components.push(ComponentState { + object_label: ObjectLabel::Relationship(spec.rel_type.clone()), + index_def: IndexDef::KeyConstraint { + field_names: key_field_names.clone(), + }, + }); + for index_def in index_options.vector_indexes.iter() { + sub_components.push(ComponentState { + object_label: ObjectLabel::Relationship(spec.rel_type.clone()), + index_def: IndexDef::from_vector_index_def( + index_def, + &rel_value_fields_info + .iter() + .find(|f| f.field_name == index_def.field_name) + .ok_or_else(|| { + api_error!( + "Unknown field name for vector index: {}", + index_def.field_name + ) + })? + .value_type, + )?, + }); + } + for (label, node) in spec.nodes.iter() { + sub_components.push(ComponentState { + object_label: ObjectLabel::Node(label.clone()), + index_def: IndexDef::KeyConstraint { + field_names: key_field_names.clone(), + }, + }); + for index_def in &node.index_options.vector_indexes { + sub_components.push(ComponentState { + object_label: ObjectLabel::Node(label.clone()), + index_def: IndexDef::from_vector_index_def( + index_def, + [src_label_info, tgt_label_info] + .into_iter() + .flat_map(|v| v.key_fields.iter().chain(v.value_fields.iter())) + .find(|f| f.field_name == index_def.field_name) + .map(|f| &f.value_type) + .ok_or_else(|| { + api_error!( + "Unknown field name for vector index: {}", + index_def.field_name + ) + })?, + )?, + }); + } + } + Ok(Self { + key_field_names, + node_labels: spec.nodes.keys().cloned().collect(), + sub_components, + }) + } + + fn check_compatible(&self, existing: &Self) -> SetupStateCompatibility { + if self.key_field_names == existing.key_field_names { + SetupStateCompatibility::Compatible + } else { + SetupStateCompatibility::NotCompatible + } + } +} + +impl IntoIterator for RelationshipSetupState { + type Item = ComponentState; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.sub_components.into_iter() + } +} +#[derive(Debug)] +struct DataClearAction { + rel_type: String, + node_labels: Vec, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum ComponentKind { + KeyConstraint, + VectorIndex, +} + +impl ComponentKind { + fn describe(&self) -> &str { + match self { + ComponentKind::KeyConstraint => "KEY CONSTRAINT", + ComponentKind::VectorIndex => "VECTOR INDEX", + } + } +} +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct ComponentKey { + kind: ComponentKind, + name: String, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +enum ObjectLabel { + Node(String), + Relationship(String), +} + +impl ObjectLabel { + fn label(&self) -> &str { + match self { + ObjectLabel::Node(label) => label, + ObjectLabel::Relationship(label) => label, + } + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +enum IndexDef { + KeyConstraint { + field_names: Vec, + }, + VectorIndex { + field_name: String, + metric: spec::VectorSimilarityMetric, + vector_size: usize, + }, +} + +impl IndexDef { + fn from_vector_index_def( index_def: &spec::VectorIndexDef, field_typ: &schema::ValueType, ) -> Result { - Ok(Self { - label: label.to_string(), + Ok(Self::VectorIndex { field_name: index_def.field_name.clone(), vector_size: (match field_typ { schema::ValueType::Basic(schema::BasicValueType::Vector(schema)) => { @@ -563,174 +701,132 @@ impl VectorIndexState { } } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct NodeLabelSetupState { - key_field_names: Vec, - key_constraint_name: String, - #[serde(default, skip_serializing_if = "HashMap::is_empty")] - vector_indexes: HashMap, +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +pub struct ComponentState { + object_label: ObjectLabel, + index_def: IndexDef, } -impl NodeLabelSetupState { - fn new( - label: &str, - spec: &RelationshipNodeSpec, - node_label_infos: &[&AnalyzedNodeLabelInfo], - ) -> Result { - let key_constraint_name = format!("n__{}__unique", label); - Ok(Self { - key_field_names: spec - .index_options - .primary_key_fields - .clone() - .unwrap_or_default(), - key_constraint_name, - vector_indexes: spec - .index_options - .vector_indexes - .iter() - .map(|v| -> Result<_> { - Ok(( - format!("n__{}__{}__{}", label, v.field_name.clone(), v.metric), - VectorIndexState::new( - label, - v, - node_label_infos - .iter() - .flat_map(|v| v.key_fields.iter().chain(v.value_fields.iter())) - .find(|f| f.field_name == v.field_name) - .map(|f| &f.value_type) - .ok_or_else(|| { - api_error!( - "Unknown field name for vector index: {}", - v.field_name - ) - })?, - )?, - )) - }) - .collect::>()?, - }) - } - - fn is_compatible(&self, other: &Self) -> bool { - self.key_field_names == other.key_field_names +impl components::State for ComponentState { + fn key(&self) -> ComponentKey { + let prefix = match &self.object_label { + ObjectLabel::Relationship(_) => "r", + ObjectLabel::Node(_) => "n", + }; + let label = self.object_label.label(); + match &self.index_def { + IndexDef::KeyConstraint { .. } => ComponentKey { + kind: ComponentKind::KeyConstraint, + name: format!("{prefix}__{label}__key"), + }, + IndexDef::VectorIndex { + field_name, metric, .. + } => ComponentKey { + kind: ComponentKind::VectorIndex, + name: format!("{prefix}__{label}__{field_name}__{metric}__vidx"), + }, + } } } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RelationshipSetupState { - key_field_names: Vec, - key_constraint_name: String, - #[serde(default, skip_serializing_if = "HashMap::is_empty")] - vector_indexes: HashMap, - #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] - nodes: BTreeMap, + +struct SetupComponentOperator { + graph_pool: Arc, + conn_spec: ConnectionSpec, } -impl RelationshipSetupState { - fn new( - spec: &RelationshipSpec, - key_field_names: Vec, - index_options: &IndexOptions, - rel_value_fields_info: &[AnalyzedGraphFieldMapping], - src_label_info: &AnalyzedNodeLabelInfo, - tgt_label_info: &AnalyzedNodeLabelInfo, - ) -> Result { - Ok(Self { - key_field_names, - key_constraint_name: format!("r__{}__key", spec.rel_type), - vector_indexes: index_options - .vector_indexes - .iter() - .map(|v| -> Result<_> { - Ok(( - format!("r__{}__{}__{}", spec.rel_type, v.field_name, v.metric), - VectorIndexState::new( - &spec.rel_type, - v, - &rel_value_fields_info - .iter() - .find(|f| f.field_name == v.field_name) - .ok_or_else(|| { - api_error!( - "Unknown field name for vector index: {}", - v.field_name - ) - })? - .value_type, - )?, - )) - }) - .collect::>()?, - nodes: spec - .nodes - .iter() - .map(|(label, node)| -> Result<_> { - Ok(( - label.clone(), - NodeLabelSetupState::new(label, node, &[src_label_info, tgt_label_info])?, - )) - }) - .collect::>()?, - }) +#[async_trait] +impl components::Operator for SetupComponentOperator { + type Key = ComponentKey; + type State = ComponentState; + type SetupState = RelationshipSetupState; + + fn describe_key(&self, key: &Self::Key) -> String { + format!("{} {}", key.kind.describe(), key.name) } - fn check_compatible(&self, existing: &Self) -> SetupStateCompatibility { - if self.key_field_names != existing.key_field_names { - SetupStateCompatibility::NotCompatible - } else if existing.nodes.iter().any(|(label, existing_node)| { - !self - .nodes - .get(label).is_some_and(|node| node.is_compatible(existing_node)) - }) { - // If any node's key field change of some node label gone, we have to clear relationship. - SetupStateCompatibility::NotCompatible - } else { - SetupStateCompatibility::Compatible + fn describe_state(&self, state: &Self::State) -> String { + let key_desc = self.describe_key(&state.key()); + let label = state.object_label.label(); + match &state.index_def { + IndexDef::KeyConstraint { field_names } => { + format!("{key_desc} ON {label} (key: {})", field_names.join(", ")) + } + IndexDef::VectorIndex { + field_name, + metric, + vector_size, + } => { + format!("{key_desc} ON {label} (field_name: {field_name}, vector_size: {vector_size}, metric: {metric})",) + } } } -} -#[derive(Debug)] -struct DataClearAction { - rel_type: String, - node_labels: IndexSet, -} + fn is_up_to_date(&self, current: &ComponentState, desired: &ComponentState) -> bool { + current == desired + } -#[derive(Debug)] -struct KeyConstraint { - label: String, - field_names: Vec, -} + async fn create(&self, state: &ComponentState) -> Result<()> { + let graph = self.graph_pool.get_graph(&self.conn_spec).await?; + let key = state.key(); + let (matcher, qualifier) = match &state.object_label { + ObjectLabel::Relationship(label) => (format!("()-[r:{label}]->()"), "r"), + ObjectLabel::Node(label) => (format!("(n:{label})"), "n"), + }; + let query = neo4rs::query(&match &state.index_def { + IndexDef::KeyConstraint { field_names } => { + format!( + "CREATE CONSTRAINT {name} IF NOT EXISTS FOR {matcher} REQUIRE {field_names} IS UNIQUE", + name=key.name, + field_names=build_composite_field_names(qualifier, &field_names), + ) + } + IndexDef::VectorIndex { + field_name, + metric, + vector_size, + } => { + format!( + r#"CREATE VECTOR INDEX {name} IF NOT EXISTS FOR {matcher} ON {qualifier}.{field_name} OPTIONS + {{ indexConfig: {{`vector.dimensions`: {vector_size}, `vector.similarity_function`: '{metric}'}}}}"#, + name = key.name, + ) + } + }); + Ok(graph.run(query).await?) + } -impl KeyConstraint { - fn new(label: String, state: &NodeLabelSetupState) -> Self { - Self { - label, - field_names: state.key_field_names.clone(), - } + async fn delete(&self, key: &ComponentKey) -> Result<()> { + let graph = self.graph_pool.get_graph(&self.conn_spec).await?; + let query = neo4rs::query(&format!( + "DROP {kind} {name} IF EXISTS", + kind = match key.kind { + ComponentKind::KeyConstraint => "CONSTRAINT", + ComponentKind::VectorIndex => "INDEX", + }, + name = key.name, + )); + Ok(graph.run(query).await?) } } +fn build_composite_field_names(qualifier: &str, field_names: &[String]) -> String { + let strs = field_names + .iter() + .map(|name| format!("{qualifier}.{name}")) + .join(", "); + if field_names.len() == 1 { + strs + } else { + format!("({})", strs) + } +} #[derive(Derivative)] #[derivative(Debug)] struct SetupStatusCheck { #[derivative(Debug = "ignore")] graph_pool: Arc, conn_spec: ConnectionSpec, - data_clear: Option, - - rel_constraint_to_delete: IndexSet, - rel_constraint_to_create: IndexMap, - node_constraint_to_delete: IndexSet, - node_constraint_to_create: IndexMap, - - rel_index_to_delete: IndexSet, - rel_index_to_create: IndexMap, - node_index_to_delete: IndexSet, - node_index_to_create: IndexMap, - change_type: SetupChangeType, } @@ -739,8 +835,8 @@ impl SetupStatusCheck { key: GraphRelationship, graph_pool: Arc, conn_spec: ConnectionSpec, - desired_state: Option, - existing: CombinedState, + desired_state: Option<&RelationshipSetupState>, + existing: &CombinedState, ) -> Self { let data_clear = existing .current @@ -753,127 +849,26 @@ impl SetupStatusCheck { }) .map(|existing_current| DataClearAction { rel_type: key.relationship.clone(), - node_labels: existing_current.nodes.keys().cloned().collect(), + node_labels: existing_current.node_labels.clone(), }); - let mut old_rel_constraints = IndexSet::new(); - let mut old_node_constraints = IndexSet::new(); - let mut old_rel_indexes = IndexSet::new(); - let mut old_node_indexes = IndexSet::new(); - - for existing_version in existing.possible_versions() { - old_rel_constraints.insert(existing_version.key_constraint_name.clone()); - old_rel_indexes.extend(existing_version.vector_indexes.keys().cloned()); - for (_, node) in existing_version.nodes.iter() { - old_node_constraints.insert(node.key_constraint_name.clone()); - old_node_indexes.extend(node.vector_indexes.keys().cloned()); - } - } - - let mut rel_constraint_to_create = IndexMap::new(); - let mut node_constraint_to_create = IndexMap::new(); - let mut rel_index_to_create = IndexMap::new(); - let mut node_index_to_create = IndexMap::new(); - - if let Some(desired_state) = desired_state { - let rel_constraint = KeyConstraint { - label: key.relationship.clone(), - field_names: desired_state.key_field_names.clone(), - }; - old_rel_constraints.shift_remove(&desired_state.key_constraint_name); - if !existing - .current - .as_ref() - .map(|c| rel_constraint.field_names == c.key_field_names) - .unwrap_or(false) - { - rel_constraint_to_create.insert(desired_state.key_constraint_name, rel_constraint); - } - - for (index_name, vector_index) in desired_state.vector_indexes.into_iter() { - old_rel_indexes.shift_remove(&index_name); - if !existing.current.as_ref().is_some_and(|c| { - Some(&vector_index) == c.vector_indexes.get(&index_name) - }) { - rel_index_to_create.insert(index_name, vector_index); - } - } - - for (label, node) in desired_state.nodes.into_iter() { - old_node_constraints.shift_remove(&node.key_constraint_name); - if !existing - .current - .as_ref() - .map(|c| { - c.nodes - .get(&label).is_some_and(|existing_node| node.is_compatible(existing_node)) - }) - .unwrap_or(false) - { - node_constraint_to_create.insert( - node.key_constraint_name.clone(), - KeyConstraint::new(label.clone(), &node), - ); - } - - for (index_name, vector_index) in node.vector_indexes.into_iter() { - old_node_indexes.shift_remove(&index_name); - if !existing.current.as_ref().is_some_and(|c| { - c.nodes.get(&label).is_some_and(|n| { - Some(&vector_index) == n.vector_indexes.get(&index_name) - }) - }) { - node_index_to_create.insert(index_name, vector_index); - } + let change_type = match (desired_state, existing.possible_versions().next()) { + (Some(_), Some(_)) => { + if data_clear.is_none() { + SetupChangeType::NoChange + } else { + SetupChangeType::Update } } - } - - let rel_constraint_to_delete = old_rel_constraints; - let node_constraint_to_delete = old_node_constraints; - let rel_index_to_delete = old_rel_indexes; - let node_index_to_delete = old_node_indexes; - - let change_type = if data_clear.is_none() - && rel_constraint_to_delete.is_empty() - && rel_constraint_to_create.is_empty() - && node_constraint_to_delete.is_empty() - && node_constraint_to_create.is_empty() - && rel_index_to_delete.is_empty() - && rel_index_to_create.is_empty() - && node_index_to_delete.is_empty() - && node_index_to_create.is_empty() - { - SetupChangeType::NoChange - } else if data_clear.is_none() - && rel_constraint_to_delete.is_empty() - && node_constraint_to_delete.is_empty() - && rel_index_to_delete.is_empty() - && node_index_to_delete.is_empty() - { - SetupChangeType::Create - } else if rel_constraint_to_create.is_empty() - && node_constraint_to_create.is_empty() - && rel_index_to_create.is_empty() - && node_index_to_create.is_empty() - { - SetupChangeType::Delete - } else { - SetupChangeType::Update + (Some(_), None) => SetupChangeType::Create, + (None, Some(_)) => SetupChangeType::Delete, + (None, None) => SetupChangeType::NoChange, }; Self { graph_pool, conn_spec, data_clear, - rel_constraint_to_delete, - rel_constraint_to_create, - node_constraint_to_delete, - node_constraint_to_create, - rel_index_to_delete, - rel_index_to_create, - node_index_to_delete, - node_index_to_create, change_type, } } @@ -890,47 +885,6 @@ impl ResourceSetupStatusCheck for SetupStatusCheck { data_clear.node_labels.iter().join(", "), )); } - for name in &self.rel_constraint_to_delete { - result.push(format!("Delete relationship constraint {}", name)); - } - for (name, rel_constraint) in self.rel_constraint_to_create.iter() { - result.push(format!( - "Create KEY CONSTRAINT {} ON RELATIONSHIP {} (key: {})", - name, - rel_constraint.label, - rel_constraint.field_names.join(", "), - )); - } - for name in &self.node_constraint_to_delete { - result.push(format!("Delete node constraint {}", name)); - } - for (name, node_constraint) in self.node_constraint_to_create.iter() { - result.push(format!( - "Create KEY CONSTRAINT {} ON NODE {} (key: {})", - name, - node_constraint.label, - node_constraint.field_names.join(", "), - )); - } - for name in &self.rel_index_to_delete { - result.push(format!("Delete relationship index {}", name)); - } - for (name, vector_index) in self.rel_index_to_create.iter() { - result.push(format!( - "Create VECTOR INDEX {} (vector_size: {}, metric: {}) ON RELATIONSHIP {}", - name, vector_index.vector_size, vector_index.metric, vector_index.label - )); - } - for name in &self.node_index_to_delete { - result.push(format!("Delete node index {}", name)); - } - for (name, vector_index) in self.node_index_to_create.iter() { - result.push(format!( - "Create VECTOR INDEX {} (vector_size: {}, metric: {}) ON NODE {}", - name, vector_index.vector_size, vector_index.metric, vector_index.label - )); - } - result } @@ -939,20 +893,7 @@ impl ResourceSetupStatusCheck for SetupStatusCheck { } async fn apply_change(&self) -> Result<()> { - let build_composite_field_names = |qualifier: &str, field_names: &[String]| -> String { - let strs = field_names - .iter() - .map(|name| format!("{qualifier}.{name}")) - .join(", "); - if field_names.len() == 1 { - strs - } else { - format!("({})", strs) - } - }; - let graph = self.graph_pool.get_graph(&self.conn_spec).await?; - if let Some(data_clear) = &self.data_clear { let delete_rel_query = neo4rs::query(&format!( r#" @@ -980,94 +921,6 @@ impl ResourceSetupStatusCheck for SetupStatusCheck { graph.run(delete_node_query).await?; } } - - for name in - (self.rel_constraint_to_delete.iter()).chain(self.node_constraint_to_delete.iter()) - { - graph - .run(neo4rs::query(&format!("DROP CONSTRAINT {name} IF EXISTS"))) - .await?; - } - for name in (self.rel_index_to_delete.iter()).chain(self.node_index_to_delete.iter()) { - graph - .run(neo4rs::query(&format!("DROP INDEX {name} IF EXISTS"))) - .await?; - } - - for (name, constraint) in self.node_constraint_to_create.iter() { - graph - .run(neo4rs::query(&format!("DROP CONSTRAINT {name} IF EXISTS"))) - .await?; - graph - .run(neo4rs::query(&format!( - "CREATE CONSTRAINT {name} IF NOT EXISTS FOR (n:{label}) REQUIRE {field_names} IS UNIQUE", - label = constraint.label, - field_names = build_composite_field_names("n", &constraint.field_names) - ))) - .await?; - } - - for (name, constraint) in self.rel_constraint_to_create.iter() { - graph - .run(neo4rs::query(&format!("DROP CONSTRAINT {name} IF EXISTS"))) - .await?; - graph - .run(neo4rs::query(&format!( - "CREATE CONSTRAINT {name} IF NOT EXISTS FOR ()-[e:{label}]-() REQUIRE {field_names} IS UNIQUE", - label = constraint.label, - field_names = build_composite_field_names("e", &constraint.field_names) - ))) - .await?; - } - - let build_create_vector_index_query = |name: &str, - index_state: &VectorIndexState, - matcher: &str, - arg_name: &str| - -> Result { - let metric = match index_state.metric { - spec::VectorSimilarityMetric::CosineSimilarity => "cosine", - spec::VectorSimilarityMetric::L2Distance => "euclidean", - _ => api_bail!( - "Unsupported vector similarity metric in Neo4j: {}", - index_state.metric - ), - }; - let query = format!( - r#"CREATE VECTOR INDEX {name} IF NOT EXISTS FOR {matcher} ON {arg_name}.{field_name} OPTIONS - {{ indexConfig: {{`vector.dimensions`: {vector_size}, `vector.similarity_function`: '{metric}'}}}}"#, - field_name = index_state.field_name, - vector_size = index_state.vector_size, - ); - Ok(query) - }; - for (name, vector_index) in self.rel_index_to_create.iter() { - graph - .run(neo4rs::query(&format!("DROP INDEX {name} IF EXISTS"))) - .await?; - graph - .run(neo4rs::query(&build_create_vector_index_query( - name, - vector_index, - &format!("()-[r:{}]-()", vector_index.label), - "r", - )?)) - .await?; - } - for (name, vector_index) in self.node_index_to_create.iter() { - graph - .run(neo4rs::query(&format!("DROP INDEX {name} IF EXISTS"))) - .await?; - graph - .run(neo4rs::query(&build_create_vector_index_query( - name, - vector_index, - &format!("(n:{})", vector_index.label), - "n", - )?)) - .await?; - } - Ok(()) } } @@ -1240,13 +1093,22 @@ impl StorageFactoryBase for RelationshipFactory { auth_registry: &Arc, ) -> Result { let conn_spec = auth_registry.get::(&key.connection)?; - Ok(SetupStatusCheck::new( + let base = SetupStatusCheck::new( key, self.graph_pool.clone(), - conn_spec, + conn_spec.clone(), + desired.as_ref(), + &existing, + ); + let comp = components::StatusCheck::create( + SetupComponentOperator { + graph_pool: self.graph_pool.clone(), + conn_spec: conn_spec.clone(), + }, desired, existing, - )) + )?; + Ok(components::combine_status_checks(base, comp)) } fn check_state_compatibility( diff --git a/src/setup/components.rs b/src/setup/components.rs new file mode 100644 index 000000000..8d07ce429 --- /dev/null +++ b/src/setup/components.rs @@ -0,0 +1,203 @@ +use super::{CombinedState, ResourceSetupStatusCheck, SetupChangeType, StateChange}; +use crate::prelude::*; +use std::fmt::Debug; + +pub trait State: Debug + Send + Sync { + fn key(&self) -> Key; +} + +#[async_trait] +pub trait Operator { + type Key: Debug + Hash + Eq + Clone + Send + Sync; + type State: State; + type SetupState: Send + Sync + IntoIterator; + + fn describe_key(&self, key: &Self::Key) -> String; + + fn describe_state(&self, state: &Self::State) -> String; + + fn is_up_to_date(&self, current: &Self::State, desired: &Self::State) -> bool; + + async fn create(&self, state: &Self::State) -> Result<()>; + + async fn delete(&self, key: &Self::Key) -> Result<()>; + + async fn update(&self, state: &Self::State) -> Result<()> { + self.delete(&state.key()).await?; + self.create(state).await + } +} + +#[derive(Debug)] +struct CompositeStateUpsert { + state: S, + already_exists: bool, +} + +#[derive(Derivative)] +#[derivative(Debug)] +pub struct StatusCheck { + #[derivative(Debug = "ignore")] + desc: D, + keys_to_delete: IndexSet, + states_to_upsert: Vec>, +} + +impl StatusCheck { + pub fn create( + desc: D, + desired: Option, + existing: CombinedState, + ) -> Result { + let existing_component_states = CombinedState { + current: existing.current.map(|s| { + s.into_iter() + .map(|s| (s.key(), s)) + .collect::>() + }), + staging: existing + .staging + .into_iter() + .map(|s| match s { + StateChange::Delete => StateChange::Delete, + StateChange::Upsert(s) => { + StateChange::Upsert(s.into_iter().map(|s| (s.key(), s)).collect()) + } + }) + .collect(), + }; + let mut keys_to_delete = IndexSet::new(); + let mut states_to_upsert = vec![]; + + // Collect all existing component keys + for c in existing_component_states.possible_versions() { + keys_to_delete.extend(c.keys().cloned()); + } + + if let Some(desired_state) = desired { + for desired_comp_state in desired_state { + let key = desired_comp_state.key(); + + // Remove keys that should be kept from deletion list + keys_to_delete.shift_remove(&key); + + // Add components that need to be updated + let is_up_to_date = existing_component_states.always_exists() + && existing_component_states.possible_versions().all(|v| { + v.get(&key) + .map_or(false, |s| desc.is_up_to_date(s, &desired_comp_state)) + }); + if !is_up_to_date { + let already_exists = existing_component_states + .possible_versions() + .any(|v| v.contains_key(&key)); + states_to_upsert.push(CompositeStateUpsert { + state: desired_comp_state, + already_exists, + }); + } + } + } + + Ok(Self { + desc, + keys_to_delete, + states_to_upsert, + }) + } +} + +#[async_trait] +impl ResourceSetupStatusCheck for StatusCheck { + fn describe_changes(&self) -> Vec { + let mut result = vec![]; + + for key in &self.keys_to_delete { + result.push(format!("Delete {}", self.desc.describe_key(key))); + } + + for state in &self.states_to_upsert { + result.push(format!( + "{} {}", + if state.already_exists { + "Update" + } else { + "Create" + }, + self.desc.describe_state(&state.state) + )); + } + + result + } + + fn change_type(&self) -> SetupChangeType { + if self.keys_to_delete.is_empty() && self.states_to_upsert.is_empty() { + SetupChangeType::NoChange + } else if self.keys_to_delete.is_empty() { + SetupChangeType::Create + } else if self.states_to_upsert.is_empty() { + SetupChangeType::Delete + } else { + SetupChangeType::Update + } + } + + async fn apply_change(&self) -> Result<()> { + // First delete components that need to be removed + for key in &self.keys_to_delete { + self.desc.delete(key).await?; + } + + // Then upsert components that need to be updated + for state in &self.states_to_upsert { + if state.already_exists { + self.desc.update(&state.state).await?; + } else { + self.desc.create(&state.state).await?; + } + } + + Ok(()) + } +} + +#[derive(Debug)] +struct CombinedStatusCheck { + a: A, + b: B, +} + +#[async_trait] +impl ResourceSetupStatusCheck + for CombinedStatusCheck +{ + fn describe_changes(&self) -> Vec { + let mut result = vec![]; + result.extend(self.a.describe_changes()); + result.extend(self.b.describe_changes()); + result + } + + fn change_type(&self) -> SetupChangeType { + match (self.a.change_type(), self.b.change_type()) { + (SetupChangeType::Invalid, _) | (_, SetupChangeType::Invalid) => { + SetupChangeType::Invalid + } + (SetupChangeType::NoChange, b) => b, + (a, _) => a, + } + } + + async fn apply_change(&self) -> Result<()> { + self.a.apply_change().await?; + self.b.apply_change().await + } +} + +pub fn combine_status_checks( + a: A, + b: B, +) -> impl ResourceSetupStatusCheck { + CombinedStatusCheck { a, b } +} diff --git a/src/setup/mod.rs b/src/setup/mod.rs index 7aeeab876..6acf7a3c3 100644 --- a/src/setup/mod.rs +++ b/src/setup/mod.rs @@ -3,6 +3,8 @@ mod db_metadata; mod driver; mod states; +pub mod components; + pub use auth_registry::AuthRegistry; pub use driver::*; pub use states::*; diff --git a/src/setup/states.rs b/src/setup/states.rs index a97a6da64..6e546b52d 100644 --- a/src/setup/states.rs +++ b/src/setup/states.rs @@ -38,12 +38,12 @@ impl StateMode for DesiredMode { } #[derive(Debug, Clone)] -pub struct CombinedState { +pub struct CombinedState { pub current: Option, pub staging: Vec>, } -impl CombinedState { +impl CombinedState { pub fn possible_versions(&self) -> impl Iterator { self.current .iter() @@ -293,7 +293,8 @@ impl std::fmt::Display for ResourceSetupInfo< impl ResourceSetupInfo { pub fn is_up_to_date(&self) -> bool { self.status_check - .as_ref().is_none_or(|c| c.change_type() == SetupChangeType::NoChange) + .as_ref() + .is_none_or(|c| c.change_type() == SetupChangeType::NoChange) } }