-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[BUG] Mem leak handles in scheduler #5590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,63 @@ | ||
use parking_lot::RwLock; | ||
use std::fmt::Debug; | ||
use std::sync::Arc; | ||
use std::sync::atomic::AtomicU64; | ||
use std::sync::{Arc, Weak}; | ||
use std::time::Duration; | ||
use std::{collections::HashMap, fmt::Debug}; | ||
use tokio::select; | ||
use tracing::Span; | ||
|
||
use super::{Component, ComponentContext, Handler, Message}; | ||
|
||
#[derive(Debug)] | ||
pub(crate) struct SchedulerTaskHandle { | ||
join_handle: Option<tokio::task::JoinHandle<()>>, | ||
cancel: tokio_util::sync::CancellationToken, | ||
} | ||
|
||
impl Debug for SchedulerTaskHandle { | ||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
f.debug_struct("SchedulerTaskHandle").finish() | ||
} | ||
} | ||
|
||
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] | ||
pub struct TaskId(u64); | ||
|
||
pub struct HandleGuard { | ||
HammadB marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
weak_handles: Weak<RwLock<HashMap<TaskId, SchedulerTaskHandle>>>, | ||
task_id: TaskId, | ||
} | ||
|
||
impl Drop for HandleGuard { | ||
fn drop(&mut self) { | ||
if let Some(handles) = self.weak_handles.upgrade() { | ||
let mut handles = handles.write(); | ||
handles.remove(&self.task_id); | ||
} | ||
} | ||
} | ||
|
||
#[derive(Clone, Debug)] | ||
pub struct Scheduler { | ||
handles: Arc<RwLock<Vec<SchedulerTaskHandle>>>, | ||
handles: Arc<RwLock<HashMap<TaskId, SchedulerTaskHandle>>>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd recommend not using RwLock here unless we have a heavy read path. An RwLock is typically more expensive per access if it's 100% write, and I only see read calls from tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah thats fine, can do, was just leaving it as it was before to minimize churn. |
||
next_id: Arc<AtomicU64>, | ||
} | ||
|
||
impl Scheduler { | ||
pub(crate) fn new() -> Scheduler { | ||
Scheduler { | ||
handles: Arc::new(RwLock::new(Vec::new())), | ||
handles: Arc::new(RwLock::new(HashMap::new())), | ||
next_id: Arc::new(AtomicU64::new(1)), | ||
} | ||
} | ||
|
||
/// Allocate the next task ID. | ||
fn allocate_id(&self) -> TaskId { | ||
let id = self | ||
.next_id | ||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed); | ||
TaskId(id) | ||
} | ||
|
||
/// Schedule a message to be sent to the component after the specified duration. | ||
/// | ||
/// `span_factory` is called immediately before sending the scheduled message to the component. | ||
|
@@ -40,9 +73,17 @@ | |
M: Message, | ||
S: (Fn() -> Option<Span>) + Send + Sync + 'static, | ||
{ | ||
let id = self.allocate_id(); | ||
let handles_weak = Arc::downgrade(&self.handles); | ||
|
||
let cancel = ctx.cancellation_token.clone(); | ||
let sender = ctx.receiver().clone(); | ||
let handle = tokio::spawn(async move { | ||
let _guard = HandleGuard { | ||
weak_handles: handles_weak, | ||
task_id: id, | ||
}; | ||
|
||
select! { | ||
_ = cancel.cancelled() => {} | ||
_ = tokio::time::sleep(duration) => { | ||
|
@@ -61,7 +102,7 @@ | |
join_handle: Some(handle), | ||
cancel: ctx.cancellation_token.clone(), | ||
}; | ||
self.handles.write().push(handle); | ||
self.handles.write().insert(id, handle); | ||
} | ||
|
||
/// Schedule a message to be sent to the component at a regular interval. | ||
|
@@ -80,11 +121,16 @@ | |
M: Message + Clone, | ||
S: (Fn() -> Option<Span>) + Send + Sync + 'static, | ||
{ | ||
let id = self.allocate_id(); | ||
let handles_weak = Arc::downgrade(&self.handles); | ||
let cancel = ctx.cancellation_token.clone(); | ||
|
||
let sender = ctx.receiver().clone(); | ||
|
||
let handle = tokio::spawn(async move { | ||
let _guard = HandleGuard { | ||
weak_handles: handles_weak, | ||
task_id: id, | ||
}; | ||
let mut counter = 0; | ||
while Self::should_continue(num_times, counter) { | ||
select! { | ||
|
@@ -109,7 +155,7 @@ | |
join_handle: Some(handle), | ||
cancel: ctx.cancellation_token.clone(), | ||
}; | ||
self.handles.write().push(handle); | ||
self.handles.write().insert(id, handle); | ||
} | ||
|
||
#[cfg(test)] | ||
|
@@ -132,7 +178,7 @@ | |
let mut handles = self.handles.write(); | ||
handles | ||
.iter_mut() | ||
.flat_map(|h| h.join_handle.take()) | ||
.flat_map(|(_, h)| h.join_handle.take()) | ||
.collect::<Vec<_>>() | ||
}; | ||
for join_handle in handles.iter_mut() { | ||
|
@@ -148,7 +194,7 @@ | |
pub(crate) fn stop(&self) { | ||
let handles = self.handles.read(); | ||
for handle in handles.iter() { | ||
handle.cancel.cancel(); | ||
handle.1.cancel.cancel(); | ||
} | ||
} | ||
} | ||
|
@@ -157,45 +203,43 @@ | |
mod tests { | ||
use super::*; | ||
use crate::system::System; | ||
|
||
use async_trait::async_trait; | ||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; | ||
use std::sync::Arc; | ||
use std::time::Duration; | ||
|
||
use std::sync::atomic::{AtomicUsize, Ordering}; | ||
|
||
#[derive(Debug)] | ||
struct TestComponent { | ||
struct SimpleScheduleIntervalComponent { | ||
queue_size: usize, | ||
counter: Arc<AtomicUsize>, | ||
} | ||
|
||
#[derive(Clone, Debug)] | ||
struct ScheduleMessage {} | ||
|
||
impl TestComponent { | ||
impl SimpleScheduleIntervalComponent { | ||
fn new(queue_size: usize, counter: Arc<AtomicUsize>) -> Self { | ||
TestComponent { | ||
SimpleScheduleIntervalComponent { | ||
queue_size, | ||
counter, | ||
} | ||
} | ||
} | ||
#[async_trait] | ||
impl Handler<ScheduleMessage> for TestComponent { | ||
impl Handler<ScheduleMessage> for SimpleScheduleIntervalComponent { | ||
type Result = (); | ||
|
||
async fn handle( | ||
&mut self, | ||
_message: ScheduleMessage, | ||
_ctx: &ComponentContext<TestComponent>, | ||
_ctx: &ComponentContext<SimpleScheduleIntervalComponent>, | ||
) { | ||
self.counter.fetch_add(1, Ordering::SeqCst); | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl Component for TestComponent { | ||
impl Component for SimpleScheduleIntervalComponent { | ||
fn get_name() -> &'static str { | ||
"Test component" | ||
} | ||
|
@@ -204,7 +248,10 @@ | |
self.queue_size | ||
} | ||
|
||
async fn on_start(&mut self, ctx: &ComponentContext<TestComponent>) -> () { | ||
async fn on_start( | ||
&mut self, | ||
ctx: &ComponentContext<SimpleScheduleIntervalComponent>, | ||
) -> () { | ||
let duration = Duration::from_millis(100); | ||
ctx.scheduler | ||
.schedule(ScheduleMessage {}, duration, ctx, || None); | ||
|
@@ -224,12 +271,81 @@ | |
async fn test_schedule() { | ||
let system = System::new(); | ||
let counter = Arc::new(AtomicUsize::new(0)); | ||
let component = TestComponent::new(10, counter.clone()); | ||
let component = SimpleScheduleIntervalComponent::new(10, counter.clone()); | ||
let _handle = system.start_component(component); | ||
// yield to allow the component to process the messages | ||
tokio::task::yield_now().await; | ||
// We should have scheduled the message once | ||
system.join().await; | ||
assert_eq!(counter.load(Ordering::SeqCst), 5); | ||
} | ||
|
||
#[derive(Debug)] | ||
struct OneMessageComponent { | ||
queue_size: usize, | ||
counter: Arc<AtomicUsize>, | ||
handles_empty_after: Arc<AtomicBool>, | ||
} | ||
|
||
impl OneMessageComponent { | ||
fn new( | ||
queue_size: usize, | ||
counter: Arc<AtomicUsize>, | ||
handles_empty_after: Arc<AtomicBool>, | ||
) -> Self { | ||
OneMessageComponent { | ||
queue_size, | ||
counter, | ||
handles_empty_after: handles_empty_after, | ||
} | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl Component for OneMessageComponent { | ||
fn get_name() -> &'static str { | ||
"OneMessageComponent" | ||
} | ||
|
||
fn queue_size(&self) -> usize { | ||
self.queue_size | ||
} | ||
|
||
async fn on_start(&mut self, ctx: &ComponentContext<OneMessageComponent>) -> () { | ||
let duration = Duration::from_millis(100); | ||
ctx.scheduler | ||
.schedule(ScheduleMessage {}, duration, ctx, || None); | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl Handler<ScheduleMessage> for OneMessageComponent { | ||
type Result = (); | ||
|
||
async fn handle( | ||
&mut self, | ||
_message: ScheduleMessage, | ||
ctx: &ComponentContext<OneMessageComponent>, | ||
) { | ||
self.counter.fetch_add(1, Ordering::SeqCst); | ||
self.handles_empty_after | ||
.store(ctx.scheduler.handles.read().is_empty(), Ordering::SeqCst); | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_handle_cleaned_up() { | ||
let system = System::new(); | ||
let counter = Arc::new(AtomicUsize::new(0)); | ||
let handles_empty_after = Arc::new(AtomicBool::new(false)); | ||
let component = OneMessageComponent::new(10, counter.clone(), handles_empty_after.clone()); | ||
let _handle = system.start_component(component); | ||
// Wait for the 100ms schedule to trigger | ||
tokio::time::sleep(Duration::from_millis(500)).await; | ||
// yield to allow the component to process the messages | ||
tokio::task::yield_now().await; | ||
assert!(handles_empty_after.load(Ordering::SeqCst)); | ||
// We should have scheduled the message once | ||
system.join().await; | ||
} | ||
Comment on lines
+337
to
+350
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [CriticalError] There's a potential race condition in this test. The test verifies that the However, the A more robust approach would be to check the state of the scheduler's With the #[tokio::test]
async fn test_handle_cleaned_up() {
let system = System::new();
let counter = Arc::new(AtomicUsize::new(0));
// The OneMessageComponent can be simplified to not need `handles_empty_after`
let component = OneMessageComponent::new(10, counter.clone());
let handle = system.start_component(component);
// Allow on_start to run and schedule the task.
tokio::task::yield_now().await;
assert_eq!(handle.ctx.scheduler.handles.read().len(), 1, "Handle should be present after scheduling");
// Wait for the schedule to trigger and the task to be cleaned up.
tokio::time::sleep(Duration::from_millis(500)).await;
assert!(handle.ctx.scheduler.handles.read().is_empty(), "Handles map should be empty after task completion");
assert_eq!(counter.load(Ordering::SeqCst), 1, "Message should have been handled once");
system.join().await;
} This would also allow simplifying Context for Agents
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll leave it to you to decide if this is worth the lift. |
||
} |
Uh oh!
There was an error while loading. Please reload this page.