Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 137 additions & 21 deletions rust/system/src/scheduler.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,63 @@
use parking_lot::RwLock;
use std::fmt::Debug;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::sync::{Arc, Weak};
use std::time::Duration;
use std::{collections::HashMap, fmt::Debug};
use tokio::select;
use tracing::Span;

use super::{Component, ComponentContext, Handler, Message};

#[derive(Debug)]
pub(crate) struct SchedulerTaskHandle {
join_handle: Option<tokio::task::JoinHandle<()>>,
cancel: tokio_util::sync::CancellationToken,
}

impl Debug for SchedulerTaskHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SchedulerTaskHandle").finish()
}
}

#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct TaskId(u64);

pub struct HandleGuard {
weak_handles: Weak<RwLock<HashMap<TaskId, SchedulerTaskHandle>>>,
task_id: TaskId,
}

impl Drop for HandleGuard {
fn drop(&mut self) {
if let Some(handles) = self.weak_handles.upgrade() {
let mut handles = handles.write();
handles.remove(&self.task_id);
}
}
}

#[derive(Clone, Debug)]
pub struct Scheduler {
handles: Arc<RwLock<Vec<SchedulerTaskHandle>>>,
handles: Arc<RwLock<HashMap<TaskId, SchedulerTaskHandle>>>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend not using RwLock here unless we have a heavy read path. An RwLock is typically more expensive per access if it's 100% write, and I only see read calls from tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah thats fine, can do, was just leaving it as it was before to minimize churn.

next_id: Arc<AtomicU64>,
}

impl Scheduler {
pub(crate) fn new() -> Scheduler {
Scheduler {
handles: Arc::new(RwLock::new(Vec::new())),
handles: Arc::new(RwLock::new(HashMap::new())),
next_id: Arc::new(AtomicU64::new(1)),
}
}

/// Allocate the next task ID.
fn allocate_id(&self) -> TaskId {
let id = self
.next_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
TaskId(id)
}

/// Schedule a message to be sent to the component after the specified duration.
///
/// `span_factory` is called immediately before sending the scheduled message to the component.
Expand All @@ -40,9 +73,17 @@
M: Message,
S: (Fn() -> Option<Span>) + Send + Sync + 'static,
{
let id = self.allocate_id();
let handles_weak = Arc::downgrade(&self.handles);

let cancel = ctx.cancellation_token.clone();
let sender = ctx.receiver().clone();
let handle = tokio::spawn(async move {
let _guard = HandleGuard {
weak_handles: handles_weak,
task_id: id,
};

select! {
_ = cancel.cancelled() => {}
_ = tokio::time::sleep(duration) => {
Expand All @@ -61,7 +102,7 @@
join_handle: Some(handle),
cancel: ctx.cancellation_token.clone(),
};
self.handles.write().push(handle);
self.handles.write().insert(id, handle);
}

/// Schedule a message to be sent to the component at a regular interval.
Expand All @@ -80,11 +121,16 @@
M: Message + Clone,
S: (Fn() -> Option<Span>) + Send + Sync + 'static,
{
let id = self.allocate_id();
let handles_weak = Arc::downgrade(&self.handles);
let cancel = ctx.cancellation_token.clone();

let sender = ctx.receiver().clone();

let handle = tokio::spawn(async move {
let _guard = HandleGuard {
weak_handles: handles_weak,
task_id: id,
};
let mut counter = 0;
while Self::should_continue(num_times, counter) {
select! {
Expand All @@ -109,7 +155,7 @@
join_handle: Some(handle),
cancel: ctx.cancellation_token.clone(),
};
self.handles.write().push(handle);
self.handles.write().insert(id, handle);
}

#[cfg(test)]
Expand All @@ -132,7 +178,7 @@
let mut handles = self.handles.write();
handles
.iter_mut()
.flat_map(|h| h.join_handle.take())
.flat_map(|(_, h)| h.join_handle.take())
.collect::<Vec<_>>()
};
for join_handle in handles.iter_mut() {
Expand All @@ -148,7 +194,7 @@
pub(crate) fn stop(&self) {
let handles = self.handles.read();
for handle in handles.iter() {
handle.cancel.cancel();
handle.1.cancel.cancel();
}
}
}
Expand All @@ -157,45 +203,43 @@
mod tests {
use super::*;
use crate::system::System;

use async_trait::async_trait;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;

use std::sync::atomic::{AtomicUsize, Ordering};

#[derive(Debug)]
struct TestComponent {
struct SimpleScheduleIntervalComponent {
queue_size: usize,
counter: Arc<AtomicUsize>,
}

#[derive(Clone, Debug)]
struct ScheduleMessage {}

impl TestComponent {
impl SimpleScheduleIntervalComponent {
fn new(queue_size: usize, counter: Arc<AtomicUsize>) -> Self {
TestComponent {
SimpleScheduleIntervalComponent {
queue_size,
counter,
}
}
}
#[async_trait]
impl Handler<ScheduleMessage> for TestComponent {
impl Handler<ScheduleMessage> for SimpleScheduleIntervalComponent {
type Result = ();

async fn handle(
&mut self,
_message: ScheduleMessage,
_ctx: &ComponentContext<TestComponent>,
_ctx: &ComponentContext<SimpleScheduleIntervalComponent>,
) {
self.counter.fetch_add(1, Ordering::SeqCst);
}
}

#[async_trait]
impl Component for TestComponent {
impl Component for SimpleScheduleIntervalComponent {
fn get_name() -> &'static str {
"Test component"
}
Expand All @@ -204,7 +248,10 @@
self.queue_size
}

async fn on_start(&mut self, ctx: &ComponentContext<TestComponent>) -> () {
async fn on_start(
&mut self,
ctx: &ComponentContext<SimpleScheduleIntervalComponent>,
) -> () {
let duration = Duration::from_millis(100);
ctx.scheduler
.schedule(ScheduleMessage {}, duration, ctx, || None);
Expand All @@ -224,12 +271,81 @@
async fn test_schedule() {
let system = System::new();
let counter = Arc::new(AtomicUsize::new(0));
let component = TestComponent::new(10, counter.clone());
let component = SimpleScheduleIntervalComponent::new(10, counter.clone());
let _handle = system.start_component(component);
// yield to allow the component to process the messages
tokio::task::yield_now().await;
// We should have scheduled the message once
system.join().await;
assert_eq!(counter.load(Ordering::SeqCst), 5);
}

#[derive(Debug)]
struct OneMessageComponent {
queue_size: usize,
counter: Arc<AtomicUsize>,
handles_empty_after: Arc<AtomicBool>,
}

impl OneMessageComponent {
fn new(
queue_size: usize,
counter: Arc<AtomicUsize>,
handles_empty_after: Arc<AtomicBool>,
) -> Self {
OneMessageComponent {
queue_size,
counter,
handles_empty_after: handles_empty_after,

Check failure on line 299 in rust/system/src/scheduler.rs

View workflow job for this annotation

GitHub Actions / Lint

redundant field names in struct initialization
}
}
}

#[async_trait]
impl Component for OneMessageComponent {
fn get_name() -> &'static str {
"OneMessageComponent"
}

fn queue_size(&self) -> usize {
self.queue_size
}

async fn on_start(&mut self, ctx: &ComponentContext<OneMessageComponent>) -> () {
let duration = Duration::from_millis(100);
ctx.scheduler
.schedule(ScheduleMessage {}, duration, ctx, || None);
}
}

#[async_trait]
impl Handler<ScheduleMessage> for OneMessageComponent {
type Result = ();

async fn handle(
&mut self,
_message: ScheduleMessage,
ctx: &ComponentContext<OneMessageComponent>,
) {
self.counter.fetch_add(1, Ordering::SeqCst);
self.handles_empty_after
.store(ctx.scheduler.handles.read().is_empty(), Ordering::SeqCst);
}
}

#[tokio::test]
async fn test_handle_cleaned_up() {
let system = System::new();
let counter = Arc::new(AtomicUsize::new(0));
let handles_empty_after = Arc::new(AtomicBool::new(false));
let component = OneMessageComponent::new(10, counter.clone(), handles_empty_after.clone());
let _handle = system.start_component(component);
// Wait for the 100ms schedule to trigger
tokio::time::sleep(Duration::from_millis(500)).await;
// yield to allow the component to process the messages
tokio::task::yield_now().await;
assert!(handles_empty_after.load(Ordering::SeqCst));
// We should have scheduled the message once
system.join().await;
}
Comment on lines +337 to +350
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[CriticalError]

There's a potential race condition in this test. The test verifies that the handles map is empty by checking a flag set from within the OneMessageComponent::handle method.

However, the handle method runs in the component's task, while the HandleGuard's drop implementation (which cleans up the map) runs at the end of the scheduler's spawned task. These are two different tasks, and their execution order isn't guaranteed. It's possible for handle to be called and check is_empty() before the scheduler task has finished and removed its handle from the map. This could lead to a flaky test.

A more robust approach would be to check the state of the scheduler's handles map directly from the test function. This provides a deterministic way to verify the state before and after the scheduled task has executed.

With the handles field made pub(crate) (as suggested in another comment), you could rewrite the test like this:

    #[tokio::test]
    async fn test_handle_cleaned_up() {
        let system = System::new();
        let counter = Arc::new(AtomicUsize::new(0));
        // The OneMessageComponent can be simplified to not need `handles_empty_after`
        let component = OneMessageComponent::new(10, counter.clone());
        let handle = system.start_component(component);

        // Allow on_start to run and schedule the task.
        tokio::task::yield_now().await;
        assert_eq!(handle.ctx.scheduler.handles.read().len(), 1, "Handle should be present after scheduling");

        // Wait for the schedule to trigger and the task to be cleaned up.
        tokio::time::sleep(Duration::from_millis(500)).await;
        
        assert!(handle.ctx.scheduler.handles.read().is_empty(), "Handles map should be empty after task completion");
        assert_eq!(counter.load(Ordering::SeqCst), 1, "Message should have been handled once");
        
        system.join().await;
    }

This would also allow simplifying OneMessageComponent by removing the handles_empty_after field and its related logic.

Context for Agents
[**CriticalError**]

There's a potential race condition in this test. The test verifies that the `handles` map is empty by checking a flag set from within the `OneMessageComponent::handle` method.

However, the `handle` method runs in the component's task, while the `HandleGuard`'s `drop` implementation (which cleans up the map) runs at the end of the scheduler's spawned task. These are two different tasks, and their execution order isn't guaranteed. It's possible for `handle` to be called and check `is_empty()` *before* the scheduler task has finished and removed its handle from the map. This could lead to a flaky test.

A more robust approach would be to check the state of the scheduler's `handles` map directly from the test function. This provides a deterministic way to verify the state before and after the scheduled task has executed.

With the `handles` field made `pub(crate)` (as suggested in another comment), you could rewrite the test like this:

```rust
    #[tokio::test]
    async fn test_handle_cleaned_up() {
        let system = System::new();
        let counter = Arc::new(AtomicUsize::new(0));
        // The OneMessageComponent can be simplified to not need `handles_empty_after`
        let component = OneMessageComponent::new(10, counter.clone());
        let handle = system.start_component(component);

        // Allow on_start to run and schedule the task.
        tokio::task::yield_now().await;
        assert_eq!(handle.ctx.scheduler.handles.read().len(), 1, "Handle should be present after scheduling");

        // Wait for the schedule to trigger and the task to be cleaned up.
        tokio::time::sleep(Duration::from_millis(500)).await;
        
        assert!(handle.ctx.scheduler.handles.read().is_empty(), "Handles map should be empty after task completion");
        assert_eq!(counter.load(Ordering::SeqCst), 1, "Message should have been handled once");
        
        system.join().await;
    }
```

This would also allow simplifying `OneMessageComponent` by removing the `handles_empty_after` field and its related logic.

File: rust/system/src/scheduler.rs
Line: 350

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave it to you to decide if this is worth the lift.

}
Loading