diff --git a/rust/system/src/scheduler.rs b/rust/system/src/scheduler.rs index a0a49a4e9dd..65100f44e50 100644 --- a/rust/system/src/scheduler.rs +++ b/rust/system/src/scheduler.rs @@ -1,30 +1,63 @@ use parking_lot::RwLock; -use std::fmt::Debug; -use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::sync::{Arc, Weak}; use std::time::Duration; +use std::{collections::HashMap, fmt::Debug}; use tokio::select; use tracing::Span; use super::{Component, ComponentContext, Handler, Message}; -#[derive(Debug)] pub(crate) struct SchedulerTaskHandle { join_handle: Option>, cancel: tokio_util::sync::CancellationToken, } +impl Debug for SchedulerTaskHandle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SchedulerTaskHandle").finish() + } +} + +#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] +pub(crate) struct TaskId(u64); + +pub(crate) struct HandleGuard { + weak_handles: Weak>>, + task_id: TaskId, +} + +impl Drop for HandleGuard { + fn drop(&mut self) { + if let Some(handles) = self.weak_handles.upgrade() { + let mut handles = handles.write(); + handles.remove(&self.task_id); + } + } +} + #[derive(Clone, Debug)] pub struct Scheduler { - handles: Arc>>, + handles: Arc>>, + next_id: Arc, } impl Scheduler { pub(crate) fn new() -> Scheduler { Scheduler { - handles: Arc::new(RwLock::new(Vec::new())), + handles: Arc::new(RwLock::new(HashMap::new())), + next_id: Arc::new(AtomicU64::new(1)), } } + /// Allocate the next task ID. + fn allocate_id(&self) -> TaskId { + let id = self + .next_id + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + TaskId(id) + } + /// Schedule a message to be sent to the component after the specified duration. /// /// `span_factory` is called immediately before sending the scheduled message to the component. @@ -40,9 +73,17 @@ impl Scheduler { M: Message, S: (Fn() -> Option) + Send + Sync + 'static, { + let id = self.allocate_id(); + let handles_weak = Arc::downgrade(&self.handles); + let cancel = ctx.cancellation_token.clone(); let sender = ctx.receiver().clone(); let handle = tokio::spawn(async move { + let _guard = HandleGuard { + weak_handles: handles_weak, + task_id: id, + }; + select! { _ = cancel.cancelled() => {} _ = tokio::time::sleep(duration) => { @@ -61,7 +102,7 @@ impl Scheduler { join_handle: Some(handle), cancel: ctx.cancellation_token.clone(), }; - self.handles.write().push(handle); + self.handles.write().insert(id, handle); } /// Schedule a message to be sent to the component at a regular interval. @@ -80,11 +121,16 @@ impl Scheduler { M: Message + Clone, S: (Fn() -> Option) + Send + Sync + 'static, { + let id = self.allocate_id(); + let handles_weak = Arc::downgrade(&self.handles); let cancel = ctx.cancellation_token.clone(); - let sender = ctx.receiver().clone(); let handle = tokio::spawn(async move { + let _guard = HandleGuard { + weak_handles: handles_weak, + task_id: id, + }; let mut counter = 0; while Self::should_continue(num_times, counter) { select! { @@ -109,7 +155,7 @@ impl Scheduler { join_handle: Some(handle), cancel: ctx.cancellation_token.clone(), }; - self.handles.write().push(handle); + self.handles.write().insert(id, handle); } #[cfg(test)] @@ -132,7 +178,7 @@ impl Scheduler { let mut handles = self.handles.write(); handles .iter_mut() - .flat_map(|h| h.join_handle.take()) + .flat_map(|(_, h)| h.join_handle.take()) .collect::>() }; for join_handle in handles.iter_mut() { @@ -148,7 +194,7 @@ impl Scheduler { pub(crate) fn stop(&self) { let handles = self.handles.read(); for handle in handles.iter() { - handle.cancel.cancel(); + handle.1.cancel.cancel(); } } } @@ -157,15 +203,13 @@ impl Scheduler { mod tests { use super::*; use crate::system::System; - use async_trait::async_trait; + use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; - use std::sync::atomic::{AtomicUsize, Ordering}; - #[derive(Debug)] - struct TestComponent { + struct SimpleScheduleIntervalComponent { queue_size: usize, counter: Arc, } @@ -173,29 +217,29 @@ mod tests { #[derive(Clone, Debug)] struct ScheduleMessage {} - impl TestComponent { + impl SimpleScheduleIntervalComponent { fn new(queue_size: usize, counter: Arc) -> Self { - TestComponent { + SimpleScheduleIntervalComponent { queue_size, counter, } } } #[async_trait] - impl Handler for TestComponent { + impl Handler for SimpleScheduleIntervalComponent { type Result = (); async fn handle( &mut self, _message: ScheduleMessage, - _ctx: &ComponentContext, + _ctx: &ComponentContext, ) { self.counter.fetch_add(1, Ordering::SeqCst); } } #[async_trait] - impl Component for TestComponent { + impl Component for SimpleScheduleIntervalComponent { fn get_name() -> &'static str { "Test component" } @@ -204,7 +248,10 @@ mod tests { self.queue_size } - async fn on_start(&mut self, ctx: &ComponentContext) -> () { + async fn on_start( + &mut self, + ctx: &ComponentContext, + ) -> () { let duration = Duration::from_millis(100); ctx.scheduler .schedule(ScheduleMessage {}, duration, ctx, || None); @@ -224,7 +271,7 @@ mod tests { async fn test_schedule() { let system = System::new(); let counter = Arc::new(AtomicUsize::new(0)); - let component = TestComponent::new(10, counter.clone()); + let component = SimpleScheduleIntervalComponent::new(10, counter.clone()); let _handle = system.start_component(component); // yield to allow the component to process the messages tokio::task::yield_now().await; @@ -232,4 +279,73 @@ mod tests { system.join().await; assert_eq!(counter.load(Ordering::SeqCst), 5); } + + #[derive(Debug)] + struct OneMessageComponent { + queue_size: usize, + counter: Arc, + handles_empty_after: Arc, + } + + impl OneMessageComponent { + fn new( + queue_size: usize, + counter: Arc, + handles_empty_after: Arc, + ) -> Self { + OneMessageComponent { + queue_size, + counter, + handles_empty_after, + } + } + } + + #[async_trait] + impl Component for OneMessageComponent { + fn get_name() -> &'static str { + "OneMessageComponent" + } + + fn queue_size(&self) -> usize { + self.queue_size + } + + async fn on_start(&mut self, ctx: &ComponentContext) -> () { + let duration = Duration::from_millis(100); + ctx.scheduler + .schedule(ScheduleMessage {}, duration, ctx, || None); + } + } + + #[async_trait] + impl Handler for OneMessageComponent { + type Result = (); + + async fn handle( + &mut self, + _message: ScheduleMessage, + ctx: &ComponentContext, + ) { + self.counter.fetch_add(1, Ordering::SeqCst); + self.handles_empty_after + .store(ctx.scheduler.handles.read().is_empty(), Ordering::SeqCst); + } + } + + #[tokio::test] + async fn test_handle_cleaned_up() { + let system = System::new(); + let counter = Arc::new(AtomicUsize::new(0)); + let handles_empty_after = Arc::new(AtomicBool::new(false)); + let component = OneMessageComponent::new(10, counter.clone(), handles_empty_after.clone()); + let _handle = system.start_component(component); + // Wait for the 100ms schedule to trigger + tokio::time::sleep(Duration::from_millis(500)).await; + // yield to allow the component to process the messages + tokio::task::yield_now().await; + assert!(handles_empty_after.load(Ordering::SeqCst)); + // We should have scheduled the message once + system.join().await; + } }