Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions src/ops/factory_bases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,11 @@ pub struct TypedExportDataCollectionSpec<F: StorageFactoryBase + ?Sized> {
pub index_options: IndexOptions,
}

pub struct TypedResourceSetupChangeItem<'a, F: StorageFactoryBase + ?Sized> {
pub key: F::Key,
pub setup_status: &'a F::SetupStatus,
}

#[async_trait]
pub trait StorageFactoryBase: ExportTargetFactory + Send + Sync + 'static {
type Spec: DeserializeOwned + Send + Sync;
Expand Down Expand Up @@ -339,7 +344,8 @@ pub trait StorageFactoryBase: ExportTargetFactory + Send + Sync + 'static {

async fn apply_setup_changes(
&self,
setup_status: Vec<&'async_trait Self::SetupStatus>,
setup_status: Vec<TypedResourceSetupChangeItem<'async_trait, Self>>,
auth_registry: &Arc<AuthRegistry>,
) -> Result<()>;
}

Expand Down Expand Up @@ -466,18 +472,25 @@ impl<T: StorageFactoryBase> ExportTargetFactory for T {

async fn apply_setup_changes(
&self,
setup_status: Vec<&'async_trait dyn ResourceSetupStatus>,
setup_status: Vec<ResourceSetupChangeItem<'async_trait>>,
auth_registry: &Arc<AuthRegistry>,
) -> Result<()> {
StorageFactoryBase::apply_setup_changes(
self,
setup_status
.into_iter()
.map(|s| -> anyhow::Result<_> {
Ok(s.as_any()
.downcast_ref::<T::SetupStatus>()
.ok_or_else(|| anyhow!("Unexpected setup status type"))?)
.map(|item| -> anyhow::Result<_> {
Ok(TypedResourceSetupChangeItem {
key: serde_json::from_value(item.key.clone())?,
setup_status: item
.setup_status
.as_any()
.downcast_ref::<T::SetupStatus>()
.ok_or_else(|| anyhow!("Unexpected setup status type"))?,
})
})
.collect::<Result<Vec<_>>>()?,
auth_registry,
)
.await
}
Expand Down
8 changes: 7 additions & 1 deletion src/ops/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,11 @@ pub struct ExportTargetMutationWithContext<'ctx, T: ?Sized + Send + Sync> {
pub export_context: &'ctx T,
}

pub struct ResourceSetupChangeItem<'a> {
pub key: &'a serde_json::Value,
pub setup_status: &'a dyn setup::ResourceSetupStatus,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SetupStateCompatibility {
/// The resource is fully compatible with the desired state.
Expand Down Expand Up @@ -288,7 +293,8 @@ pub trait ExportTargetFactory: Send + Sync {

async fn apply_setup_changes(
&self,
setup_status: Vec<&'async_trait dyn setup::ResourceSetupStatus>,
setup_status: Vec<ResourceSetupChangeItem<'async_trait>>,
auth_registry: &Arc<AuthRegistry>,
) -> Result<()>;
}

Expand Down
84 changes: 45 additions & 39 deletions src/ops/storages/neo4j.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub struct GraphPool {
}

impl GraphPool {
pub async fn get_graph(&self, spec: &ConnectionSpec) -> Result<Arc<Graph>> {
async fn get_graph(&self, spec: &ConnectionSpec) -> Result<Arc<Graph>> {
let graph_key = GraphKey::from_spec(spec);
let cell = {
let mut graphs = self.graphs.lock().unwrap();
Expand All @@ -87,6 +87,15 @@ impl GraphPool {
.await?;
Ok(graph.clone())
}

async fn get_graph_for_key(
&self,
key: &Neo4jGraphElement,
auth_registry: &AuthRegistry,
) -> Result<Arc<Graph>> {
let spec = auth_registry.get::<ConnectionSpec>(&key.connection)?;
self.get_graph(&spec).await
}
}

pub struct ExportContext {
Expand Down Expand Up @@ -798,19 +807,12 @@ fn build_composite_field_names(qualifier: &str, field_names: &[String]) -> Strin
}
#[derive(Debug)]
pub struct GraphElementDataSetupStatus {
key: Neo4jGraphElement,
conn_spec: ConnectionSpec,
data_clear: Option<DataClearAction>,
change_type: SetupChangeType,
}

impl GraphElementDataSetupStatus {
fn new(
key: Neo4jGraphElement,
conn_spec: ConnectionSpec,
desired_state: Option<&SetupState>,
existing: &CombinedState<SetupState>,
) -> Self {
fn new(desired_state: Option<&SetupState>, existing: &CombinedState<SetupState>) -> Self {
let mut data_clear: Option<DataClearAction> = None;
for v in existing.possible_versions() {
if desired_state.as_ref().is_none_or(|desired| {
Expand All @@ -837,8 +839,6 @@ impl GraphElementDataSetupStatus {
};

Self {
key,
conn_spec,
data_clear,
change_type,
}
Expand Down Expand Up @@ -1024,8 +1024,7 @@ impl StorageFactoryBase for Factory {
auth_registry: &Arc<AuthRegistry>,
) -> Result<Self::SetupStatus> {
let conn_spec = auth_registry.get::<ConnectionSpec>(&key.connection)?;
let data_status =
GraphElementDataSetupStatus::new(key, conn_spec.clone(), desired.as_ref(), &existing);
let data_status = GraphElementDataSetupStatus::new(desired.as_ref(), &existing);
let components = components::SetupStatus::create(
SetupComponentOperator {
graph_pool: self.graph_pool.clone(),
Expand Down Expand Up @@ -1094,51 +1093,58 @@ impl StorageFactoryBase for Factory {

async fn apply_setup_changes(
&self,
changes: Vec<&'async_trait Self::SetupStatus>,
changes: Vec<TypedResourceSetupChangeItem<'async_trait, Self>>,
auth_registry: &Arc<AuthRegistry>,
) -> Result<()> {
let (data_statuses, components): (Vec<_>, Vec<_>) =
changes.into_iter().map(|c| (&c.0, &c.1)).unzip();

// Relationships first, then nodes, as relationships need to be deleted before nodes they referenced.
let mut relationship_types = IndexMap::<&Neo4jGraphElement, &ConnectionSpec>::new();
let mut node_labels = IndexMap::<&Neo4jGraphElement, &ConnectionSpec>::new();
let mut dependent_node_labels = IndexMap::<Neo4jGraphElement, &ConnectionSpec>::new();
for data_status in data_statuses.iter() {
if let Some(data_clear) = &data_status.data_clear {
match &data_status.key.typ {
let mut relationship_types = IndexSet::<&Neo4jGraphElement>::new();
let mut node_labels = IndexSet::<&Neo4jGraphElement>::new();
let mut dependent_node_labels = IndexSet::<Neo4jGraphElement>::new();

let mut components = vec![];
for change in changes.iter() {
if let Some(data_clear) = &change.setup_status.0.data_clear {
match &change.key.typ {
ElementType::Relationship(_) => {
relationship_types.insert(&data_status.key, &data_status.conn_spec);
relationship_types.insert(&change.key);
for label in &data_clear.dependent_node_labels {
dependent_node_labels.insert(
Neo4jGraphElement {
connection: data_status.key.connection.clone(),
typ: ElementType::Node(label.clone()),
},
&data_status.conn_spec,
);
dependent_node_labels.insert(Neo4jGraphElement {
connection: change.key.connection.clone(),
typ: ElementType::Node(label.clone()),
});
}
}
ElementType::Node(_) => {
node_labels.insert(&data_status.key, &data_status.conn_spec);
node_labels.insert(&change.key);
}
}
}
components.push(&change.setup_status.1);
}

// Relationships have no dependency, so can be cleared first.
for (rel_type, conn_spec) in relationship_types.iter() {
let graph = self.graph_pool.get_graph(conn_spec).await?;
for rel_type in relationship_types.into_iter() {
let graph = self
.graph_pool
.get_graph_for_key(rel_type, auth_registry)
.await?;
clear_graph_element_data(&graph, rel_type, true).await?;
}
// Clear standalone nodes, which is simpler than dependent nodes.
for (node_label, conn_spec) in node_labels.iter() {
let graph = self.graph_pool.get_graph(conn_spec).await?;
for node_label in node_labels.iter() {
let graph = self
.graph_pool
.get_graph_for_key(node_label, auth_registry)
.await?;
clear_graph_element_data(&graph, node_label, true).await?;
}
// Clear dependent nodes if they're not covered by standalone nodes.
for (node_label, conn_spec) in dependent_node_labels.iter() {
if !node_labels.contains_key(node_label) {
let graph = self.graph_pool.get_graph(conn_spec).await?;
for node_label in dependent_node_labels.iter() {
if !node_labels.contains(node_label) {
let graph = self
.graph_pool
.get_graph_for_key(node_label, auth_registry)
.await?;
clear_graph_element_data(&graph, node_label, false).await?;
}
}
Expand Down
51 changes: 20 additions & 31 deletions src/ops/storages/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -601,20 +601,12 @@ pub struct TableSetupAction {

#[derive(Debug)]
pub struct SetupStatus {
db_pool: PgPool,
table_name: String,

create_pgvector_extension: bool,
actions: TableSetupAction,
}

impl SetupStatus {
fn new(
db_pool: PgPool,
table_name: String,
desired_state: Option<SetupState>,
existing: setup::CombinedState<SetupState>,
) -> Self {
fn new(desired_state: Option<SetupState>, existing: setup::CombinedState<SetupState>) -> Self {
let table_action = TableMainSetupAction::from_states(desired_state.as_ref(), &existing);
let (indexes_to_delete, indexes_to_create) = desired_state
.as_ref()
Expand Down Expand Up @@ -647,8 +639,6 @@ impl SetupStatus {
&& !existing.current.map(|s| s.uses_pgvector()).unwrap_or(false);

Self {
db_pool,
table_name,
create_pgvector_extension,
actions: TableSetupAction {
table_action,
Expand Down Expand Up @@ -726,21 +716,20 @@ impl setup::ResourceSetupStatus for SetupStatus {
}

impl SetupStatus {
async fn apply_change(&self) -> Result<()> {
let table_name = &self.table_name;
async fn apply_change(&self, db_pool: &PgPool, table_name: &str) -> Result<()> {
if self.actions.table_action.drop_existing {
sqlx::query(&format!("DROP TABLE IF EXISTS {table_name}"))
.execute(&self.db_pool)
.execute(db_pool)
.await?;
}
if self.create_pgvector_extension {
sqlx::query("CREATE EXTENSION IF NOT EXISTS vector;")
.execute(&self.db_pool)
.execute(db_pool)
.await?;
}
for index_name in self.actions.indexes_to_delete.iter() {
let sql = format!("DROP INDEX IF EXISTS {}", index_name);
sqlx::query(&sql).execute(&self.db_pool).await?;
sqlx::query(&sql).execute(db_pool).await?;
}
if let Some(table_upsertion) = &self.actions.table_action.table_upsertion {
match table_upsertion {
Expand All @@ -752,7 +741,7 @@ impl SetupStatus {
fields.join(", "),
keys.keys().join(", ")
);
sqlx::query(&sql).execute(&self.db_pool).await?;
sqlx::query(&sql).execute(db_pool).await?;
}
TableUpsertionAction::Update {
columns_to_delete,
Expand All @@ -762,13 +751,13 @@ impl SetupStatus {
let sql = format!(
"ALTER TABLE {table_name} DROP COLUMN IF EXISTS {column_name}",
);
sqlx::query(&sql).execute(&self.db_pool).await?;
sqlx::query(&sql).execute(db_pool).await?;
}
for (column_name, column_type) in columns_to_upsert.iter() {
let sql = format!(
"ALTER TABLE {table_name} DROP COLUMN IF EXISTS {column_name}, ADD COLUMN {column_name} {column_type}"
);
sqlx::query(&sql).execute(&self.db_pool).await?;
sqlx::query(&sql).execute(db_pool).await?;
}
}
}
Expand All @@ -778,7 +767,7 @@ impl SetupStatus {
"CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} {}",
to_index_spec_sql(index_spec)
);
sqlx::query(&sql).execute(&self.db_pool).await?;
sqlx::query(&sql).execute(db_pool).await?;
}
Ok(())
}
Expand Down Expand Up @@ -873,17 +862,12 @@ impl StorageFactoryBase for Factory {

async fn check_setup_status(
&self,
key: TableId,
_key: TableId,
desired: Option<SetupState>,
existing: setup::CombinedState<SetupState>,
auth_registry: &Arc<AuthRegistry>,
_auth_registry: &Arc<AuthRegistry>,
) -> Result<SetupStatus> {
Ok(SetupStatus::new(
get_db_pool(key.database.as_ref(), auth_registry).await?,
key.table_name,
desired,
existing,
))
Ok(SetupStatus::new(desired, existing))
}

fn check_state_compatibility(
Expand Down Expand Up @@ -938,10 +922,15 @@ impl StorageFactoryBase for Factory {

async fn apply_setup_changes(
&self,
setup_status: Vec<&'async_trait Self::SetupStatus>,
changes: Vec<TypedResourceSetupChangeItem<'async_trait, Self>>,
auth_registry: &Arc<AuthRegistry>,
) -> Result<()> {
for setup_status in setup_status.iter() {
setup_status.apply_change().await?;
for change in changes.iter() {
let db_pool = get_db_pool(change.key.database.as_ref(), &auth_registry).await?;
change
.setup_status
.apply_change(&db_pool, &change.key.table_name)
.await?;
}
Ok(())
}
Expand Down
3 changes: 2 additions & 1 deletion src/ops/storages/qdrant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,8 @@ impl StorageFactoryBase for Arc<Factory> {

async fn apply_setup_changes(
&self,
_setup_status: Vec<&'async_trait Self::SetupStatus>,
_setup_status: Vec<TypedResourceSetupChangeItem<'async_trait, Self>>,
_auth_registry: &Arc<AuthRegistry>,
) -> Result<()> {
Err(anyhow!("Qdrant does not support setup changes"))
}
Expand Down
Loading