diff --git a/.gitignore b/.gitignore index c33cf137..e9044e9b 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,8 @@ # Editors .DS_Store + +# Sqlite artifacts +*.sqlite +*.sqlite-shm +*.sqlite-wal diff --git a/Cargo.lock b/Cargo.lock index b69dd523..13e5b678 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -40,9 +40,9 @@ dependencies = [ [[package]] name = "allocator-api2" -version = "0.2.18" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" [[package]] name = "android-tzdata" @@ -110,9 +110,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.92" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f37166d7d48a0284b99dd824694c26119c700b53bf0d1540cdb147dbdaaf13" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" [[package]] name = "async-stream" @@ -304,9 +304,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.34" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67b9470d453346108f93a59222a9a1a5724db32d0a4727b7ab7ace4b4d822dc9" +checksum = "1aeb932158bd710538c73702db6945cb68a8fb08c519e6e12706b94263b36db8" dependencies = [ "shlex", ] @@ -419,9 +419,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" +checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6" dependencies = [ "libc", ] @@ -583,9 +583,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" +checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" [[package]] name = "figment" @@ -662,6 +662,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -706,6 +721,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -724,8 +750,10 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -1223,9 +1251,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.161" +version = "0.2.162" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" +checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" [[package]] name = "libm" @@ -1904,9 +1932,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -1996,9 +2024,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.38" +version = "0.38.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa260229e6538e52293eeb577aabd09945a09d6d9cc0fc550ed7529056c2e32a" +checksum = "99e4ea3e1cdc4b559b8e5650f9c8e5998e3e5c1343b4eaf034565f32318d63c0" dependencies = [ "bitflags", "errno", @@ -2188,9 +2216,9 @@ dependencies = [ [[package]] name = "sentry_protos" -version = "0.1.33" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3481ffe84b71db2796128c7177480ba1b8616dfe8ec714a8a7f98d3a721f523" +checksum = "b7add659cc42d7ba16389b8d67606663a47a8753b836149b9682554d14063ab2" dependencies = [ "glob", "prost", @@ -2203,18 +2231,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.214" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" +checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.214" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" +checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" dependencies = [ "proc-macro2", "quote", @@ -2651,9 +2679,11 @@ name = "taskbroker" version = "0.1.0" dependencies = [ "anyhow", + "async-stream", "chrono", "clap", "figment", + "futures", "metrics", "metrics-exporter-statsd", "prost", @@ -2666,15 +2696,17 @@ dependencies = [ "serde_yaml", "sqlx", "tokio", + "tokio-stream", + "tokio-util", "tracing", "tracing-subscriber", ] [[package]] name = "tempfile" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" +checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" dependencies = [ "cfg-if", "fastrand", @@ -2685,18 +2717,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.68" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02dd99dc800bbb97186339685293e1cc5d9df1f8fae2d0aecd9ff1c77efea892" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.68" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7c61ec9a6f64d2793d8a45faba21efbe3ced62a886d44c36a009b2b519b4c7e" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", @@ -2771,9 +2803,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.41.0" +version = "1.41.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" +checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" dependencies = [ "backtrace", "bytes", @@ -2817,6 +2849,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 1e293c7c..d0b17d76 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,21 +6,25 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -sentry_protos = "0.1.33" +sentry_protos = "0.1.34" +prost = "0.13" +prost-types = "0.13.3" anyhow = "1.0.92" chrono = { version = "0.4.26" } -sqlx = { version = "0.8.2", features = ["sqlite", "runtime-tokio", "chrono"] } -prost = "0.13" tokio = { version = "1.41.0", features = ["full"] } -prost-types = "0.13.3" +tokio-util = "0.7.12" +tokio-stream = { version = "0.1.16", features = ["full"] } +async-stream = "0.3.5" +futures = "0.3.31" rdkafka = { version = "0.36.2", features = ["cmake-build"] } serde = "1.0.214" serde_yaml = "0.9.34" figment = { version = "0.10.19", features = ["env", "yaml", "test"] } +sqlx = { version = "0.8.2", features = ["sqlite", "runtime-tokio", "chrono"] } clap = { version = "4.5.20", features = ["derive"] } sentry = { version = "0.34.0", features = ["tracing"] } -tracing-subscriber = { version = "0.3.18", features = ["json"] } tracing = "0.1.40" +tracing-subscriber = { version = "0.3.18", features = ["json"] } metrics-exporter-statsd = "0.9.0" metrics = "0.24.0" diff --git a/migrations/0001_create_inflight_taskactivations.sql b/migrations/0001_create_inflight_taskactivations.sql index 9ef43c77..3e2c3dc6 100644 --- a/migrations/0001_create_inflight_taskactivations.sql +++ b/migrations/0001_create_inflight_taskactivations.sql @@ -1,6 +1,7 @@ CREATE TABLE IF NOT EXISTS inflight_taskactivations ( id UUID NOT NULL PRIMARY KEY, activation BLOB NOT NULL, + partition INTEGER NOT NULL, offset BIGINTEGER NOT NULL, added_at DATETIME NOT NULL, deadletter_at DATETIME, diff --git a/src/consumer/deserialize_activation.rs b/src/consumer/deserialize_activation.rs new file mode 100644 index 00000000..b2063a3b --- /dev/null +++ b/src/consumer/deserialize_activation.rs @@ -0,0 +1,33 @@ +use std::{sync::Arc, time::Duration}; + +use anyhow::{anyhow, Error}; +use chrono::Utc; +use prost::Message as _; +use rdkafka::{message::OwnedMessage, Message}; +use sentry_protos::sentry::v1::TaskActivation; + +use crate::inflight_activation_store::{InflightActivation, TaskActivationStatus}; + +pub struct Config { + pub deadletter_duration: Option, +} + +pub fn new(config: Config) -> impl Fn(Arc) -> Result { + move |msg: Arc| { + let Some(payload) = msg.payload() else { + return Err(anyhow!("Message has no payload")); + }; + let activation = TaskActivation::decode(payload)?; + Ok(InflightActivation { + activation, + status: TaskActivationStatus::Pending, + partition: msg.partition(), + offset: msg.offset(), + added_at: Utc::now(), + deadletter_at: config + .deadletter_duration + .map(|duration| Utc::now() + duration), + processing_deadline: None, + }) + } +} diff --git a/src/consumer/inflight_activation_writer.rs b/src/consumer/inflight_activation_writer.rs new file mode 100644 index 00000000..98f57d62 --- /dev/null +++ b/src/consumer/inflight_activation_writer.rs @@ -0,0 +1,84 @@ +use std::{mem::replace, sync::Arc, time::Duration}; + +use tracing::info; + +use crate::inflight_activation_store::{InflightActivation, InflightActivationStore}; + +use super::kafka::{ + ReduceConfig, ReduceShutdownBehaviour, ReduceShutdownCondition, Reducer, + ReducerWhenFullBehaviour, +}; + +pub struct Config { + pub max_buf_len: usize, + pub max_pending_activations: usize, + pub flush_interval: Option, + pub when_full_behaviour: ReducerWhenFullBehaviour, + pub shutdown_behaviour: ReduceShutdownBehaviour, +} + +pub struct InflightActivationWriter { + store: Arc, + buffer: Vec, + config: Config, +} + +impl InflightActivationWriter { + pub fn new(store: Arc, config: Config) -> Self { + Self { + store, + buffer: Vec::with_capacity(config.max_buf_len), + config, + } + } +} + +impl Reducer for InflightActivationWriter { + type Input = InflightActivation; + + type Output = (); + + async fn reduce(&mut self, t: Self::Input) -> Result<(), anyhow::Error> { + self.buffer.push(t); + Ok(()) + } + + async fn flush(&mut self) -> Result { + if self.buffer.is_empty() { + return Ok(()); + } + let res = self + .store + .store(replace( + &mut self.buffer, + Vec::with_capacity(self.config.max_buf_len), + )) + .await?; + info!("Inserted {:?} entries", res.rows_affected); + Ok(()) + } + + fn reset(&mut self) { + self.buffer.clear(); + } + + async fn is_full(&self) -> bool { + self.buffer.len() >= self.config.max_buf_len + || self + .store + .count_pending_activations() + .await + .expect("Error communicating with activation store") + + self.buffer.len() + >= self.config.max_pending_activations + } + + fn get_reduce_config(&self) -> ReduceConfig { + ReduceConfig { + shutdown_condition: ReduceShutdownCondition::Signal, + shutdown_behaviour: ReduceShutdownBehaviour::Flush, + when_full_behaviour: self.config.when_full_behaviour, + flush_interval: self.config.flush_interval, + } + } +} diff --git a/src/consumer/kafka.rs b/src/consumer/kafka.rs new file mode 100644 index 00000000..02468df1 --- /dev/null +++ b/src/consumer/kafka.rs @@ -0,0 +1,1759 @@ +use anyhow::{anyhow, Error}; +use futures::{ + future::{self}, + pin_mut, Stream, StreamExt, +}; +use rdkafka::{ + consumer::{ + stream_consumer::StreamPartitionQueue, Consumer, ConsumerContext, Rebalance, StreamConsumer, + }, + error::{KafkaError, KafkaResult}, + message::{BorrowedMessage, OwnedMessage}, + ClientConfig, ClientContext, Message, Offset, TopicPartitionList, +}; +use std::{ + cmp, + collections::{BTreeSet, HashMap}, + fmt::Debug, + future::Future, + iter, + mem::take, + sync::{ + mpsc::{sync_channel, SyncSender}, + Arc, + }, + time::Duration, +}; +use tokio::{ + select, signal, + sync::{ + mpsc::{self, unbounded_channel, UnboundedReceiver, UnboundedSender}, + oneshot, + }, + task::JoinSet, + time::{self, sleep, MissedTickBehavior}, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_util::{either::Either, sync::CancellationToken}; +use tracing::{debug, error, info, instrument}; + +pub async fn start_consumer( + topics: &[&str], + kafka_client_config: &ClientConfig, + spawn_actors: impl FnMut( + Arc>, + &BTreeSet<(String, i32)>, + ) -> ActorHandles, +) -> Result<(), Error> { + let (client_shutdown_sender, client_shutdown_receiver) = oneshot::channel(); + let (event_sender, event_receiver) = unbounded_channel(); + + let context = KafkaContext::new(event_sender.clone()); + + let consumer: Arc> = Arc::new( + kafka_client_config + .create_with_context(context) + .expect("Consumer creation failed"), + ); + + consumer + .subscribe(topics) + .expect("Can't subscribe to specified topics"); + + handle_os_signals(event_sender.clone()); + handle_consumer_client(consumer.clone(), client_shutdown_receiver); + handle_events( + consumer, + event_receiver, + client_shutdown_sender, + spawn_actors, + ) + .await +} + +pub fn handle_os_signals(event_sender: UnboundedSender<(Event, SyncSender<()>)>) { + tokio::spawn(async move { + let _ = signal::ctrl_c().await; + let (rendezvous_sender, _) = sync_channel(0); + let _ = event_sender.send((Event::Shutdown, rendezvous_sender)); + }); +} + +#[instrument(skip(consumer, shutdown))] +pub fn handle_consumer_client( + consumer: Arc>, + shutdown: oneshot::Receiver<()>, +) { + tokio::spawn(async move { + select! { + biased; + _ = shutdown => { + debug!("Received shutdown signal, commiting state in sync mode..."); + let _ = consumer.commit_consumer_state(rdkafka::consumer::CommitMode::Sync); + } + msg = consumer.recv() => { + error!("Got unexpected message from consumer client: {:?}", msg); + } + } + debug!("Shutdown complete"); + }); +} + +#[derive(Debug)] +pub struct KafkaContext { + event_sender: UnboundedSender<(Event, SyncSender<()>)>, +} + +impl KafkaContext { + pub fn new(event_sender: UnboundedSender<(Event, SyncSender<()>)>) -> Self { + Self { event_sender } + } +} + +impl ClientContext for KafkaContext {} + +impl ConsumerContext for KafkaContext { + #[instrument(skip(self, rebalance))] + fn pre_rebalance(&self, rebalance: &Rebalance) { + let (rendezvous_sender, rendezvous_receiver) = sync_channel(0); + match rebalance { + Rebalance::Assign(tpl) => { + debug!("Got pre-rebalance callback, kind: Assign"); + let _ = self.event_sender.send(( + Event::Assign(tpl.to_topic_map().keys().cloned().collect()), + rendezvous_sender, + )); + info!("Partition assignment event sent, waiting for rendezvous..."); + let _ = rendezvous_receiver.recv(); + info!("Rendezvous complete"); + } + Rebalance::Revoke(tpl) => { + debug!("Got pre-rebalance callback, kind: Revoke"); + let _ = self.event_sender.send(( + Event::Revoke(tpl.to_topic_map().keys().cloned().collect()), + rendezvous_sender, + )); + info!("Parition assignment event sent, waiting for rendezvous..."); + let _ = rendezvous_receiver.recv(); + info!("Rendezvous complete"); + } + Rebalance::Error(err) => { + debug!("Got pre-rebalance callback, kind: Error"); + error!("Got rebalance error: {}", err); + } + } + } + + #[instrument(skip(self))] + fn commit_callback(&self, result: KafkaResult<()>, _offsets: &TopicPartitionList) { + debug!("Got commit callback"); + } +} + +#[derive(Debug)] +pub enum Event { + Assign(BTreeSet<(String, i32)>), + Revoke(BTreeSet<(String, i32)>), + Shutdown, +} + +#[derive(Debug)] +pub struct ActorHandles { + pub join_set: JoinSet>, + pub shutdown: CancellationToken, + pub rendezvous: oneshot::Receiver<()>, +} + +impl ActorHandles { + #[instrument(skip(self))] + async fn shutdown(mut self, deadline: Duration) { + debug!("Signaling shutdown to actors..."); + self.shutdown.cancel(); + info!("Actor shutdown signaled, waiting for rendezvous..."); + + select! { + _ = self.rendezvous => { + info!("Rendezvous complete within callback deadline."); + } + _ = sleep(deadline) => { + error!( + "Unable to rendezvous within callback deadline, \ + aborting all tasks within JoinSet" + ); + self.join_set.abort_all(); + } + } + } +} + +#[macro_export] +macro_rules! processing_strategy { + ( + @reducers, + ($reduce:expr), + $prev_receiver:ident, + $err_sender:ident, + $shutdown_signal:ident, + $handles:ident, + ) => {{ + let (commit_sender, commit_receiver) = tokio::sync::mpsc::channel(1); + + $handles.spawn($crate::consumer::kafka::reduce( + $reduce, + $prev_receiver, + commit_sender.clone(), + $err_sender.clone(), + $shutdown_signal.clone(), + )); + + (commit_sender, commit_receiver) + }}; + ( + @reducers, + ($reduce_first:expr $(,$reduce_rest:expr)+), + $prev_receiver:ident, + $err_sender:ident, + $shutdown_signal:ident, + $handles:ident, + ) => {{ + let (sender, receiver) = tokio::sync::mpsc::channel(1); + + $handles.spawn($crate::reduce( + $reduce_first, + $prev_receiver, + sender.clone(), + $err_sender.clone(), + $shutdown_signal.clone(), + )); + + processing_strategy!( + @reducers, + ($($reduce_rest),+), + receiver, + $err_sender, + $shutdown_signal, + $handles, + ) + }}; + ( + { + map: $map_fn:expr, + reduce: $reduce_first:expr $(=> $reduce_rest:expr)*, + err: $reduce_err:expr, + } + ) => {{ + |consumer: Arc>, + tpl: &std::collections::BTreeSet<(String, i32)>| + -> $crate::consumer::kafka::ActorHandles { + let start = std::time::Instant::now(); + + let mut handles = tokio::task::JoinSet::new(); + let shutdown_signal = tokio_util::sync::CancellationToken::new(); + + let (rendezvous_sender, rendezvous_receiver) = tokio::sync::oneshot::channel(); + + const CHANNEL_BUFF_SIZE: usize = 1024; + let (map_sender, reduce_receiver) = tokio::sync::mpsc::channel(CHANNEL_BUFF_SIZE); + let (err_sender, err_receiver) = tokio::sync::mpsc::channel(CHANNEL_BUFF_SIZE); + + for (topic, partition) in tpl.iter() { + let queue = consumer + .split_partition_queue(topic, *partition) + .expect("Unable to split topic by parition"); + + handles.spawn($crate::consumer::kafka::map( + queue, + $map_fn, + map_sender.clone(), + err_sender.clone(), + shutdown_signal.clone(), + )); + } + + let (commit_sender, commit_receiver) = $crate::processing_strategy!( + @reducers, + ($reduce_first $(,$reduce_rest)*), + reduce_receiver, + err_sender, + shutdown_signal, + handles, + ); + + handles.spawn($crate::consumer::kafka::commit( + commit_receiver, + consumer.clone(), + rendezvous_sender, + )); + + handles.spawn($crate::consumer::kafka::reduce_err( + $reduce_err, + err_receiver, + commit_sender.clone(), + shutdown_signal.clone(), + )); + + tracing::debug!("Creating actors took {:?}", start.elapsed()); + + $crate::consumer::kafka::ActorHandles { + join_set: handles, + shutdown: shutdown_signal, + rendezvous: rendezvous_receiver, + } + } + }}; +} + +#[derive(Debug)] +enum ConsumerState { + Ready, + Consuming(ActorHandles, BTreeSet<(String, i32)>), + Stopped, +} + +#[instrument(skip(consumer, events, shutdown_client, spawn_actors))] +pub async fn handle_events( + consumer: Arc>, + events: UnboundedReceiver<(Event, SyncSender<()>)>, + shutdown_client: oneshot::Sender<()>, + mut spawn_actors: impl FnMut( + Arc>, + &BTreeSet<(String, i32)>, + ) -> ActorHandles, +) -> Result<(), anyhow::Error> { + const CALLBACK_DURATION: Duration = Duration::from_secs(4); + + let mut shutdown_client = Some(shutdown_client); + let mut events_stream = UnboundedReceiverStream::new(events); + + let mut state = ConsumerState::Ready; + + while let ConsumerState::Ready { .. } | ConsumerState::Consuming { .. } = state { + let Some((event, _rendezvous_guard)) = events_stream.next().await else { + unreachable!("Unexpected end to event stream") + }; + info!("Recieved event: {:?}", event); + state = match (state, event) { + (ConsumerState::Ready, Event::Assign(tpl)) => { + ConsumerState::Consuming(spawn_actors(consumer.clone(), &tpl), tpl) + } + (ConsumerState::Ready, Event::Revoke(_)) => { + unreachable!("Got partition revocation before the consumer has started") + } + (ConsumerState::Ready, Event::Shutdown) => ConsumerState::Stopped, + (ConsumerState::Consuming(actor_handles, mut tpl), Event::Assign(mut assigned_tpl)) => { + assert!( + tpl.is_disjoint(&assigned_tpl), + "Newly assigned TPL should be disjoint from TPL we're consuming from" + ); + tpl.append(&mut assigned_tpl); + debug!( + "{} additional topic partitions added after assignment", + assigned_tpl.len() + ); + actor_handles.shutdown(CALLBACK_DURATION).await; + ConsumerState::Consuming(spawn_actors(consumer.clone(), &tpl), tpl) + } + (ConsumerState::Consuming(actor_handles, mut tpl), Event::Revoke(revoked_tpl)) => { + assert!( + tpl.is_subset(&revoked_tpl), + "Revoked TPL should be a subset of TPL we're consuming from" + ); + tpl.retain(|e| !revoked_tpl.contains(e)); + debug!("{} topic partitions remaining after revocation", tpl.len()); + actor_handles.shutdown(CALLBACK_DURATION).await; + if tpl.is_empty() { + ConsumerState::Ready + } else { + ConsumerState::Consuming(spawn_actors(consumer.clone(), &tpl), tpl) + } + } + (ConsumerState::Consuming(actor_handles, _), Event::Shutdown) => { + actor_handles.shutdown(CALLBACK_DURATION).await; + debug!("Signaling shutdown to client..."); + shutdown_client.take(); + ConsumerState::Stopped + } + (ConsumerState::Stopped, _) => { + unreachable!("Got event after consumer has stopped") + } + } + } + debug!("Shutdown complete"); + Ok(()) +} + +pub trait KafkaMessage { + fn detach(&self) -> Result; +} + +impl KafkaMessage for Result, KafkaError> { + fn detach(&self) -> Result { + match self { + Ok(borrowed_msg) => Ok(borrowed_msg.detach()), + Err(err) => Err(anyhow!( + "Cannot detach message, got error from kafka: {:?}", + err + )), + } + } +} + +pub trait MessageQueue { + fn stream(&self) -> impl Stream; +} + +impl MessageQueue for StreamPartitionQueue { + fn stream(&self) -> impl Stream { + self.stream() + } +} + +#[instrument(skip(queue, transform, ok, err, shutdown))] +pub async fn map( + queue: impl MessageQueue, + transform: impl Fn(Arc) -> Result, + ok: mpsc::Sender<(iter::Once, T)>, + err: mpsc::Sender, + shutdown: CancellationToken, +) -> Result<(), Error> { + let stream = queue.stream(); + pin_mut!(stream); + + loop { + select! { + biased; + + _ = shutdown.cancelled() => { + debug!("Receive shutdown signal, shutting down..."); + break; + } + + val = stream.next() => { + let Some(msg) = val else { + break; + }; + let msg = Arc::new(msg.detach()?); + match transform(msg.clone()) { + Ok(transformed) => { + if ok.send(( + iter::once( + Arc::try_unwrap(msg) + .expect("msg should only have a single strong ref"), + ), + transformed, + )).await.is_err() { + debug!("Receive half of ok channel is closed, shutting down..."); + break; + } + } + Err(e) => { + error!( + "Failed to map message at \ + (topic: {}, partition: {}, offset: {}), reason: {}", + msg.topic(), + msg.partition(), + msg.offset(), + e, + ); + err.send( + Arc::try_unwrap(msg).expect("msg should only have a single strong ref"), + ) + .await + .expect("reduce_err is not available"); + } + } + } + } + } + debug!("Shutdown complete"); + Ok(()) +} + +#[derive(Debug, Clone)] +pub struct ReduceConfig { + pub shutdown_condition: ReduceShutdownCondition, + pub shutdown_behaviour: ReduceShutdownBehaviour, + pub when_full_behaviour: ReducerWhenFullBehaviour, + pub flush_interval: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReduceShutdownCondition { + Signal, + Drain, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReduceShutdownBehaviour { + Flush, + Drop, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReducerWhenFullBehaviour { + Flush, + Backpressure, +} + +pub trait Reducer { + type Input; + type Output; + + fn reduce(&mut self, t: Self::Input) -> impl Future> + Send; + fn flush(&mut self) -> impl Future> + Send; + fn reset(&mut self); + fn is_full(&self) -> impl Future + Send; + fn get_reduce_config(&self) -> ReduceConfig; +} + +async fn handle_reducer_failure( + reducer: &mut impl Reducer, + inflight_msgs: &mut Vec, + err: &mpsc::Sender, +) { + for msg in take(inflight_msgs).into_iter() { + err.send(msg).await.expect("reduce_err is not available"); + } + reducer.reset(); +} + +#[instrument(skip(reducer, inflight_msgs, ok, err))] +async fn flush_reducer( + reducer: &mut impl Reducer, + inflight_msgs: &mut Vec, + ok: &mpsc::Sender<(Vec, U)>, + err: &mpsc::Sender, +) -> Result<(), Error> { + match reducer.flush().await { + Err(e) => { + error!("Failed to flush reducer, reason: {}", e); + handle_reducer_failure(reducer, inflight_msgs, err).await; + } + Ok(result) => { + if !inflight_msgs.is_empty() { + ok.send((take(inflight_msgs), result)) + .await + .map_err(|err| anyhow!("{}", err))?; + } + } + } + Ok(()) +} + +#[instrument(skip(reducer, receiver, ok, err, shutdown))] +pub async fn reduce( + mut reducer: impl Reducer, + mut receiver: mpsc::Receiver<(impl IntoIterator, T)>, + ok: mpsc::Sender<(Vec, U)>, + err: mpsc::Sender, + shutdown: CancellationToken, +) -> Result<(), Error> { + let config = reducer.get_reduce_config(); + let mut flush_timer = config.flush_interval.map(time::interval); + let mut loop_timer = time::interval(Duration::from_secs(1)); + loop_timer.set_missed_tick_behavior(MissedTickBehavior::Delay); + let mut inflight_msgs = Vec::new(); + + loop { + select! { + biased; + + _ = if config.shutdown_condition == ReduceShutdownCondition::Signal { + Either::Left(shutdown.cancelled()) + } else { + Either::Right(future::pending::<_>()) + } => { + match config.shutdown_behaviour { + ReduceShutdownBehaviour::Flush => { + debug!("Received shutdown signal, flushing reducer..."); + flush_reducer(&mut reducer, &mut inflight_msgs, &ok, &err).await?; + } + ReduceShutdownBehaviour::Drop => { + debug!("Received shutdown signal, dropping reducer..."); + drop(reducer); + } + }; + break; + } + + _ = if let Some(ref mut flush_timer) = flush_timer { + Either::Left(flush_timer.tick()) + } else { + Either::Right(future::pending::<_>()) + } => { + flush_reducer(&mut reducer, &mut inflight_msgs, &ok, &err).await?; + } + + val = receiver.recv(), if !reducer.is_full().await => { + let Some((msg, value)) = val else { + assert_eq!( + config.shutdown_condition, + ReduceShutdownCondition::Drain, + "Got end of stream without shutdown signal" + ); + match config.shutdown_behaviour { + ReduceShutdownBehaviour::Flush => { + debug!("Received end of stream, flushing reducer..."); + flush_reducer(&mut reducer, &mut inflight_msgs, &ok, &err).await?; + } + ReduceShutdownBehaviour::Drop => { + debug!("Received end of stream, dropping reducer..."); + drop(reducer); + } + }; + break; + }; + + inflight_msgs.extend(msg); + + if let Err(e) = reducer.reduce(value).await { + error!( + "Failed to reduce message at \ + (topic: {}, partition: {}, offset: {}), reason: {}", + inflight_msgs.last().unwrap().topic(), + inflight_msgs.last().unwrap().partition(), + inflight_msgs.last().unwrap().offset(), + e, + ); + handle_reducer_failure(&mut reducer, &mut inflight_msgs, &err).await; + } + + if config.when_full_behaviour == ReducerWhenFullBehaviour::Flush + && reducer.is_full().await + { + flush_reducer(&mut reducer, &mut inflight_msgs, &ok, &err).await?; + } + } + + _ = loop_timer.tick() => { } + } + } + + debug!("Shutdown complete"); + Ok(()) +} + +#[instrument(skip(reducer, receiver, ok, shutdown))] +pub async fn reduce_err( + mut reducer: impl Reducer, + mut receiver: mpsc::Receiver, + ok: mpsc::Sender<(Vec, ())>, + shutdown: CancellationToken, +) -> Result<(), Error> { + let config = reducer.get_reduce_config(); + let mut flush_timer = config.flush_interval.map(time::interval); + let mut inflight_msgs = Vec::new(); + + loop { + select! { + biased; + + _ = shutdown.cancelled() => { + match config.shutdown_behaviour { + ReduceShutdownBehaviour::Flush => { + debug!("Received shutdown signal, flushing reducer..."); + reducer + .flush() + .await + .expect("Failed to flush error reducer"); + if !inflight_msgs.is_empty() { + ok.send((take(&mut inflight_msgs), ())) + .await + .map_err(|err| anyhow!("{}", err))?; + } + }, + ReduceShutdownBehaviour::Drop => { + debug!("Received shutdown signal, dropping reducer..."); + drop(reducer); + }, + } + break; + } + + _ = if let Some(ref mut flush_timer) = flush_timer { + Either::Left(flush_timer.tick()) + } else { + Either::Right(future::pending::<_>()) + } => { + reducer + .flush() + .await + .expect("Failed to flush error reducer"); + if !inflight_msgs.is_empty() { + ok.send((take(&mut inflight_msgs), ())) + .await + .map_err(|err| anyhow!("{}", err))?; + } + } + + val = receiver.recv(), if !reducer.is_full().await => { + let Some(msg) = val else { + unreachable!("Received end of stream without shutdown signal"); + }; + inflight_msgs.push(msg.clone()); + + reducer + .reduce(msg) + .await + .expect("Failed to reduce error reducer"); + + if matches!(config.when_full_behaviour, ReducerWhenFullBehaviour::Flush) + && reducer.is_full().await + { + reducer + .flush() + .await + .expect("Failed to flush error reducer"); + + if !inflight_msgs.is_empty() { + ok.send((take(&mut inflight_msgs), ())) + .await + .map_err(|err| anyhow!("{}", err))?; + } + } + } + } + } + + debug!("Shutdown complete"); + Ok(()) +} + +trait CommitClient { + fn store_offsets(&self, tpl: &TopicPartitionList) -> KafkaResult<()>; +} + +impl CommitClient for StreamConsumer { + fn store_offsets(&self, tpl: &TopicPartitionList) -> KafkaResult<()> { + Consumer::store_offsets(self, tpl) + } +} + +#[derive(Default)] +struct HighwaterMark { + data: HashMap<(String, i32), i64>, +} + +impl HighwaterMark { + fn new() -> Self { + Self { + data: HashMap::new(), + } + } + + fn track(&mut self, msg: &OwnedMessage) { + let cur_offset = self + .data + .entry((msg.topic().to_string(), msg.partition())) + .or_insert(msg.offset() + 1); + *cur_offset = cmp::max(*cur_offset, msg.offset() + 1); + } + + fn len(&self) -> usize { + self.data.len() + } +} + +impl From for TopicPartitionList { + fn from(val: HighwaterMark) -> Self { + let mut tpl = TopicPartitionList::with_capacity(val.len()); + for ((topic, partition), offset) in val.data.iter() { + tpl.add_partition_offset(topic, *partition, Offset::Offset(*offset)) + .expect("Invalid partition offset"); + } + tpl + } +} + +#[instrument(skip(receiver, consumer, _rendezvous_guard))] +pub async fn commit( + mut receiver: mpsc::Receiver<(Vec, ())>, + consumer: Arc, + _rendezvous_guard: oneshot::Sender<()>, +) -> Result<(), Error> { + while let Some(msgs) = receiver.recv().await { + debug!("Storing offsets"); + let mut highwater_mark = HighwaterMark::new(); + msgs.0.iter().for_each(|msg| highwater_mark.track(msg)); + consumer.store_offsets(&highwater_mark.into()).unwrap(); + } + debug!("Shutdown complete"); + Ok(()) +} + +#[cfg(test)] +mod tests { + use std::{ + collections::HashMap, + iter, + mem::take, + sync::{Arc, RwLock}, + time::Duration, + }; + + use anyhow::{anyhow, Error}; + use futures::Stream; + use rdkafka::{ + error::{KafkaError, KafkaResult}, + message::OwnedMessage, + Message, Offset, Timestamp, TopicPartitionList, + }; + use tokio::{ + sync::{broadcast, mpsc, oneshot}, + time::sleep, + }; + use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; + use tokio_util::sync::CancellationToken; + + use crate::consumer::kafka::{ + commit, map, reduce, reduce_err, CommitClient, KafkaMessage, MessageQueue, ReduceConfig, + ReduceShutdownBehaviour, ReduceShutdownCondition, Reducer, ReducerWhenFullBehaviour, + }; + + struct MockCommitClient { + offsets: Arc>>, + } + + impl CommitClient for MockCommitClient { + fn store_offsets(&self, tpl: &TopicPartitionList) -> KafkaResult<()> { + self.offsets.write().unwrap().push(tpl.clone()); + Ok(()) + } + } + + struct StreamingReducer { + data: Option, + pipe: Arc>>, + error_on_idx: Option, + } + + impl StreamingReducer { + fn new(error_on_idx: Option) -> Self { + Self { + data: None, + pipe: Arc::new(RwLock::new(Vec::new())), + error_on_idx, + } + } + + fn get_pipe(&self) -> Arc>> { + self.pipe.clone() + } + } + + impl Reducer for StreamingReducer + where + T: Send + Sync + Clone, + { + type Input = T; + + type Output = (); + + async fn reduce(&mut self, t: Self::Input) -> Result<(), anyhow::Error> { + if let Some(idx) = self.error_on_idx { + if idx == self.pipe.read().unwrap().len() { + self.error_on_idx.take(); + return Err(anyhow!("err")); + } + } + assert!(self.data.is_none()); + self.data = Some(t); + Ok(()) + } + + async fn flush(&mut self) -> Result<(), anyhow::Error> { + self.pipe.write().unwrap().push(self.data.take().unwrap()); + Ok(()) + } + + fn reset(&mut self) { + self.data.take(); + } + + async fn is_full(&self) -> bool { + self.data.is_some() + } + + fn get_reduce_config(&self) -> ReduceConfig { + ReduceConfig { + shutdown_condition: ReduceShutdownCondition::Signal, + shutdown_behaviour: ReduceShutdownBehaviour::Drop, + when_full_behaviour: ReducerWhenFullBehaviour::Flush, + flush_interval: None, + } + } + } + + struct BatchingReducer { + buffer: Arc>>, + pipe: Arc>>, + error_on_nth_reduce: Option, + error_on_nth_flush: Option, + shutdown_condition: ReduceShutdownCondition, + } + + impl BatchingReducer { + fn new( + error_on_reduce: Option, + error_on_flush: Option, + shutdown_condition: ReduceShutdownCondition, + ) -> Self { + Self { + buffer: Arc::new(RwLock::new(Vec::new())), + pipe: Arc::new(RwLock::new(Vec::new())), + error_on_nth_reduce: error_on_reduce, + error_on_nth_flush: error_on_flush, + shutdown_condition, + } + } + + fn get_buffer(&self) -> Arc>> { + self.buffer.clone() + } + + fn get_pipe(&self) -> Arc>> { + self.pipe.clone() + } + } + + impl Reducer for BatchingReducer + where + T: Send + Sync + Clone, + { + type Input = T; + type Output = (); + + async fn reduce(&mut self, t: Self::Input) -> Result<(), anyhow::Error> { + if let Some(idx) = self.error_on_nth_reduce { + if idx == 0 { + self.error_on_nth_reduce.take(); + return Err(anyhow!("err")); + } else { + self.error_on_nth_reduce = Some(idx - 1); + } + } + self.buffer.write().unwrap().push(t); + Ok(()) + } + + async fn flush(&mut self) -> Result<(), anyhow::Error> { + if let Some(idx) = self.error_on_nth_flush { + if idx == 0 { + self.error_on_nth_flush.take(); + return Err(anyhow!("err")); + } else { + self.error_on_nth_flush = Some(idx - 1); + } + } + self.pipe + .write() + .unwrap() + .extend(take(&mut self.buffer.write().unwrap() as &mut Vec).into_iter()); + Ok(()) + } + + fn reset(&mut self) { + self.buffer.write().unwrap().clear(); + } + + async fn is_full(&self) -> bool { + self.buffer.read().unwrap().len() >= 32 + } + + fn get_reduce_config(&self) -> ReduceConfig { + ReduceConfig { + shutdown_condition: self.shutdown_condition, + shutdown_behaviour: ReduceShutdownBehaviour::Flush, + when_full_behaviour: ReducerWhenFullBehaviour::Backpressure, + flush_interval: Some(Duration::from_secs(1)), + } + } + } + + #[tokio::test] + async fn test_commit() { + let offsets = Arc::new(RwLock::new(Vec::new())); + + let commit_client = Arc::new(MockCommitClient { + offsets: offsets.clone(), + }); + let (sender, receiver) = mpsc::channel(1); + let (rendezvou_sender, rendezvou_receiver) = oneshot::channel(); + + let msg = vec![ + OwnedMessage::new( + None, + None, + "topic".to_string(), + Timestamp::NotAvailable, + 0, + 1, + None, + ), + OwnedMessage::new( + None, + None, + "topic".to_string(), + Timestamp::NotAvailable, + 1, + 0, + None, + ), + ]; + + assert!(sender.send((msg.clone(), ())).await.is_ok()); + + tokio::spawn(commit(receiver, commit_client, rendezvou_sender)); + + drop(sender); + let _ = rendezvou_receiver.await; + + assert_eq!(offsets.read().unwrap().len(), 1); + assert_eq!( + offsets.read().unwrap()[0], + TopicPartitionList::from_topic_map(&HashMap::from([ + (("topic".to_string(), 0), Offset::Offset(2)), + (("topic".to_string(), 1), Offset::Offset(1)) + ])) + .unwrap() + ); + } + + #[tokio::test] + async fn test_reduce_err_without_flush_interval() { + let reducer = StreamingReducer::new(None); + let pipe = reducer.get_pipe(); + + let (sender, receiver) = mpsc::channel(1); + let (commit_sender, mut commit_receiver) = mpsc::channel(1); + let shutdown = CancellationToken::new(); + + let msg = OwnedMessage::new( + Some(vec![0, 1, 2, 3, 4, 5, 6, 7]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 0, + None, + ); + + tokio::spawn(reduce_err( + reducer, + receiver, + commit_sender, + shutdown.clone(), + )); + + assert!(sender.send(msg.clone()).await.is_ok()); + assert_eq!( + commit_receiver.recv().await.unwrap().0[0].payload(), + msg.payload() + ); + assert_eq!( + pipe.read().unwrap().last().unwrap().payload().unwrap(), + &[0, 1, 2, 3, 4, 5, 6, 7] + ); + + drop(sender); + shutdown.cancel(); + + sleep(Duration::from_secs(1)).await; + assert!(commit_receiver.is_closed()); + } + + #[tokio::test] + async fn test_reduce_without_flush_interval() { + let reducer = StreamingReducer::new(None); + let pipe = reducer.get_pipe(); + + let (sender, receiver) = mpsc::channel(2); + let (ok_sender, mut ok_receiver) = mpsc::channel(2); + let (err_sender, err_receiver) = mpsc::channel(2); + let shutdown = CancellationToken::new(); + + let msg_0 = OwnedMessage::new( + Some(vec![0, 2, 4, 6]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 0, + None, + ); + let msg_1 = OwnedMessage::new( + Some(vec![1, 3, 5, 7]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 1, + None, + ); + + tokio::spawn(reduce( + reducer, + receiver, + ok_sender, + err_sender, + shutdown.clone(), + )); + + assert!(sender.send((iter::once(msg_0.clone()), 1)).await.is_ok()); + assert!(sender.send((iter::once(msg_1.clone()), 2)).await.is_ok()); + + assert_eq!( + ok_receiver.recv().await.unwrap().0[0].payload(), + msg_0.payload() + ); + assert_eq!( + ok_receiver.recv().await.unwrap().0[0].payload(), + msg_1.payload() + ); + assert_eq!(pipe.read().unwrap().as_slice(), &[1, 2]); + assert!(err_receiver.is_empty()); + + drop(sender); + shutdown.cancel(); + + sleep(Duration::from_secs(1)).await; + assert!(ok_receiver.is_closed()); + assert!(err_receiver.is_closed()); + } + + #[tokio::test] + async fn test_fail_on_reduce_without_flush_interval() { + let reducer = StreamingReducer::new(Some(1)); + let pipe = reducer.get_pipe(); + + let (sender, receiver) = mpsc::channel(2); + let (ok_sender, mut ok_receiver) = mpsc::channel(2); + let (err_sender, mut err_receiver) = mpsc::channel(2); + let shutdown = CancellationToken::new(); + + let msg_0 = OwnedMessage::new( + Some(vec![0, 2, 4, 6]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 0, + None, + ); + let msg_1 = OwnedMessage::new( + Some(vec![1, 3, 5, 7]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 1, + None, + ); + let msg_2 = OwnedMessage::new( + Some(vec![0, 0, 0, 0]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 2, + None, + ); + + tokio::spawn(reduce( + reducer, + receiver, + ok_sender, + err_sender, + shutdown.clone(), + )); + + assert!(sender.send((iter::once(msg_0.clone()), 1)).await.is_ok()); + assert_eq!( + ok_receiver.recv().await.unwrap().0[0].payload(), + msg_0.payload(), + ); + assert_eq!(pipe.read().unwrap().as_slice(), &[1]); + + assert!(sender.send((iter::once(msg_1.clone()), 2)).await.is_ok()); + assert_eq!( + err_receiver.recv().await.unwrap().payload(), + msg_1.payload() + ); + assert_eq!(pipe.read().unwrap().as_slice(), &[1]); + + assert!(sender.send((iter::once(msg_2.clone()), 3)).await.is_ok()); + assert_eq!( + ok_receiver.recv().await.unwrap().0[0].payload(), + msg_2.payload(), + ); + assert_eq!(pipe.read().unwrap().as_slice(), &[1, 3]); + + assert!(ok_receiver.is_empty()); + assert!(err_receiver.is_empty()); + + drop(sender); + shutdown.cancel(); + + sleep(Duration::from_secs(1)).await; + assert!(ok_receiver.is_closed()); + assert!(err_receiver.is_closed()); + } + + #[tokio::test] + async fn test_reduce_err_with_flush_interval() { + let reducer = BatchingReducer::new(None, None, ReduceShutdownCondition::Signal); + let buffer = reducer.get_buffer(); + let pipe = reducer.get_pipe(); + + let (sender, receiver) = mpsc::channel(1); + let (commit_sender, mut commit_receiver) = mpsc::channel(1); + let shutdown = CancellationToken::new(); + + let msg = OwnedMessage::new( + Some(vec![0, 1, 2, 3, 4, 5, 6, 7]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 0, + None, + ); + + tokio::spawn(reduce_err( + reducer, + receiver, + commit_sender, + shutdown.clone(), + )); + + assert!(sender.send(msg.clone()).await.is_ok()); + assert_eq!( + commit_receiver.recv().await.unwrap().0[0].payload(), + msg.payload() + ); + assert_eq!(pipe.read().unwrap()[0].payload(), msg.payload()); + assert!(buffer.read().unwrap().is_empty()); + + drop(sender); + shutdown.cancel(); + + sleep(Duration::from_secs(1)).await; + assert!(commit_receiver.is_closed()); + } + + #[tokio::test] + async fn test_reduce_with_flush_interval() { + let reducer = BatchingReducer::new(None, None, ReduceShutdownCondition::Signal); + let buffer = reducer.get_buffer(); + let pipe = reducer.get_pipe(); + + let (sender, receiver) = mpsc::channel(2); + let (ok_sender, mut ok_receiver) = mpsc::channel(2); + let (err_sender, err_receiver) = mpsc::channel(2); + let shutdown = CancellationToken::new(); + + let msg_0 = OwnedMessage::new( + Some(vec![0, 2, 4, 6]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 0, + None, + ); + let msg_1 = OwnedMessage::new( + Some(vec![1, 3, 5, 7]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 1, + None, + ); + + tokio::spawn(reduce( + reducer, + receiver, + ok_sender, + err_sender, + shutdown.clone(), + )); + + assert!(sender.send((iter::once(msg_0.clone()), 1)).await.is_ok()); + assert!(sender.send((iter::once(msg_1.clone()), 2)).await.is_ok()); + + let ok_msgs = ok_receiver.recv().await.unwrap().0; + assert_eq!(ok_msgs.len(), 2); + assert_eq!(ok_msgs[0].payload(), msg_0.payload()); + assert_eq!(ok_msgs[1].payload(), msg_1.payload()); + assert!(buffer.read().unwrap().is_empty()); + assert_eq!(pipe.read().unwrap().as_slice(), &[1, 2]); + assert!(err_receiver.is_empty()); + + drop(sender); + shutdown.cancel(); + + sleep(Duration::from_secs(1)).await; + assert!(ok_receiver.is_closed()); + assert!(err_receiver.is_closed()); + } + + #[tokio::test] + async fn test_fail_on_reduce_with_flush_interval() { + let reducer = BatchingReducer::new(Some(1), None, ReduceShutdownCondition::Signal); + let buffer = reducer.get_buffer(); + let pipe = reducer.get_pipe(); + + let (sender, receiver) = mpsc::channel(3); + let (ok_sender, mut ok_receiver) = mpsc::channel(3); + let (err_sender, mut err_receiver) = mpsc::channel(3); + let shutdown = CancellationToken::new(); + + let msg_0 = OwnedMessage::new( + Some(vec![0, 3, 6]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 0, + None, + ); + let msg_1 = OwnedMessage::new( + Some(vec![1, 4, 7]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 1, + None, + ); + let msg_2 = OwnedMessage::new( + Some(vec![2, 5, 8]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 2, + None, + ); + + tokio::spawn(reduce( + reducer, + receiver, + ok_sender, + err_sender, + shutdown.clone(), + )); + + assert!(sender.send((iter::once(msg_0.clone()), 0)).await.is_ok()); + let ok_msgs = ok_receiver.recv().await.unwrap().0; + assert_eq!(ok_msgs.len(), 1); + assert_eq!(ok_msgs[0].payload(), msg_0.payload()); + assert_eq!(buffer.read().unwrap().as_slice(), &[] as &[i32]); + assert_eq!(pipe.read().unwrap().as_slice(), &[0]); + + assert!(sender.send((iter::once(msg_1.clone()), 1)).await.is_ok()); + assert_eq!( + err_receiver.recv().await.unwrap().payload(), + msg_1.payload() + ); + assert_eq!(buffer.read().unwrap().as_slice(), &[] as &[i32]); + assert_eq!(pipe.read().unwrap().as_slice(), &[0] as &[i32]); + + assert!(sender.send((iter::once(msg_2.clone()), 2)).await.is_ok()); + let ok_msgs = ok_receiver.recv().await.unwrap().0; + assert_eq!(ok_msgs.len(), 1); + assert_eq!(ok_msgs[0].payload(), msg_2.payload()); + assert_eq!(buffer.read().unwrap().as_slice(), &[] as &[i32]); + assert_eq!(pipe.read().unwrap().as_slice(), &[0, 2] as &[i32]); + + drop(sender); + shutdown.cancel(); + + sleep(Duration::from_secs(1)).await; + assert!(ok_receiver.is_empty()); + assert!(err_receiver.is_empty()); + } + + #[tokio::test] + async fn test_fail_on_flush() { + let reducer = BatchingReducer::new(None, Some(1), ReduceShutdownCondition::Signal); + let buffer = reducer.get_buffer(); + let pipe = reducer.get_pipe(); + + let (sender, receiver) = mpsc::channel(1); + let (ok_sender, mut ok_receiver) = mpsc::channel(1); + let (err_sender, mut err_receiver) = mpsc::channel(1); + let shutdown = CancellationToken::new(); + + let msg_0 = OwnedMessage::new( + Some(vec![0, 3, 6]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 0, + None, + ); + let msg_1 = OwnedMessage::new( + Some(vec![1, 4, 7]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 1, + None, + ); + let msg_2 = OwnedMessage::new( + Some(vec![2, 5, 8]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 2, + None, + ); + let msg_3 = OwnedMessage::new( + Some(vec![0, 0, 0]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 3, + None, + ); + + tokio::spawn(reduce( + reducer, + receiver, + ok_sender, + err_sender, + shutdown.clone(), + )); + + assert!(sender.send((iter::once(msg_0.clone()), 0)).await.is_ok()); + let ok_msgs = ok_receiver.recv().await.unwrap().0; + assert_eq!(ok_msgs.len(), 1); + assert_eq!(ok_msgs[0].payload(), msg_0.payload()); + + assert_eq!(buffer.read().unwrap().as_slice(), &[] as &[i32]); + assert_eq!(pipe.read().unwrap().as_slice(), &[0]); + + assert!(sender.send((iter::once(msg_1.clone()), 1)).await.is_ok()); + assert!(sender.send((iter::once(msg_2.clone()), 2)).await.is_ok()); + assert_eq!( + err_receiver.recv().await.unwrap().payload(), + msg_1.payload() + ); + assert_eq!( + err_receiver.recv().await.unwrap().payload(), + msg_2.payload() + ); + assert_eq!(buffer.read().unwrap().as_slice(), &[] as &[i32]); + assert_eq!(pipe.read().unwrap().as_slice(), &[0]); + + assert!(sender.send((iter::once(msg_3.clone()), 3)).await.is_ok()); + let ok_msgs = ok_receiver.recv().await.unwrap().0; + assert_eq!(ok_msgs.len(), 1); + assert_eq!(ok_msgs[0].payload(), msg_3.payload()); + assert_eq!(buffer.read().unwrap().as_slice(), &[] as &[i32]); + assert_eq!(pipe.read().unwrap().as_slice(), &[0, 3]); + + drop(sender); + shutdown.cancel(); + + sleep(Duration::from_secs(1)).await; + assert!(ok_receiver.is_empty()); + assert!(err_receiver.is_empty()); + } + + #[tokio::test] + async fn test_sequential_reducers() { + let reducer_0 = BatchingReducer::new(None, None, ReduceShutdownCondition::Signal); + let buffer_0 = reducer_0.get_buffer(); + let pipe_0 = reducer_0.get_pipe(); + + let reducer_1 = BatchingReducer::new(None, None, ReduceShutdownCondition::Signal); + let buffer_1 = reducer_1.get_buffer(); + let pipe_1 = reducer_1.get_pipe(); + + let shutdown = CancellationToken::new(); + + let (sender, receiver) = mpsc::channel(1); + let (ok_sender_0, ok_receiver_0) = mpsc::channel(2); + let (err_sender_0, err_receiver_0) = mpsc::channel(1); + + let (ok_sender_1, mut ok_receiver_1) = mpsc::channel(1); + let (err_sender_1, err_receiver_1) = mpsc::channel(1); + + let msg_0 = OwnedMessage::new( + Some(vec![0, 2, 4, 6]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 0, + None, + ); + let msg_1 = OwnedMessage::new( + Some(vec![1, 3, 5, 7]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 1, + None, + ); + + tokio::spawn(reduce( + reducer_0, + receiver, + ok_sender_0, + err_sender_0, + shutdown.clone(), + )); + + tokio::spawn(reduce( + reducer_1, + ok_receiver_0, + ok_sender_1, + err_sender_1, + shutdown.clone(), + )); + + assert!(sender.send((iter::once(msg_0.clone()), 1)).await.is_ok()); + assert!(sender.send((iter::once(msg_1.clone()), 2)).await.is_ok()); + + let ok_msgs = ok_receiver_1.recv().await.unwrap().0; + assert_eq!(ok_msgs.len(), 2); + assert_eq!(ok_msgs[0].payload(), msg_0.payload()); + assert_eq!(ok_msgs[1].payload(), msg_1.payload()); + + assert!(buffer_0.read().unwrap().is_empty()); + assert_eq!(pipe_0.read().unwrap().as_slice(), &[1, 2]); + + assert!(buffer_1.read().unwrap().is_empty()); + assert_eq!(pipe_1.read().unwrap().as_slice(), &[()]); + + assert!(err_receiver_0.is_empty()); + assert!(err_receiver_1.is_empty()); + + drop(sender); + shutdown.cancel(); + + sleep(Duration::from_secs(1)).await; + assert!(err_receiver_0.is_closed()); + assert!(ok_receiver_1.is_closed()); + assert!(err_receiver_1.is_closed()); + } + + #[tokio::test] + async fn test_reduce_shutdown_from_drain() { + let reducer = BatchingReducer::new(None, None, ReduceShutdownCondition::Drain); + let buffer = reducer.get_buffer(); + let pipe = reducer.get_pipe(); + + let (sender, receiver) = mpsc::channel(2); + let (ok_sender, mut ok_receiver) = mpsc::channel(2); + let (err_sender, err_receiver) = mpsc::channel(2); + let shutdown = CancellationToken::new(); + + let msg_0 = OwnedMessage::new( + Some(vec![0, 2, 4, 6]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 0, + None, + ); + let msg_1 = OwnedMessage::new( + Some(vec![1, 3, 5, 7]), + None, + "topic".to_string(), + Timestamp::now(), + 0, + 1, + None, + ); + + tokio::spawn(reduce( + reducer, + receiver, + ok_sender, + err_sender, + shutdown.clone(), + )); + + shutdown.cancel(); + + assert!(sender.send((iter::once(msg_0.clone()), 1)).await.is_ok()); + assert!(sender.send((iter::once(msg_1.clone()), 2)).await.is_ok()); + + let ok_msgs = ok_receiver.recv().await.unwrap().0; + assert_eq!(ok_msgs.len(), 2); + assert_eq!(ok_msgs[0].payload(), msg_0.payload()); + assert_eq!(ok_msgs[1].payload(), msg_1.payload()); + assert!(buffer.read().unwrap().is_empty()); + assert_eq!(pipe.read().unwrap().as_slice(), &[1, 2]); + assert!(err_receiver.is_empty()); + + drop(sender); + shutdown.cancel(); + + sleep(Duration::from_secs(1)).await; + assert!(ok_receiver.is_closed()); + assert!(err_receiver.is_closed()); + } + + #[derive(Clone)] + struct MockMessage { + payload: Vec, + topic: String, + partition: i32, + offset: i64, + } + + impl KafkaMessage for Result, BroadcastStreamRecvError> { + fn detach(&self) -> Result { + let clone = self.clone().unwrap().unwrap(); + Ok(OwnedMessage::new( + Some(clone.payload), + None, + clone.topic, + Timestamp::now(), + clone.partition, + clone.offset, + None, + )) + } + } + + impl MessageQueue for broadcast::Receiver> { + fn stream(&self) -> impl Stream { + BroadcastStream::new(self.resubscribe()) + } + } + + #[tokio::test] + async fn test_map() { + let (sender, receiver) = broadcast::channel(1); + let (ok_sender, mut ok_receiver) = mpsc::channel(1); + let (err_sender, err_receiver) = mpsc::channel(1); + let shutdown = CancellationToken::new(); + + tokio::spawn(map( + receiver, + |msg| Ok(msg.payload().unwrap()[0] * 2), + ok_sender, + err_sender, + shutdown.clone(), + )); + sleep(Duration::from_secs(1)).await; + + let msg_0 = MockMessage { + payload: vec![0], + topic: "topic".to_string(), + partition: 0, + offset: 0, + }; + let msg_1 = MockMessage { + payload: vec![1], + topic: "topic".to_string(), + partition: 0, + offset: 1, + }; + assert!(sender.send(Ok(msg_0.clone())).is_ok()); + assert!(err_receiver.is_empty()); + let res = ok_receiver.recv().await.unwrap(); + assert_eq!( + res.0.collect::>()[0].payload(), + Some(msg_0.payload.clone()).as_deref() + ); + assert_eq!(res.1, msg_0.payload[0] * 2); + + assert!(sender.send(Ok(msg_1.clone())).is_ok()); + assert!(err_receiver.is_empty()); + let res = ok_receiver.recv().await.unwrap(); + assert_eq!( + res.0.collect::>()[0].payload(), + Some(msg_1.payload.clone()).as_deref() + ); + assert_eq!(res.1, msg_1.payload[0] * 2); + + shutdown.cancel(); + sleep(Duration::from_secs(1)).await; + assert!(ok_receiver.is_closed()); + assert!(err_receiver.is_closed()); + } + + #[tokio::test] + async fn test_fail_on_map() { + let (sender, receiver) = broadcast::channel(1); + let (ok_sender, mut ok_receiver) = mpsc::channel(1); + let (err_sender, mut err_receiver) = mpsc::channel(1); + let shutdown = CancellationToken::new(); + + tokio::spawn(map( + receiver, + |msg| { + if msg.payload().unwrap()[0] == 1 { + Err(anyhow!("Oh no")) + } else { + Ok(msg.payload().unwrap()[0] * 2) + } + }, + ok_sender, + err_sender, + shutdown.clone(), + )); + sleep(Duration::from_secs(1)).await; + + let msg_0 = MockMessage { + payload: vec![0], + topic: "topic".to_string(), + partition: 0, + offset: 0, + }; + let msg_1 = MockMessage { + payload: vec![1], + topic: "topic".to_string(), + partition: 0, + offset: 1, + }; + let msg_2 = MockMessage { + payload: vec![2], + topic: "topic".to_string(), + partition: 0, + offset: 2, + }; + + assert!(sender.send(Ok(msg_0.clone())).is_ok()); + assert!(err_receiver.is_empty()); + let res = ok_receiver.recv().await.unwrap(); + assert_eq!( + res.0.collect::>()[0].payload(), + Some(msg_0.payload).as_deref() + ); + assert_eq!(res.1, 0); + + assert!(sender.send(Ok(msg_1.clone())).is_ok()); + assert!(ok_receiver.is_empty()); + let res = err_receiver.recv().await.unwrap(); + assert_eq!(res.payload(), Some(msg_1.payload).as_deref()); + + assert!(sender.send(Ok(msg_2.clone())).is_ok()); + assert!(err_receiver.is_empty()); + let res = ok_receiver.recv().await.unwrap(); + assert_eq!( + res.0.collect::>()[0].payload(), + Some(msg_2.payload).as_deref() + ); + assert_eq!(res.1, 4); + + shutdown.cancel(); + sleep(Duration::from_secs(1)).await; + assert!(ok_receiver.is_closed()); + assert!(err_receiver.is_closed()); + } +} diff --git a/src/consumer/mod.rs b/src/consumer/mod.rs new file mode 100644 index 00000000..dc440b62 --- /dev/null +++ b/src/consumer/mod.rs @@ -0,0 +1,4 @@ +pub mod deserialize_activation; +pub mod inflight_activation_writer; +pub mod kafka; +pub mod os_stream_writer; diff --git a/src/consumer/os_stream_writer.rs b/src/consumer/os_stream_writer.rs new file mode 100644 index 00000000..d36c9afa --- /dev/null +++ b/src/consumer/os_stream_writer.rs @@ -0,0 +1,71 @@ +use crate::consumer::kafka::{ + ReduceConfig, ReduceShutdownBehaviour, ReduceShutdownCondition, Reducer, + ReducerWhenFullBehaviour, +}; +use std::{fmt::Debug, marker::PhantomData, time::Duration}; +use tokio::time::sleep; + +pub enum OsStream { + StdOut, + StdErr, +} + +pub struct OsStreamWriter { + data: Option, + print_duration: Duration, + os_stream: OsStream, + phantom: PhantomData, +} + +impl OsStreamWriter { + pub fn new(print_duration: Duration, os_stream: OsStream) -> Self { + Self { + data: None, + print_duration, + os_stream, + phantom: PhantomData::, + } + } +} + +impl Reducer for OsStreamWriter +where + T: Debug + Send + Sync, +{ + type Input = T; + type Output = (); + + async fn reduce(&mut self, t: Self::Input) -> Result<(), anyhow::Error> { + self.data = Some(t); + Ok(()) + } + + async fn flush(&mut self) -> Result<(), anyhow::Error> { + let Some(data) = self.data.take() else { + return Ok(()); + }; + match self.os_stream { + OsStream::StdOut => println!("{:?}", data), + OsStream::StdErr => eprintln!("{:?}", data), + } + sleep(self.print_duration).await; + Ok(()) + } + + fn reset(&mut self) { + self.data.take(); + } + + async fn is_full(&self) -> bool { + self.data.is_some() + } + + fn get_reduce_config(&self) -> ReduceConfig { + ReduceConfig { + shutdown_condition: ReduceShutdownCondition::Signal, + shutdown_behaviour: ReduceShutdownBehaviour::Flush, + when_full_behaviour: ReducerWhenFullBehaviour::Flush, + flush_interval: None, + } + } +} diff --git a/src/inflight_task_store.rs b/src/inflight_activation_store.rs similarity index 91% rename from src/inflight_task_store.rs rename to src/inflight_activation_store.rs index a5830f72..3d80086b 100644 --- a/src/inflight_task_store.rs +++ b/src/inflight_activation_store.rs @@ -1,12 +1,16 @@ +use std::str::FromStr; + use anyhow::Error; use chrono::{DateTime, Utc}; use prost::Message; use sentry_protos::sentry::v1::TaskActivation; use sqlx::{ - migrate::MigrateDatabase, sqlite::SqlitePool, FromRow, QueryBuilder, Row, Sqlite, Type, + migrate::MigrateDatabase, + sqlite::{SqliteConnectOptions, SqlitePool, SqliteQueryResult}, + ConnectOptions, FromRow, QueryBuilder, Row, Sqlite, Type, }; -pub struct InflightTaskStore { +pub struct InflightActivationStore { sqlite_pool: SqlitePool, } @@ -23,16 +27,31 @@ pub enum TaskActivationStatus { pub struct InflightActivation { pub activation: TaskActivation, pub status: TaskActivationStatus, + pub partition: i32, pub offset: i64, pub added_at: DateTime, pub deadletter_at: Option>, pub processing_deadline: Option>, } +#[derive(Clone, Copy, Debug)] +pub struct QueryResult { + pub rows_affected: u64, +} + +impl From for QueryResult { + fn from(value: SqliteQueryResult) -> Self { + Self { + rows_affected: value.rows_affected(), + } + } +} + #[derive(FromRow)] struct TableRow { id: String, activation: Vec, + partition: i32, offset: i64, added_at: DateTime, deadletter_at: Option>, @@ -48,6 +67,7 @@ impl TryFrom for TableRow { Ok(Self { id: value.activation.id.clone(), activation: value.activation.encode_to_vec(), + partition: value.partition, offset: value.offset, added_at: value.added_at, deadletter_at: value.deadletter_at, @@ -65,6 +85,7 @@ impl From for InflightActivation { "Decode should always be successful as we only store encoded data in this column", ), status: value.status, + partition: value.partition, offset: value.offset, added_at: value.added_at, deadletter_at: value.deadletter_at, @@ -73,22 +94,27 @@ impl From for InflightActivation { } } -impl InflightTaskStore { +impl InflightActivationStore { pub async fn new(url: &str) -> Result { if !Sqlite::database_exists(url).await? { Sqlite::create_database(url).await? } - let sqlite_pool = SqlitePool::connect(url).await?; + let conn_options = SqliteConnectOptions::from_str(url)?.disable_statement_logging(); + + let sqlite_pool = SqlitePool::connect_with(conn_options).await?; sqlx::migrate!("./migrations").run(&sqlite_pool).await?; Ok(Self { sqlite_pool }) } - pub async fn store(&self, batch: Vec) -> Result<(), Error> { + pub async fn store(&self, batch: Vec) -> Result { + if batch.is_empty() { + return Ok(QueryResult { rows_affected: 0 }); + } let mut query_builder = QueryBuilder::::new( "INSERT INTO inflight_taskactivations \ - (id, activation, offset, added_at, deadletter_at, processing_deadline_duration, status)", + (id, activation, partition, offset, added_at, deadletter_at, processing_deadline_duration, status)", ); let rows = batch .into_iter() @@ -98,6 +124,7 @@ impl InflightTaskStore { .push_values(rows, |mut b, row| { b.push_bind(row.id); b.push_bind(row.activation); + b.push_bind(row.partition); b.push_bind(row.offset); b.push_bind(row.added_at); b.push_bind(row.deadletter_at); @@ -105,8 +132,7 @@ impl InflightTaskStore { b.push_bind(row.status); }) .build(); - query.execute(&self.sqlite_pool).await?; - Ok(()) + Ok(query.execute(&self.sqlite_pool).await?.into()) } pub async fn get_pending_activation(&self) -> Result, Error> { @@ -178,24 +204,12 @@ impl InflightTaskStore { pub async fn get_retry_activations(&self) -> Result, Error> { Ok( - sqlx::query("SELECT * FROM inflight_taskactivations WHERE status = $1") + sqlx::query_as("SELECT * FROM inflight_taskactivations WHERE status = $1") .bind(TaskActivationStatus::Retry) .fetch_all(&self.sqlite_pool) .await? .into_iter() - .map(|row| { - TableRow { - id: row.get("id"), - activation: row.get("activation"), - offset: row.get("offset"), - added_at: row.get("added_at"), - deadletter_at: row.get("deadletter_at"), - processing_deadline_duration: row.get("processing_deadline_duration"), - processing_deadline: row.get("processing_deadline"), - status: row.get("status"), - } - .into() - }) + .map(|row: TableRow| row.into()) .collect(), ) } @@ -210,7 +224,9 @@ mod tests { use sentry_protos::sentry::v1::TaskActivation; use sqlx::{Row, SqlitePool}; - use crate::inflight_task_store::{InflightActivation, InflightTaskStore, TaskActivationStatus}; + use crate::inflight_activation_store::{ + InflightActivation, InflightActivationStore, TaskActivationStatus, + }; fn generate_temp_filename() -> String { let mut rng = rand::thread_rng(); @@ -219,7 +235,7 @@ mod tests { #[tokio::test] async fn test_create_db() { - assert!(InflightTaskStore::new(&generate_temp_filename()) + assert!(InflightActivationStore::new(&generate_temp_filename()) .await .is_ok()) } @@ -227,7 +243,7 @@ mod tests { #[tokio::test] async fn test_store() { let url = generate_temp_filename(); - let store = InflightTaskStore::new(&url).await.unwrap(); + let store = InflightActivationStore::new(&url).await.unwrap(); #[allow(deprecated)] let batch = vec![ @@ -248,6 +264,7 @@ mod tests { expires: Some(1), }, status: TaskActivationStatus::Pending, + partition: 0, offset: 0, added_at: Utc::now(), deadletter_at: None, @@ -270,6 +287,7 @@ mod tests { expires: Some(1), }, status: TaskActivationStatus::Pending, + partition: 0, offset: 1, added_at: Utc::now(), deadletter_at: None, @@ -292,7 +310,7 @@ mod tests { #[tokio::test] async fn test_get_pending_activation() { let url = generate_temp_filename(); - let store = InflightTaskStore::new(&url).await.unwrap(); + let store = InflightActivationStore::new(&url).await.unwrap(); let added_at = Utc::now(); #[allow(deprecated)] @@ -313,6 +331,7 @@ mod tests { expires: Some(1), }, status: TaskActivationStatus::Pending, + partition: 0, offset: 0, added_at, deadletter_at: None, @@ -339,6 +358,7 @@ mod tests { expires: Some(1), }, status: TaskActivationStatus::Processing, + partition: 0, offset: 0, added_at, deadletter_at: None, @@ -354,7 +374,7 @@ mod tests { #[tokio::test] async fn test_count_pending_activations() { let url = generate_temp_filename(); - let store = InflightTaskStore::new(&url).await.unwrap(); + let store = InflightActivationStore::new(&url).await.unwrap(); #[allow(deprecated)] let batch = vec![ @@ -375,6 +395,7 @@ mod tests { expires: Some(1), }, status: TaskActivationStatus::Pending, + partition: 0, offset: 0, added_at: Utc::now(), deadletter_at: None, @@ -397,6 +418,7 @@ mod tests { expires: Some(1), }, status: TaskActivationStatus::Pending, + partition: 0, offset: 1, added_at: Utc::now(), deadletter_at: None, @@ -411,7 +433,7 @@ mod tests { #[tokio::test] async fn set_activation_status() { let url = generate_temp_filename(); - let store = InflightTaskStore::new(&url).await.unwrap(); + let store = InflightActivationStore::new(&url).await.unwrap(); #[allow(deprecated)] let batch = vec![ @@ -432,6 +454,7 @@ mod tests { expires: Some(1), }, status: TaskActivationStatus::Pending, + partition: 0, offset: 0, added_at: Utc::now(), deadletter_at: None, @@ -454,6 +477,7 @@ mod tests { expires: Some(1), }, status: TaskActivationStatus::Pending, + partition: 0, offset: 1, added_at: Utc::now(), deadletter_at: None, @@ -488,7 +512,7 @@ mod tests { #[tokio::test] async fn test_set_processing_deadline() { let url = generate_temp_filename(); - let store = InflightTaskStore::new(&url).await.unwrap(); + let store = InflightActivationStore::new(&url).await.unwrap(); #[allow(deprecated)] let batch = vec![InflightActivation { @@ -508,6 +532,7 @@ mod tests { expires: Some(1), }, status: TaskActivationStatus::Pending, + partition: 0, offset: 0, added_at: Utc::now(), deadletter_at: None, @@ -540,7 +565,7 @@ mod tests { #[tokio::test] async fn test_delete_activation() { let url = generate_temp_filename(); - let store = InflightTaskStore::new(&url).await.unwrap(); + let store = InflightActivationStore::new(&url).await.unwrap(); #[allow(deprecated)] let batch = vec![ @@ -561,6 +586,7 @@ mod tests { expires: Some(1), }, status: TaskActivationStatus::Pending, + partition: 0, offset: 0, added_at: Utc::now(), deadletter_at: None, @@ -583,6 +609,7 @@ mod tests { expires: Some(1), }, status: TaskActivationStatus::Pending, + partition: 0, offset: 1, added_at: Utc::now(), deadletter_at: None, @@ -634,7 +661,7 @@ mod tests { #[tokio::test] async fn test_get_retry_activations() { let url = generate_temp_filename(); - let store = InflightTaskStore::new(&url).await.unwrap(); + let store = InflightActivationStore::new(&url).await.unwrap(); let added_at = Utc::now(); #[allow(deprecated)] @@ -656,6 +683,7 @@ mod tests { expires: Some(1), }, status: TaskActivationStatus::Pending, + partition: 0, offset: 0, added_at, deadletter_at: None, @@ -678,6 +706,7 @@ mod tests { expires: Some(1), }, status: TaskActivationStatus::Pending, + partition: 0, offset: 1, added_at, deadletter_at: None, @@ -716,6 +745,7 @@ mod tests { expires: Some(1), }, status: TaskActivationStatus::Retry, + partition: 0, offset: 0, added_at, deadletter_at: None, @@ -738,6 +768,7 @@ mod tests { expires: Some(1), }, status: TaskActivationStatus::Retry, + partition: 0, offset: 1, added_at, deadletter_at: None, diff --git a/src/main.rs b/src/main.rs index e690ca98..66295ad1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,24 @@ use anyhow::Error; use clap::Parser; use config::Config; -use inflight_task_store::InflightTaskStore; +use consumer::{ + deserialize_activation::{self}, + inflight_activation_writer::{self, InflightActivationWriter}, + kafka::{start_consumer, ReduceShutdownBehaviour, ReducerWhenFullBehaviour}, + os_stream_writer::{OsStream, OsStreamWriter}, +}; +use inflight_activation_store::InflightActivationStore; +use rdkafka::{config::RDKafkaLogLevel, ClientConfig}; +use std::{sync::Arc, time::Duration}; +use tokio::{select, signal, time}; +use tracing::info; +#[allow(dead_code)] mod config; #[allow(dead_code)] -mod inflight_task_store; +mod consumer; +#[allow(dead_code)] +mod inflight_activation_store; mod logging; mod metrics; @@ -23,13 +36,65 @@ struct Args { #[tokio::main] async fn main() -> Result<(), Error> { - // Read command line options let args = Args::parse(); let config = Config::from_args(&args)?; logging::init(logging::LoggingConfig::from_config(&config)); metrics::init(metrics::MetricsConfig::from_config(&config)); - InflightTaskStore::new(&config.db_path).await?; - Ok(()) + let store = Arc::new(InflightActivationStore::new(&config.db_path).await?); + let rpc_store = store.clone(); + + tokio::spawn(async move { + let mut timer = time::interval(Duration::from_millis(200)); + loop { + select! { + _ = signal::ctrl_c() => { + break; + } + _ = timer.tick() => { + let _ = rpc_store.get_pending_activation().await; + info!( + "Pending activation in store: {}", + rpc_store.count_pending_activations().await.unwrap() + ); + } + } + } + }); + + start_consumer( + [&config.kafka_topic as &str].as_ref(), + ClientConfig::new() + .set("group.id", "test-taskworker-consumer") + .set("bootstrap.servers", "127.0.0.1:9092") + .set("enable.partition.eof", "false") + .set("session.timeout.ms", "6000") + .set("enable.auto.commit", "true") + .set("auto.commit.interval.ms", "5000") + .set("enable.auto.offset.store", "false") + .set_log_level(RDKafkaLogLevel::Debug), + processing_strategy!({ + map: deserialize_activation::new(deserialize_activation::Config { + deadletter_duration: None, + }), + + reduce: InflightActivationWriter::new( + store.clone(), + inflight_activation_writer::Config { + max_buf_len: 128, + max_pending_activations: 2048, + flush_interval: None, + when_full_behaviour: ReducerWhenFullBehaviour::Flush, + shutdown_behaviour: ReduceShutdownBehaviour::Drop, + } + ), + + err: OsStreamWriter::new( + Duration::from_secs(1), + OsStream::StdErr, + ), + }), + ) + .await }