diff --git a/src/execution/db_tracking.rs b/src/execution/db_tracking.rs index 871a9201c..bc7fc3ef4 100644 --- a/src/execution/db_tracking.rs +++ b/src/execution/db_tracking.rs @@ -3,12 +3,80 @@ use crate::prelude::*; use super::{db_tracking_setup::TrackingTableSetupState, memoization::StoredMemoizationInfo}; use crate::utils::{db::WriteAction, fingerprint::Fingerprint}; use futures::Stream; +use serde::de::{self, Deserializer, SeqAccess, Visitor}; +use serde::ser::SerializeSeq; use sqlx::PgPool; +use std::fmt; + +#[derive(Debug, Clone)] +pub struct TrackedTargetKeyInfo { + pub key: serde_json::Value, + pub additional_key: serde_json::Value, + pub process_ordinal: i64, + // None means deletion. + pub fingerprint: Option, +} + +impl Serialize for TrackedTargetKeyInfo { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut seq = serializer.serialize_seq(None)?; + seq.serialize_element(&self.key)?; + seq.serialize_element(&self.process_ordinal)?; + seq.serialize_element(&self.fingerprint)?; + if !self.additional_key.is_null() { + seq.serialize_element(&self.additional_key)?; + } + seq.end() + } +} + +impl<'de> serde::Deserialize<'de> for TrackedTargetKeyInfo { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct TrackedTargetKeyVisitor; + + impl<'de> Visitor<'de> for TrackedTargetKeyVisitor { + type Value = TrackedTargetKeyInfo; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence of 3 or 4 elements for TrackedTargetKey") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let target_key: serde_json::Value = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let process_ordinal: i64 = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + let fingerprint: Option = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(2, &self))?; + let additional_key: Option = seq.next_element()?; + + Ok(TrackedTargetKeyInfo { + key: target_key, + process_ordinal, + fingerprint, + additional_key: additional_key.unwrap_or(serde_json::Value::Null), + }) + } + } + + deserializer.deserialize_seq(TrackedTargetKeyVisitor) + } +} -/// (target_key, process_ordinal, fingerprint) -pub type TrackedTargetKey = (serde_json::Value, i64, Option); /// (source_id, target_key) -pub type TrackedTargetKeyForSource = Vec<(i32, Vec)>; +pub type TrackedTargetKeyForSource = Vec<(i32, Vec)>; #[derive(sqlx::FromRow, Debug)] pub struct SourceTrackingInfoForProcessing { @@ -80,7 +148,8 @@ pub async fn precommit_source_tracking_info( let query_str = match action { WriteAction::Insert => format!( "INSERT INTO {} (source_id, source_key, max_process_ordinal, staging_target_keys, memoization_info) VALUES ($1, $2, $3, $4, $5)", - db_setup.table_name), + db_setup.table_name + ), WriteAction::Update => format!( "UPDATE {} SET max_process_ordinal = $3, staging_target_keys = $4, memoization_info = $5 WHERE source_id = $1 AND source_key = $2", db_setup.table_name @@ -205,9 +274,9 @@ impl ListTrackedSourceKeyMetadataState { pool: &'a PgPool, ) -> impl Stream> + 'a { self.query_str = format!( - "SELECT source_key, processed_source_ordinal, process_logic_fingerprint FROM {} WHERE source_id = $1", - db_setup.table_name - ); + "SELECT source_key, processed_source_ordinal, process_logic_fingerprint FROM {} WHERE source_id = $1", + db_setup.table_name + ); sqlx::query_as(&self.query_str).bind(source_id).fetch(pool) } } diff --git a/src/execution/row_indexer.rs b/src/execution/row_indexer.rs index d37ac5842..fb2f655b1 100644 --- a/src/execution/row_indexer.rs +++ b/src/execution/row_indexer.rs @@ -4,7 +4,7 @@ use futures::future::try_join_all; use sqlx::PgPool; use std::collections::{HashMap, HashSet}; -use super::db_tracking::{self, TrackedTargetKey, read_source_tracking_info_for_processing}; +use super::db_tracking::{self, TrackedTargetKeyInfo, read_source_tracking_info_for_processing}; use super::db_tracking_setup; use super::evaluator::{ EvaluateSourceEntryOutput, SourceRowEvaluationContext, evaluate_source_entry, @@ -119,6 +119,12 @@ pub enum SkippedOr { Skipped(SourceVersion), } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct TargetKeyPair { + pub key: serde_json::Value, + pub additional_key: serde_json::Value, +} + #[derive(Default)] struct TrackingInfoForTarget<'a> { export_op: Option<&'a AnalyzedExportOp>, @@ -126,11 +132,11 @@ struct TrackingInfoForTarget<'a> { // Existing keys info. Keyed by target key. // Will be removed after new rows for the same key are added into `new_staging_keys_info` and `mutation.upserts`, // hence all remaining ones are to be deleted. - existing_staging_keys_info: HashMap)>>, - existing_keys_info: HashMap)>>, + existing_staging_keys_info: HashMap)>>, + existing_keys_info: HashMap)>>, // New keys info for staging. - new_staging_keys_info: Vec, + new_staging_keys_info: Vec, // Mutation to apply to the target storage. mutation: ExportTargetMutation, @@ -208,9 +214,12 @@ async fn precommit_source_tracking_info( for key_info in keys_info.into_iter() { target_info .existing_staging_keys_info - .entry(key_info.0) + .entry(TargetKeyPair { + key: key_info.key, + additional_key: key_info.additional_key, + }) .or_default() - .push((key_info.1, key_info.2)); + .push((key_info.process_ordinal, key_info.fingerprint)); } } @@ -220,9 +229,12 @@ async fn precommit_source_tracking_info( for key_info in keys_info.into_iter() { target_info .existing_keys_info - .entry(key_info.0) + .entry(TargetKeyPair { + key: key_info.key, + additional_key: key_info.additional_key, + }) .or_default() - .push((key_info.1, key_info.2)); + .push((key_info.process_ordinal, key_info.fingerprint)); } } } @@ -249,22 +261,24 @@ async fn precommit_source_tracking_info( .fields .push(value.fields[*field as usize].clone()); } - let existing_target_keys = target_info.existing_keys_info.remove(&primary_key_json); + let additional_key = export_op.export_target_factory.extract_additional_key( + &primary_key, + &field_values, + export_op.export_context.as_ref(), + )?; + let target_key_pair = TargetKeyPair { + key: primary_key_json, + additional_key, + }; + let existing_target_keys = target_info.existing_keys_info.remove(&target_key_pair); let existing_staging_target_keys = target_info .existing_staging_keys_info - .remove(&primary_key_json); + .remove(&target_key_pair); - let upsert_entry = export_op.export_target_factory.prepare_upsert_entry( - ExportTargetUpsertEntry { - key: primary_key, - value: field_values, - }, - export_op.export_context.as_ref(), - )?; let curr_fp = if !export_op.value_stable { Some( Fingerprinter::default() - .with(&upsert_entry.value)? + .with(&field_values)? .into_fingerprint(), ) } else { @@ -285,16 +299,29 @@ async fn precommit_source_tracking_info( .into_iter() .next() .ok_or_else(invariance_violation)?; - keys_info.push((primary_key_json, existing_ordinal, existing_fp)); + keys_info.push(TrackedTargetKeyInfo { + key: target_key_pair.key, + additional_key: target_key_pair.additional_key, + process_ordinal: existing_ordinal, + fingerprint: existing_fp, + }); } else { // Entry with new value. Needs to be upserted. - target_info.mutation.upserts.push(upsert_entry); - target_info.new_staging_keys_info.push(( - primary_key_json.clone(), + let tracked_target_key = TrackedTargetKeyInfo { + key: target_key_pair.key.clone(), + additional_key: target_key_pair.additional_key.clone(), process_ordinal, - curr_fp, - )); - keys_info.push((primary_key_json, process_ordinal, curr_fp)); + fingerprint: curr_fp, + }; + target_info.mutation.upserts.push(ExportTargetUpsertEntry { + key: primary_key, + additional_key: target_key_pair.additional_key, + value: field_values, + }); + target_info + .new_staging_keys_info + .push(tracked_target_key.clone()); + keys_info.push(tracked_target_key); } } new_target_keys_info.push((export_op.target_id, keys_info)); @@ -304,7 +331,7 @@ async fn precommit_source_tracking_info( let mut new_staging_target_keys = db_tracking::TrackedTargetKeyForSource::default(); let mut target_mutations = HashMap::with_capacity(export_ops.len()); for (target_id, target_tracking_info) in tracking_info_for_targets.into_iter() { - let legacy_keys: HashSet = target_tracking_info + let legacy_keys: HashSet = target_tracking_info .existing_keys_info .into_keys() .chain(target_tracking_info.existing_staging_keys_info.into_keys()) @@ -312,24 +339,27 @@ async fn precommit_source_tracking_info( let mut new_staging_keys_info = target_tracking_info.new_staging_keys_info; // Add tracking info for deletions. - new_staging_keys_info.extend( - legacy_keys - .iter() - .map(|key| ((*key).clone(), process_ordinal, None)), - ); + new_staging_keys_info.extend(legacy_keys.iter().map(|key| TrackedTargetKeyInfo { + key: key.key.clone(), + additional_key: key.additional_key.clone(), + process_ordinal, + fingerprint: None, + })); new_staging_target_keys.push((target_id, new_staging_keys_info)); if let Some(export_op) = target_tracking_info.export_op { let mut mutation = target_tracking_info.mutation; - mutation.delete_keys.reserve(legacy_keys.len()); + mutation.deletes.reserve(legacy_keys.len()); for legacy_key in legacy_keys.into_iter() { - mutation.delete_keys.push( - value::Value::::from_json( - legacy_key, - &export_op.primary_key_type, - )? - .as_key()?, - ); + let key = value::Value::::from_json( + legacy_key.key, + &export_op.primary_key_type, + )? + .as_key()?; + mutation.deletes.push(interface::ExportTargetDeleteEntry { + key, + additional_key: legacy_key.additional_key, + }); } target_mutations.insert(target_id, mutation); } @@ -398,9 +428,10 @@ async fn commit_source_tracking_info( .filter_map(|(target_id, target_keys)| { let cleaned_target_keys: Vec<_> = target_keys .into_iter() - .filter(|(_, ordinal, _)| { - Some(*ordinal) > precommit_metadata.existing_process_ordinal - && *ordinal != precommit_metadata.process_ordinal + .filter(|key_info| { + Some(key_info.process_ordinal) + > precommit_metadata.existing_process_ordinal + && key_info.process_ordinal != precommit_metadata.process_ordinal }) .collect(); if !cleaned_target_keys.is_empty() { diff --git a/src/ops/factory_bases.rs b/src/ops/factory_bases.rs index ef06252bd..c7b9567ae 100644 --- a/src/ops/factory_bases.rs +++ b/src/ops/factory_bases.rs @@ -327,12 +327,13 @@ pub trait StorageFactoryBase: ExportTargetFactory + Send + Sync + 'static { fn describe_resource(&self, key: &Self::Key) -> Result; - fn prepare_upsert_entry<'ctx>( + fn extract_additional_key<'ctx>( &self, - entry: ExportTargetUpsertEntry, + _key: &value::KeyValue, + _value: &value::FieldValues, _export_context: &'ctx Self::ExportContext, - ) -> Result { - Ok(entry) + ) -> Result { + Ok(serde_json::Value::Null) } fn register(self, registry: &mut ExecutorFactoryRegistry) -> Result<()> @@ -459,14 +460,16 @@ impl ExportTargetFactory for T { Ok(result) } - fn prepare_upsert_entry<'ctx>( + fn extract_additional_key<'ctx>( &self, - entry: ExportTargetUpsertEntry, + key: &value::KeyValue, + value: &value::FieldValues, export_context: &'ctx (dyn Any + Send + Sync), - ) -> Result { - StorageFactoryBase::prepare_upsert_entry( + ) -> Result { + StorageFactoryBase::extract_additional_key( self, - entry, + key, + value, export_context .downcast_ref::() .ok_or_else(invariance_violation)?, diff --git a/src/ops/interface.rs b/src/ops/interface.rs index 69cef0327..e935064f9 100644 --- a/src/ops/interface.rs +++ b/src/ops/interface.rs @@ -195,18 +195,25 @@ pub trait SimpleFunctionFactory { #[derive(Debug)] pub struct ExportTargetUpsertEntry { pub key: KeyValue, + pub additional_key: serde_json::Value, pub value: FieldValues, } +#[derive(Debug)] +pub struct ExportTargetDeleteEntry { + pub key: KeyValue, + pub additional_key: serde_json::Value, +} + #[derive(Debug, Default)] pub struct ExportTargetMutation { pub upserts: Vec, - pub delete_keys: Vec, + pub deletes: Vec, } impl ExportTargetMutation { pub fn is_empty(&self) -> bool { - self.upserts.is_empty() && self.delete_keys.is_empty() + self.upserts.is_empty() && self.deletes.is_empty() } } @@ -286,11 +293,12 @@ pub trait ExportTargetFactory: Send + Sync { fn describe_resource(&self, key: &serde_json::Value) -> Result; - fn prepare_upsert_entry<'ctx>( + fn extract_additional_key<'ctx>( &self, - entry: ExportTargetUpsertEntry, + key: &KeyValue, + value: &FieldValues, export_context: &'ctx (dyn Any + Send + Sync), - ) -> Result; + ) -> Result; async fn apply_mutation( &self, diff --git a/src/ops/storages/neo4j.rs b/src/ops/storages/neo4j.rs index 8b1915aef..0311e9c64 100644 --- a/src/ops/storages/neo4j.rs +++ b/src/ops/storages/neo4j.rs @@ -1056,8 +1056,8 @@ impl StorageFactoryBase for Factory { } for mut_with_ctx in muts.iter().rev() { let export_ctx = &mut_with_ctx.export_context; - for delete_key in mut_with_ctx.mutation.delete_keys.iter() { - export_ctx.add_delete_queries(delete_key, &mut queries)?; + for deletion in mut_with_ctx.mutation.deletes.iter() { + export_ctx.add_delete_queries(&deletion.key, &mut queries)?; } } let mut txn = graph.start_txn().await?; diff --git a/src/ops/storages/postgres.rs b/src/ops/storages/postgres.rs index 11da51e3f..46bafedb5 100644 --- a/src/ops/storages/postgres.rs +++ b/src/ops/storages/postgres.rs @@ -388,17 +388,17 @@ impl ExportContext { async fn delete( &self, - delete_keys: &[KeyValue], + deletions: &[interface::ExportTargetDeleteEntry], txn: &mut sqlx::PgTransaction<'_>, ) -> Result<()> { // TODO: Find a way to batch delete. - for delete_key in delete_keys.iter() { + for deletion in deletions.iter() { let mut query_builder = sqlx::QueryBuilder::new(""); query_builder.push(&self.delete_sql_prefix); for (i, (schema, value)) in self .key_fields_schema .iter() - .zip(key_value_fields_iter(&self.key_fields_schema, delete_key)?.iter()) + .zip(key_value_fields_iter(&self.key_fields_schema, &deletion.key)?.iter()) .enumerate() { if i > 0 { @@ -912,7 +912,7 @@ impl StorageFactoryBase for Factory { for mut_group in mut_groups.iter() { mut_group .export_context - .delete(&mut_group.mutation.delete_keys, &mut txn) + .delete(&mut_group.mutation.deletes, &mut txn) .await?; } txn.commit().await?; diff --git a/src/ops/storages/qdrant.rs b/src/ops/storages/qdrant.rs index 978a93774..a140a85be 100644 --- a/src/ops/storages/qdrant.rs +++ b/src/ops/storages/qdrant.rs @@ -78,9 +78,9 @@ impl ExportContext { } let ids = mutation - .delete_keys + .deletes .iter() - .map(key_to_point_id) + .map(|deletion| key_to_point_id(&deletion.key)) .collect::>>()?; if !ids.is_empty() {