-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[BUG] Mem leak handles in scheduler #5590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,63 @@ | ||
use parking_lot::RwLock; | ||
use std::fmt::Debug; | ||
use std::sync::Arc; | ||
use std::sync::atomic::AtomicU64; | ||
use std::sync::{Arc, Weak}; | ||
use std::time::Duration; | ||
use std::{collections::HashMap, fmt::Debug}; | ||
use tokio::select; | ||
use tracing::Span; | ||
|
||
use super::{Component, ComponentContext, Handler, Message}; | ||
|
||
#[derive(Debug)] | ||
pub(crate) struct SchedulerTaskHandle { | ||
join_handle: Option<tokio::task::JoinHandle<()>>, | ||
cancel: tokio_util::sync::CancellationToken, | ||
} | ||
|
||
impl Debug for SchedulerTaskHandle { | ||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
f.debug_struct("SchedulerTaskHandle").finish() | ||
} | ||
} | ||
|
||
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] | ||
pub(crate) struct TaskId(u64); | ||
|
||
pub(crate) struct HandleGuard { | ||
weak_handles: Weak<RwLock<HashMap<TaskId, SchedulerTaskHandle>>>, | ||
task_id: TaskId, | ||
} | ||
|
||
impl Drop for HandleGuard { | ||
fn drop(&mut self) { | ||
if let Some(handles) = self.weak_handles.upgrade() { | ||
let mut handles = handles.write(); | ||
handles.remove(&self.task_id); | ||
} | ||
} | ||
} | ||
|
||
#[derive(Clone, Debug)] | ||
pub struct Scheduler { | ||
handles: Arc<RwLock<Vec<SchedulerTaskHandle>>>, | ||
handles: Arc<RwLock<HashMap<TaskId, SchedulerTaskHandle>>>, | ||
next_id: Arc<AtomicU64>, | ||
} | ||
|
||
impl Scheduler { | ||
pub(crate) fn new() -> Scheduler { | ||
Scheduler { | ||
handles: Arc::new(RwLock::new(Vec::new())), | ||
handles: Arc::new(RwLock::new(HashMap::new())), | ||
next_id: Arc::new(AtomicU64::new(1)), | ||
} | ||
} | ||
|
||
/// Allocate the next task ID. | ||
fn allocate_id(&self) -> TaskId { | ||
let id = self | ||
.next_id | ||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed); | ||
TaskId(id) | ||
} | ||
|
||
/// Schedule a message to be sent to the component after the specified duration. | ||
/// | ||
/// `span_factory` is called immediately before sending the scheduled message to the component. | ||
|
@@ -40,9 +73,17 @@ impl Scheduler { | |
M: Message, | ||
S: (Fn() -> Option<Span>) + Send + Sync + 'static, | ||
{ | ||
let id = self.allocate_id(); | ||
let handles_weak = Arc::downgrade(&self.handles); | ||
|
||
let cancel = ctx.cancellation_token.clone(); | ||
let sender = ctx.receiver().clone(); | ||
let handle = tokio::spawn(async move { | ||
let _guard = HandleGuard { | ||
weak_handles: handles_weak, | ||
task_id: id, | ||
}; | ||
|
||
select! { | ||
_ = cancel.cancelled() => {} | ||
_ = tokio::time::sleep(duration) => { | ||
|
@@ -61,7 +102,7 @@ impl Scheduler { | |
join_handle: Some(handle), | ||
cancel: ctx.cancellation_token.clone(), | ||
}; | ||
self.handles.write().push(handle); | ||
self.handles.write().insert(id, handle); | ||
} | ||
|
||
/// Schedule a message to be sent to the component at a regular interval. | ||
|
@@ -80,11 +121,16 @@ impl Scheduler { | |
M: Message + Clone, | ||
S: (Fn() -> Option<Span>) + Send + Sync + 'static, | ||
{ | ||
let id = self.allocate_id(); | ||
let handles_weak = Arc::downgrade(&self.handles); | ||
let cancel = ctx.cancellation_token.clone(); | ||
|
||
let sender = ctx.receiver().clone(); | ||
|
||
let handle = tokio::spawn(async move { | ||
let _guard = HandleGuard { | ||
weak_handles: handles_weak, | ||
task_id: id, | ||
}; | ||
let mut counter = 0; | ||
while Self::should_continue(num_times, counter) { | ||
select! { | ||
|
@@ -109,7 +155,7 @@ impl Scheduler { | |
join_handle: Some(handle), | ||
cancel: ctx.cancellation_token.clone(), | ||
}; | ||
self.handles.write().push(handle); | ||
self.handles.write().insert(id, handle); | ||
} | ||
|
||
#[cfg(test)] | ||
|
@@ -132,7 +178,7 @@ impl Scheduler { | |
let mut handles = self.handles.write(); | ||
handles | ||
.iter_mut() | ||
.flat_map(|h| h.join_handle.take()) | ||
.flat_map(|(_, h)| h.join_handle.take()) | ||
.collect::<Vec<_>>() | ||
}; | ||
for join_handle in handles.iter_mut() { | ||
|
@@ -148,7 +194,7 @@ impl Scheduler { | |
pub(crate) fn stop(&self) { | ||
let handles = self.handles.read(); | ||
for handle in handles.iter() { | ||
handle.cancel.cancel(); | ||
handle.1.cancel.cancel(); | ||
} | ||
} | ||
} | ||
|
@@ -157,45 +203,43 @@ impl Scheduler { | |
mod tests { | ||
use super::*; | ||
use crate::system::System; | ||
|
||
use async_trait::async_trait; | ||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; | ||
use std::sync::Arc; | ||
use std::time::Duration; | ||
|
||
use std::sync::atomic::{AtomicUsize, Ordering}; | ||
|
||
#[derive(Debug)] | ||
struct TestComponent { | ||
struct SimpleScheduleIntervalComponent { | ||
queue_size: usize, | ||
counter: Arc<AtomicUsize>, | ||
} | ||
|
||
#[derive(Clone, Debug)] | ||
struct ScheduleMessage {} | ||
|
||
impl TestComponent { | ||
impl SimpleScheduleIntervalComponent { | ||
fn new(queue_size: usize, counter: Arc<AtomicUsize>) -> Self { | ||
TestComponent { | ||
SimpleScheduleIntervalComponent { | ||
queue_size, | ||
counter, | ||
} | ||
} | ||
} | ||
#[async_trait] | ||
impl Handler<ScheduleMessage> for TestComponent { | ||
impl Handler<ScheduleMessage> for SimpleScheduleIntervalComponent { | ||
type Result = (); | ||
|
||
async fn handle( | ||
&mut self, | ||
_message: ScheduleMessage, | ||
_ctx: &ComponentContext<TestComponent>, | ||
_ctx: &ComponentContext<SimpleScheduleIntervalComponent>, | ||
) { | ||
self.counter.fetch_add(1, Ordering::SeqCst); | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl Component for TestComponent { | ||
impl Component for SimpleScheduleIntervalComponent { | ||
fn get_name() -> &'static str { | ||
"Test component" | ||
} | ||
|
@@ -204,7 +248,10 @@ mod tests { | |
self.queue_size | ||
} | ||
|
||
async fn on_start(&mut self, ctx: &ComponentContext<TestComponent>) -> () { | ||
async fn on_start( | ||
&mut self, | ||
ctx: &ComponentContext<SimpleScheduleIntervalComponent>, | ||
) -> () { | ||
let duration = Duration::from_millis(100); | ||
ctx.scheduler | ||
.schedule(ScheduleMessage {}, duration, ctx, || None); | ||
|
@@ -224,12 +271,81 @@ mod tests { | |
async fn test_schedule() { | ||
let system = System::new(); | ||
let counter = Arc::new(AtomicUsize::new(0)); | ||
let component = TestComponent::new(10, counter.clone()); | ||
let component = SimpleScheduleIntervalComponent::new(10, counter.clone()); | ||
let _handle = system.start_component(component); | ||
// yield to allow the component to process the messages | ||
tokio::task::yield_now().await; | ||
// We should have scheduled the message once | ||
system.join().await; | ||
assert_eq!(counter.load(Ordering::SeqCst), 5); | ||
} | ||
|
||
#[derive(Debug)] | ||
struct OneMessageComponent { | ||
queue_size: usize, | ||
counter: Arc<AtomicUsize>, | ||
handles_empty_after: Arc<AtomicBool>, | ||
} | ||
|
||
impl OneMessageComponent { | ||
fn new( | ||
queue_size: usize, | ||
counter: Arc<AtomicUsize>, | ||
handles_empty_after: Arc<AtomicBool>, | ||
) -> Self { | ||
OneMessageComponent { | ||
queue_size, | ||
counter, | ||
handles_empty_after, | ||
} | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl Component for OneMessageComponent { | ||
fn get_name() -> &'static str { | ||
"OneMessageComponent" | ||
} | ||
|
||
fn queue_size(&self) -> usize { | ||
self.queue_size | ||
} | ||
|
||
async fn on_start(&mut self, ctx: &ComponentContext<OneMessageComponent>) -> () { | ||
let duration = Duration::from_millis(100); | ||
ctx.scheduler | ||
.schedule(ScheduleMessage {}, duration, ctx, || None); | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl Handler<ScheduleMessage> for OneMessageComponent { | ||
type Result = (); | ||
|
||
async fn handle( | ||
&mut self, | ||
_message: ScheduleMessage, | ||
ctx: &ComponentContext<OneMessageComponent>, | ||
) { | ||
self.counter.fetch_add(1, Ordering::SeqCst); | ||
self.handles_empty_after | ||
.store(ctx.scheduler.handles.read().is_empty(), Ordering::SeqCst); | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_handle_cleaned_up() { | ||
let system = System::new(); | ||
let counter = Arc::new(AtomicUsize::new(0)); | ||
let handles_empty_after = Arc::new(AtomicBool::new(false)); | ||
let component = OneMessageComponent::new(10, counter.clone(), handles_empty_after.clone()); | ||
let _handle = system.start_component(component); | ||
// Wait for the 100ms schedule to trigger | ||
tokio::time::sleep(Duration::from_millis(500)).await; | ||
// yield to allow the component to process the messages | ||
tokio::task::yield_now().await; | ||
assert!(handles_empty_after.load(Ordering::SeqCst)); | ||
// We should have scheduled the message once | ||
system.join().await; | ||
} | ||
Comment on lines
+337
to
+350
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [CriticalError] There's a potential race condition in this test. The test verifies that the However, the A more robust approach would be to check the state of the scheduler's With the #[tokio::test]
async fn test_handle_cleaned_up() {
let system = System::new();
let counter = Arc::new(AtomicUsize::new(0));
// The OneMessageComponent can be simplified to not need `handles_empty_after`
let component = OneMessageComponent::new(10, counter.clone());
let handle = system.start_component(component);
// Allow on_start to run and schedule the task.
tokio::task::yield_now().await;
assert_eq!(handle.ctx.scheduler.handles.read().len(), 1, "Handle should be present after scheduling");
// Wait for the schedule to trigger and the task to be cleaned up.
tokio::time::sleep(Duration::from_millis(500)).await;
assert!(handle.ctx.scheduler.handles.read().is_empty(), "Handles map should be empty after task completion");
assert_eq!(counter.load(Ordering::SeqCst), 1, "Message should have been handled once");
system.join().await;
} This would also allow simplifying Context for Agents
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll leave it to you to decide if this is worth the lift. |
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd recommend not using RwLock here unless we have a heavy read path. An RwLock is typically more expensive per access if it's 100% write, and I only see read calls from tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah thats fine, can do, was just leaving it as it was before to minimize churn.