diff --git a/src/execution/live_updater.rs b/src/execution/live_updater.rs index 1e6632a4a..fa3a2358a 100644 --- a/src/execution/live_updater.rs +++ b/src/execution/live_updater.rs @@ -31,6 +31,39 @@ struct StatsReportState { const MIN_REPORT_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5); const REPORT_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10); +struct SharedAckFn { + count: usize, + ack_fn: Option BoxFuture<'static, Result<()>> + Send + Sync>>, +} + +impl SharedAckFn { + fn new( + count: usize, + ack_fn: Box BoxFuture<'static, Result<()>> + Send + Sync>, + ) -> Self { + Self { + count, + ack_fn: Some(ack_fn), + } + } + + async fn ack(v: &Mutex) -> Result<()> { + let ack_fn = { + let mut v = v.lock().unwrap(); + v.count -= 1; + if v.count > 0 { + None + } else { + v.ack_fn.take() + } + }; + if let Some(ack_fn) = ack_fn { + ack_fn().await?; + } + Ok(()) + } +} + async fn update_source( flow_ctx: Arc, plan: Arc, @@ -92,13 +125,50 @@ async fn update_source( futs.push( async move { let mut change_stream = change_stream; - while let Some(change) = change_stream.next().await { - tokio::spawn(source_context.clone().process_source_key( - change.key, - change.data, - source_update_stats.clone(), - pool.clone(), - )); + let retry_options = retriable::RetryOptions { + max_retries: None, + initial_backoff: std::time::Duration::from_secs(5), + max_backoff: std::time::Duration::from_secs(60), + }; + loop { + // Workaround as AsyncFnMut isn't mature yet. + // Should be changed to use AsyncFnMut once it is. + let change_stream = tokio::sync::Mutex::new(&mut change_stream); + let change_msg = retriable::run( + || async { + let mut change_stream = change_stream.lock().await; + change_stream + .next() + .await + .transpose() + .map_err(retriable::Error::always_retryable) + }, + &retry_options, + ) + .await?; + let change_msg = if let Some(change_msg) = change_msg { + change_msg + } else { + break; + }; + let ack_fn = change_msg.ack_fn.map(|ack_fn| { + Arc::new(Mutex::new(SharedAckFn::new( + change_msg.changes.iter().len(), + ack_fn, + ))) + }); + for change in change_msg.changes { + let ack_fn = ack_fn.clone(); + tokio::spawn(source_context.clone().process_source_key( + change.key, + change.data, + source_update_stats.clone(), + ack_fn.map(|ack_fn| { + move || async move { SharedAckFn::ack(&ack_fn).await } + }), + pool.clone(), + )); + } } Ok(()) } diff --git a/src/execution/source_indexer.rs b/src/execution/source_indexer.rs index 4f646fac0..2d8f0ac7f 100644 --- a/src/execution/source_indexer.rs +++ b/src/execution/source_indexer.rs @@ -1,5 +1,6 @@ use crate::prelude::*; +use futures::future::Ready; use sqlx::PgPool; use std::collections::{hash_map, HashMap}; use tokio::{sync::Semaphore, task::JoinSet}; @@ -36,6 +37,8 @@ pub struct SourceIndexingContext { state: Mutex, } +pub const NO_ACK: Option Ready>> = None; + impl SourceIndexingContext { pub async fn load( flow: Arc, @@ -79,11 +82,15 @@ impl SourceIndexingContext { }) } - pub async fn process_source_key( + pub async fn process_source_key< + AckFut: Future> + Send + 'static, + AckFn: FnOnce() -> AckFut, + >( self: Arc, key: value::KeyValue, source_data: Option, update_stats: Arc, + ack_fn: Option, pool: PgPool, ) { let process = async { @@ -173,11 +180,20 @@ impl SourceIndexingContext { } } drop(permit); + if let Some(ack_fn) = ack_fn { + ack_fn().await?; + } anyhow::Ok(()) }; if let Err(e) = process.await { update_stats.num_errors.inc(1); - error!("{:?}", e.context("Error in processing a source row")); + error!( + "{:?}", + e.context(format!( + "Error in processing row from source `{source}` with key: {key}", + source = self.flow.flow_instance.import_ops[self.source_idx].name + )) + ); } } @@ -203,7 +219,7 @@ impl SourceIndexingContext { } Some( self.clone() - .process_source_key(key, None, update_stats.clone(), pool.clone()), + .process_source_key(key, None, update_stats.clone(), NO_ACK, pool.clone()), ) } @@ -269,6 +285,7 @@ impl SourceIndexingContext { key, source_data, update_stats.clone(), + NO_ACK, pool.clone(), )); } diff --git a/src/ops/interface.rs b/src/ops/interface.rs index 93022664a..cae5776f2 100644 --- a/src/ops/interface.rs +++ b/src/ops/interface.rs @@ -95,6 +95,11 @@ pub struct SourceChange { pub data: Option, } +pub struct SourceChangeMessage { + pub changes: Vec, + pub ack_fn: Option BoxFuture<'static, Result<()>> + Send + Sync>>, +} + #[derive(Debug, Default)] pub struct SourceExecutorListOptions { pub include_ordinal: bool, @@ -141,7 +146,9 @@ pub trait SourceExecutor: Send + Sync { options: &SourceExecutorGetOptions, ) -> Result; - async fn change_stream(&self) -> Result>> { + async fn change_stream( + &self, + ) -> Result>>> { Ok(None) } } diff --git a/src/ops/sources/amazon_s3.rs b/src/ops/sources/amazon_s3.rs index ffdb5a11d..c9719b949 100644 --- a/src/ops/sources/amazon_s3.rs +++ b/src/ops/sources/amazon_s3.rs @@ -3,7 +3,6 @@ use async_stream::try_stream; use aws_config::BehaviorVersion; use aws_sdk_s3::Client; use globset::{Glob, GlobSet, GlobSetBuilder}; -use log::warn; use std::sync::Arc; use crate::base::field_attrs; @@ -23,6 +22,20 @@ struct SqsContext { client: aws_sdk_sqs::Client, queue_url: String, } + +impl SqsContext { + async fn delete_message(&self, receipt_handle: String) -> Result<()> { + error!("Deleting message: {}", receipt_handle); + self.client + .delete_message() + .queue_url(&self.queue_url) + .receipt_handle(receipt_handle) + .send() + .await?; + Ok(()) + } +} + struct Executor { client: Client, bucket_name: String, @@ -152,7 +165,9 @@ impl SourceExecutor for Executor { Ok(PartialSourceRowData { value, ordinal }) } - async fn change_stream(&self) -> Result>> { + async fn change_stream( + &self, + ) -> Result>>> { let sqs_context = if let Some(sqs_context) = &self.sqs_context { sqs_context } else { @@ -160,16 +175,16 @@ impl SourceExecutor for Executor { }; let stream = stream! { loop { - let changes = match self.poll_sqs(&sqs_context).await { - Ok(changes) => changes, + match self.poll_sqs(&sqs_context).await { + Ok(messages) => { + for message in messages { + yield Ok(message); + } + } Err(e) => { - warn!("Failed to poll SQS: {}", e); - continue; + yield Err(e); } }; - for change in changes { - yield change; - } } }; Ok(Some(stream.boxed())) @@ -206,7 +221,7 @@ pub struct S3Object { } impl Executor { - async fn poll_sqs(&self, sqs_context: &Arc) -> Result> { + async fn poll_sqs(&self, sqs_context: &Arc) -> Result> { let resp = sqs_context .client .receive_message() @@ -220,36 +235,53 @@ impl Executor { } else { return Ok(Vec::new()); }; - let mut changes = vec![]; - for message in messages.into_iter().filter_map(|m| m.body) { - let notification: S3EventNotification = serde_json::from_str(&message)?; - for record in notification.records { - let s3 = if let Some(s3) = record.s3 { - s3 - } else { - continue; - }; - if s3.bucket.name != self.bucket_name { - continue; - } - if !self - .prefix - .as_ref() - .map_or(true, |prefix| s3.object.key.starts_with(prefix)) - { - continue; + let mut change_messages = vec![]; + for message in messages.into_iter() { + if let Some(body) = message.body { + let notification: S3EventNotification = serde_json::from_str(&body)?; + let mut changes = vec![]; + for record in notification.records { + let s3 = if let Some(s3) = record.s3 { + s3 + } else { + continue; + }; + if s3.bucket.name != self.bucket_name { + continue; + } + if !self + .prefix + .as_ref() + .map_or(true, |prefix| s3.object.key.starts_with(prefix)) + { + continue; + } + if record.event_name.starts_with("ObjectCreated:") + || record.event_name.starts_with("ObjectDeleted:") + { + changes.push(SourceChange { + key: KeyValue::Str(s3.object.key.into()), + data: None, + }); + } } - if record.event_name.starts_with("ObjectCreated:") - || record.event_name.starts_with("ObjectDeleted:") - { - changes.push(SourceChange { - key: KeyValue::Str(s3.object.key.into()), - data: None, - }); + if let Some(receipt_handle) = message.receipt_handle { + if !changes.is_empty() { + let sqs_context = sqs_context.clone(); + change_messages.push(SourceChangeMessage { + changes, + ack_fn: Some(Box::new(move || { + async move { sqs_context.delete_message(receipt_handle).await } + .boxed() + })), + }); + } else { + sqs_context.delete_message(receipt_handle).await?; + } } } } - Ok(changes) + Ok(change_messages) } } diff --git a/src/ops/sources/google_drive.rs b/src/ops/sources/google_drive.rs index 2d901d518..703f0ae49 100644 --- a/src/ops/sources/google_drive.rs +++ b/src/ops/sources/google_drive.rs @@ -190,7 +190,7 @@ impl Executor { async fn get_recent_updates( &self, cutoff_time: &mut DateTime, - ) -> Result> { + ) -> Result { let mut page_size: i32 = 10; let mut next_page_token: Option = None; let mut changes = Vec::new(); @@ -234,7 +234,10 @@ impl Executor { page_size = 100; } *cutoff_time = Self::make_cutoff_time(most_recent_modified_time, start_time); - Ok(changes) + Ok(SourceChangeMessage { + changes, + ack_fn: None, + }) } async fn is_file_covered(&self, file_id: &str) -> Result { @@ -416,7 +419,9 @@ impl SourceExecutor for Executor { Ok(PartialSourceRowData { value, ordinal }) } - async fn change_stream(&self) -> Result>> { + async fn change_stream( + &self, + ) -> Result>>> { let poll_interval = if let Some(poll_interval) = self.recent_updates_poll_interval { poll_interval } else { @@ -428,17 +433,7 @@ impl SourceExecutor for Executor { let stream = stream! { loop { interval.tick().await; - let changes = self.get_recent_updates(&mut cutoff_time).await; - match changes { - Ok(changes) => { - for change in changes { - yield change; - } - } - Err(e) => { - error!("Error getting recent updates: {e}"); - } - } + yield self.get_recent_updates(&mut cutoff_time).await; } }; Ok(Some(stream.boxed())) diff --git a/src/ops/storages/neo4j.rs b/src/ops/storages/neo4j.rs index 4084aed88..5e4416ede 100644 --- a/src/ops/storages/neo4j.rs +++ b/src/ops/storages/neo4j.rs @@ -1289,11 +1289,12 @@ impl StorageFactoryBase for Factory { .or_insert_with(Vec::new) .push(mut_with_ctx); } + let retry_options = retriable::RetryOptions::default(); for muts in muts_by_graph.values_mut() { muts.sort_by_key(|m| m.export_context.create_order); let graph = &muts[0].export_context.graph; retriable::run( - || async { + async || { let mut queries = vec![]; for mut_with_ctx in muts.iter() { let export_ctx = &mut_with_ctx.export_context; @@ -1312,7 +1313,7 @@ impl StorageFactoryBase for Factory { txn.commit().await?; retriable::Ok(()) }, - retriable::RunOptions::default(), + &retry_options, ) .await .map_err(Into::::into)? diff --git a/src/utils/retriable.rs b/src/utils/retriable.rs index 8c25eb892..f8115f948 100644 --- a/src/utils/retriable.rs +++ b/src/utils/retriable.rs @@ -28,6 +28,15 @@ impl IsRetryable for Error { } } +impl Error { + pub fn always_retryable(error: anyhow::Error) -> Self { + Self { + error, + is_retryable: true, + } + } +} + impl From for Error { fn from(error: anyhow::Error) -> Self { Self { @@ -59,16 +68,16 @@ pub fn Ok(value: T) -> Result { Result::Ok(value) } -pub struct RunOptions { - pub max_retries: usize, +pub struct RetryOptions { + pub max_retries: Option, pub initial_backoff: Duration, pub max_backoff: Duration, } -impl Default for RunOptions { +impl Default for RetryOptions { fn default() -> Self { Self { - max_retries: 5, + max_retries: Some(5), initial_backoff: Duration::from_millis(100), max_backoff: Duration::from_secs(10), } @@ -82,7 +91,7 @@ pub async fn run< F: Fn() -> Fut, >( f: F, - options: RunOptions, + options: &RetryOptions, ) -> Result { let mut retries = 0; let mut backoff = options.initial_backoff; @@ -91,7 +100,11 @@ pub async fn run< match f().await { Result::Ok(result) => return Result::Ok(result), Result::Err(err) => { - if !err.is_retryable() || retries >= options.max_retries { + if !err.is_retryable() + || options + .max_retries + .map_or(false, |max_retries| retries >= max_retries) + { return Result::Err(err); } retries += 1;