diff --git a/Cargo.toml b/Cargo.toml index 0b8e116..c2f5c67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ repository = "https://github.com/astraly-labs/pragma-common" license = "MIT" [features] +default = [] serde = ["dep:serde"] borsh = ["dep:borsh"] proto = ["dep:prost"] diff --git a/src/services.rs b/src/services.rs index a3686e9..b5c3428 100644 --- a/src/services.rs +++ b/src/services.rs @@ -2,7 +2,7 @@ /// use std::{panic, time::Duration}; -use anyhow::Context; +use anyhow::{anyhow, Context}; use futures::Future; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; @@ -71,7 +71,7 @@ pub trait Service: 'static + Send + Sync { let runner = ServiceRunner::new(ctx, &mut join_set); self.start(runner).await.context("Starting service")?; - drive_joinset(join_set).await + drive_critical_joinset(join_set).await } } @@ -109,28 +109,59 @@ impl<'a> ServiceRunner<'a> { /// A group of services that can be started together #[derive(Default)] pub struct ServiceGroup { - services: Vec>, - join_set: Option>>, + critical_services: Vec>, + auxiliary_services: Vec>, + critical_join_set: Option>>, + auxiliary_join_set: Option>>, } impl ServiceGroup { - pub fn new(services: Vec>) -> Self { + pub fn new( + critical_services: Vec>, + auxiliary_services: Vec>, + ) -> Self { + let has_critical_services = !critical_services.is_empty(); + let has_auxiliary_services = !auxiliary_services.is_empty(); + Self { - services, - join_set: Some(JoinSet::default()), + critical_services, + auxiliary_services, + critical_join_set: if has_critical_services { + Some(JoinSet::default()) + } else { + None + }, + auxiliary_join_set: if has_auxiliary_services { + Some(JoinSet::default()) + } else { + None + }, } } - pub fn push(&mut self, service: impl Service) { - if self.join_set.is_none() { - self.join_set = Some(JoinSet::default()); + pub fn push_critical(&mut self, service: impl Service) { + if self.critical_join_set.is_none() { + self.critical_join_set = Some(JoinSet::default()); } - self.services.push(Box::new(service)); + self.critical_services.push(Box::new(service)); + } + + pub fn push_auxiliary(&mut self, service: impl Service) { + if self.auxiliary_join_set.is_none() { + self.auxiliary_join_set = Some(JoinSet::default()); + } + self.auxiliary_services.push(Box::new(service)); + } + + #[must_use] + pub fn with_critical(mut self, service: impl Service) -> Self { + self.push_critical(service); + self } #[must_use] - pub fn with(mut self, service: impl Service) -> Self { - self.push(service); + pub fn with_auxiliary(mut self, service: impl Service) -> Self { + self.push_auxiliary(service); self } } @@ -138,25 +169,52 @@ impl ServiceGroup { #[async_trait::async_trait] impl Service for ServiceGroup { async fn start<'a>(&mut self, runner: ServiceRunner<'a>) -> anyhow::Result<()> { - let mut own_join_set = self - .join_set + if self.critical_services.is_empty() { + return Err(anyhow!("ServiceGroup started without any critical service")); + } + + let mut own_critical_join_set = self + .critical_join_set .take() - .expect("Service has already been started"); + .context("ServiceGroup has already been started")?; - for service in &mut self.services { + for service in &mut self.critical_services { let ctx = runner.ctx.clone(); service - .start(ServiceRunner::new(ctx, &mut own_join_set)) + .start(ServiceRunner::new(ctx, &mut own_critical_join_set)) .await - .context("Starting service")?; + .context("Starting critical service")?; } - runner.join_set.spawn(drive_joinset(own_join_set)); + if !self.auxiliary_services.is_empty() { + let mut own_auxiliary_join_set = self + .auxiliary_join_set + .take() + .context("ServiceGroup has already been started")?; + + for service in &mut self.auxiliary_services { + let ctx = runner.ctx.clone(); + // Ignore start result for auxiliary services + let _ = service + .start(ServiceRunner::new(ctx, &mut own_auxiliary_join_set)) + .await; + } + + runner.join_set.spawn(drive_critical_and_auxiliary_joinsets( + own_critical_join_set, + own_auxiliary_join_set, + )); + } else { + runner + .join_set + .spawn(drive_critical_joinset(own_critical_join_set)); + }; + Ok(()) } } -async fn drive_joinset(mut join_set: JoinSet>) -> anyhow::Result<()> { +async fn drive_critical_joinset(mut join_set: JoinSet>) -> anyhow::Result<()> { while let Some(result) = join_set.join_next().await { match result { Ok(result) => result?, @@ -166,5 +224,22 @@ async fn drive_joinset(mut join_set: JoinSet>) -> anyhow::Res Err(_) => {} } } + + Ok(()) +} + +async fn drive_critical_and_auxiliary_joinsets( + critical_join_set: JoinSet>, + mut auxiliary_join_set: JoinSet>, +) -> anyhow::Result<()> { + let (res_critical, _ret_auxiliary) = futures::future::join( + drive_critical_joinset(critical_join_set), + // Ignore result for auxiliary services + async { while let Some(_result) = auxiliary_join_set.join_next().await {} }, + ) + .await; + + res_critical?; + Ok(()) } diff --git a/tests/test_services.rs b/tests/test_services.rs index 4fcc67b..b8b9532 100644 --- a/tests/test_services.rs +++ b/tests/test_services.rs @@ -222,7 +222,9 @@ mod test_services { should_panic: false, }; - let mut group = ServiceGroup::default().with(service1).with(service2); + let mut group = ServiceGroup::default() + .with_critical(service1) + .with_critical(service2); let ctx = ServiceContext::new(); let mut join_set = JoinSet::new(); @@ -266,6 +268,214 @@ mod test_services { assert_eq!(count2_before, count2_after, "Service 2 should have stopped"); } + #[tokio::test] + async fn test_empty_service_group() { + let mut group = ServiceGroup::default(); + + let ctx = ServiceContext::new(); + let mut join_set = JoinSet::new(); + let runner = ServiceRunner::new(ctx.clone(), &mut join_set); + + // Start service group + assert!(group.start(runner).await.is_err()); + } + + #[tokio::test] + async fn test_aux_only_service_group() { + let counter1 = Arc::new(Mutex::new(0)); + let counter2 = Arc::new(Mutex::new(0)); + + let service1 = TestService { + counter: counter1.clone(), + sleep_duration: Some(Duration::from_millis(50)), + should_panic: false, + }; + + let service2 = TestService { + counter: counter2.clone(), + sleep_duration: Some(Duration::from_millis(30)), + should_panic: false, + }; + + let mut group = ServiceGroup::default() + .with_auxiliary(service1) + .with_auxiliary(service2); + + let ctx = ServiceContext::new(); + let mut join_set = JoinSet::new(); + let runner = ServiceRunner::new(ctx.clone(), &mut join_set); + + // Start service group + assert!(group.start(runner).await.is_err()); + } + + #[tokio::test] + async fn test_mixed_service_group() { + let counter1 = Arc::new(Mutex::new(0)); + let counter2 = Arc::new(Mutex::new(0)); + + let service1 = TestService { + counter: counter1.clone(), + sleep_duration: Some(Duration::from_millis(50)), + should_panic: false, + }; + + let service2 = TestService { + counter: counter2.clone(), + sleep_duration: Some(Duration::from_millis(30)), + should_panic: false, + }; + + let mut group = ServiceGroup::default() + .with_critical(service1) + .with_auxiliary(service2); + + let ctx = ServiceContext::new(); + let mut join_set = JoinSet::new(); + let runner = ServiceRunner::new(ctx.clone(), &mut join_set); + + // Start service group + group.start(runner).await.unwrap(); + + // Let services run + sleep(Duration::from_millis(200)).await; + + // Verify both services are running + let count1 = *counter1.lock().unwrap(); + let count2 = *counter2.lock().unwrap(); + + assert!(count1 > 0, "Service 1 should have incremented counter"); + assert!(count2 > 0, "Service 2 should have incremented counter"); + assert!( + count2 > count1, + "Service 2 should increment faster than Service 1" + ); + + // Cancel all services + ctx.cancel(); + + // Wait for all services to complete + while let Some(result) = join_set.join_next().await { + result.unwrap().unwrap(); + } + + // Verify all services stopped + let count1_before = *counter1.lock().unwrap(); + let count2_before = *counter2.lock().unwrap(); + + sleep(Duration::from_millis(200)).await; + + let count1_after = *counter1.lock().unwrap(); + let count2_after = *counter2.lock().unwrap(); + + assert_eq!(count1_before, count1_after, "Service 1 should have stopped"); + assert_eq!(count2_before, count2_after, "Service 2 should have stopped"); + } + + #[tokio::test] + async fn test_auxiliary_service_failure() { + let counter1 = Arc::new(Mutex::new(0)); + let counter2 = Arc::new(Mutex::new(0)); + + let service1 = TestService { + counter: counter1.clone(), + sleep_duration: Some(Duration::from_millis(50)), + should_panic: false, + }; + + let service2 = TestService { + counter: counter2.clone(), + sleep_duration: Some(Duration::from_millis(30)), + should_panic: true, + }; + + let mut group = ServiceGroup::default() + .with_critical(service1) + .with_auxiliary(service2); + + let ctx = ServiceContext::new(); + let mut join_set = JoinSet::new(); + let runner = ServiceRunner::new(ctx.clone(), &mut join_set); + + // Start service group + group.start(runner).await.unwrap(); + + // Let services run + sleep(Duration::from_millis(200)).await; + + // Verify both services are running + let count1 = *counter1.lock().unwrap(); + let count2 = *counter2.lock().unwrap(); + + assert!(count1 > 0, "Service 1 should have incremented counter"); + assert!(count2 == 0, "Service 2 should not have incremented counter"); + + // Cancel all services + ctx.cancel(); + + // Wait for all services to complete + while let Some(result) = join_set.join_next().await { + result.unwrap().unwrap(); + } + + // Verify all services stopped + let count1_before = *counter1.lock().unwrap(); + + sleep(Duration::from_millis(200)).await; + + let count1_after = *counter1.lock().unwrap(); + + assert_eq!(count1_before, count1_after, "Service 1 should have stopped"); + } + + #[tokio::test] + #[should_panic(expected = "Service panic as requested")] + async fn test_critical_service_failure() { + let counter1 = Arc::new(Mutex::new(0)); + let counter2 = Arc::new(Mutex::new(0)); + + let service1 = TestService { + counter: counter1.clone(), + sleep_duration: Some(Duration::from_millis(50)), + should_panic: true, + }; + + let service2 = TestService { + counter: counter2.clone(), + sleep_duration: Some(Duration::from_millis(30)), + should_panic: false, + }; + + let mut group = ServiceGroup::default() + .with_critical(service1) + .with_critical(service2); + + let ctx = ServiceContext::new(); + let mut join_set = JoinSet::new(); + let runner = ServiceRunner::new(ctx.clone(), &mut join_set); + + // Start service group + group.start(runner).await.unwrap(); + + // Let services run + sleep(Duration::from_millis(200)).await; + + // Verify both services are running + let count1 = *counter1.lock().unwrap(); + let count2 = *counter2.lock().unwrap(); + + assert!(count1 == 0, "Service 1 should not have incremented counter"); + assert!(count2 > 0, "Service 2 should have incremented counter"); + + // Cancel all services + ctx.cancel(); + + // Wait for all services to complete + while let Some(result) = join_set.join_next().await { + result.unwrap().unwrap(); + } + } + #[tokio::test] async fn test_service_lifecycle_with_controlled_shutdown() { // Create a counter to track executions