Skip to content

Commit a885727

Browse files
scheduler: Fix bug in scheduling safety task
Fixed bugs related to handling and scheduling safety task. #18 #20 #39
1 parent 69fa91b commit a885727

File tree

15 files changed

+340
-67
lines changed

15 files changed

+340
-67
lines changed

src/kyron/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,15 @@ rust_binary(
100100
visibility = ["//visibility:public"],
101101
deps = _EXAMPLE_DEPS,
102102
)
103+
104+
rust_binary(
105+
name = "safety_task",
106+
srcs = [
107+
"examples/safety_task.rs",
108+
],
109+
proc_macro_deps = [
110+
"//src/kyron-macros:runtime_macros",
111+
],
112+
visibility = ["//visibility:public"],
113+
deps = _EXAMPLE_DEPS,
114+
)

src/kyron/examples/safety_task.rs

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//
2+
// Copyright (c) 2025 Contributors to the Eclipse Foundation
3+
//
4+
// See the NOTICE file(s) distributed with this work for additional
5+
// information regarding copyright ownership.
6+
//
7+
// This program and the accompanying materials are made available under the
8+
// terms of the Apache License Version 2.0 which is available at
9+
// <https://www.apache.org/licenses/LICENSE-2.0>
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
14+
use kyron::prelude::*;
15+
use kyron::safety;
16+
use kyron::spawn_on_dedicated;
17+
use kyron_foundation::prelude::*;
18+
19+
async fn failing_safety_task() -> Result<(), String> {
20+
info!("Worker-N: failing_safety_task");
21+
Err("Intentional failure".to_string())
22+
}
23+
24+
async fn passing_safety_task() -> Result<(), String> {
25+
info!("Worker-N: passing_safety_task");
26+
Ok(())
27+
}
28+
29+
async fn passing_non_safety_task() -> Result<(), String> {
30+
info!("Dedicated worker (dw1): passing_non_safety_task");
31+
Ok(())
32+
}
33+
34+
fn main() {
35+
tracing_subscriber::fmt()
36+
.with_target(false) // Optional: Remove module path
37+
.with_max_level(Level::DEBUG)
38+
.with_thread_ids(true)
39+
.with_thread_names(true)
40+
.init();
41+
42+
// Create runtime
43+
let (builder, _engine_id) = kyron::runtime::RuntimeBuilder::new().with_engine(
44+
ExecutionEngineBuilder::new()
45+
.task_queue_size(256)
46+
.enable_safety_worker(ThreadParameters::default())
47+
.with_dedicated_worker("dw1".into(), ThreadParameters::default())
48+
.workers(2),
49+
);
50+
51+
let mut runtime = builder.build().unwrap();
52+
// Put programs into runtime and run them
53+
runtime.block_on(async move {
54+
let handle1 = safety::spawn(failing_safety_task());
55+
let handle2 = safety::spawn(passing_safety_task());
56+
let handle3 = spawn_on_dedicated(passing_non_safety_task(), "dw1".into());
57+
58+
info!("=============================== Spawned all tasks ===============================");
59+
60+
let _ = handle1.await;
61+
info!("Safety worker: Since safety task fails, safety worker executes parent task from this statement onwards.");
62+
let _ = handle2.await;
63+
let _ = handle3.await;
64+
65+
info!("Safety worker: Program finished running.");
66+
});
67+
68+
info!("Exit.");
69+
}

src/kyron/src/scheduler/context.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,9 @@ pub(crate) struct WorkerContext {
318318
/// Helper flag to check if safety was enabled in runtime builder
319319
is_safety_enabled: bool,
320320

321+
/// This flag is used to schedule parent task of failing safety task into safety worker
322+
schedule_safety: Cell<bool>,
323+
321324
wakeup_time: Cell<Option<u64>>,
322325
}
323326

@@ -399,6 +402,7 @@ impl ContextBuilder {
399402
worker_id: Cell::new(self.worker_id.expect("Worker type must be set in context builder!")),
400403
handler: RefCell::new(Some(Rc::new(self.handle.expect("Handler type must be set in context builder!")))),
401404
is_safety_enabled: self.is_with_safety,
405+
schedule_safety: Cell::new(false),
402406
wakeup_time: Cell::new(None),
403407
drivers: Some(self.drivers),
404408
}
@@ -444,6 +448,24 @@ pub(crate) fn ctx_get_worker_id() -> WorkerId {
444448
})
445449
}
446450

451+
///
452+
/// Set schedule safety flag
453+
///
454+
#[allow(dead_code)] // To avoid error when runtime mocking feature is enabled
455+
pub(crate) fn ctx_set_schedule_safety(val: bool) {
456+
CTX.try_with(|ctx| ctx.borrow().as_ref().expect("Called before CTX init?").schedule_safety.set(val))
457+
.unwrap_or_default();
458+
}
459+
460+
///
461+
/// Get schedule safety flag and clear
462+
///
463+
#[allow(dead_code)]
464+
pub(crate) fn ctx_get_schedule_safety() -> bool {
465+
CTX.try_with(|ctx| ctx.borrow().as_ref().expect("Called before CTX init?").schedule_safety.replace(false))
466+
.unwrap_or_default()
467+
}
468+
447469
///
448470
/// Check if safety was enabled
449471
///

src/kyron/src/scheduler/join_handle.rs

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
use kyron_foundation::prelude::*;
1414
use kyron_foundation::{not_recoverable_error, prelude::CommonErrors};
1515

16+
use crate::scheduler::task::task_context::TaskContext;
1617
use crate::{
1718
futures::{FutureInternalReturn, FutureState},
1819
TaskRef,
@@ -74,34 +75,52 @@ impl<T: Send + 'static> Future for JoinHandle<T> {
7475
if was_set {
7576
FutureInternalReturn::default()
7677
} else {
78+
// Check whether there is safety error for the completed task and this task is running on async worker
79+
// if this task is already running on safety worker/dedicated worker, do not set the flag to schedule on safety worker.
80+
if self.for_task.get_task_safety_error() && TaskContext::is_task_running_on_async_worker() {
81+
// Set the flag to wake this task into safety worker
82+
TaskContext::set_flag_to_wake_parent_task_into_safety();
83+
waker.wake_by_ref();
84+
FutureInternalReturn::polled()
85+
} else {
86+
let mut ret: Result<T, CommonErrors> = Err(CommonErrors::NoData);
87+
let ret_as_ptr = &mut ret as *mut _;
88+
self.for_task.get_return_val(ret_as_ptr as *mut u8);
89+
90+
match ret {
91+
Ok(v) => FutureInternalReturn::ready(Ok(v)),
92+
Err(CommonErrors::OperationAborted) => FutureInternalReturn::ready(Err(CommonErrors::OperationAborted)),
93+
Err(e) => {
94+
not_recoverable_error!(with e, "There has been an error in a task that is not recoverable ({})!");
95+
}
96+
}
97+
}
98+
}
99+
}
100+
FutureState::Polled => {
101+
let waker = cx.waker();
102+
103+
// Set the waker, return values tells what have happen and took care about correct synchronization
104+
let was_set = self.for_task.set_join_handle_waker(waker.clone());
105+
106+
if was_set {
107+
FutureInternalReturn::default()
108+
} else {
109+
// Safety belows forms AqrRel so waker is really written before we do marking
77110
let mut ret: Result<T, CommonErrors> = Err(CommonErrors::NoData);
78111
let ret_as_ptr = &mut ret as *mut _;
79112
self.for_task.get_return_val(ret_as_ptr as *mut u8);
80113

81114
match ret {
82115
Ok(v) => FutureInternalReturn::ready(Ok(v)),
116+
Err(CommonErrors::NoData) => FutureInternalReturn::polled(),
83117
Err(CommonErrors::OperationAborted) => FutureInternalReturn::ready(Err(CommonErrors::OperationAborted)),
84118
Err(e) => {
85119
not_recoverable_error!(with e, "There has been an error in a task that is not recoverable ({})!");
86120
}
87121
}
88122
}
89123
}
90-
FutureState::Polled => {
91-
// Safety belows forms AqrRel so waker is really written before we do marking
92-
let mut ret: Result<T, CommonErrors> = Err(CommonErrors::NoData);
93-
let ret_as_ptr = &mut ret as *mut _;
94-
self.for_task.get_return_val(ret_as_ptr as *mut u8);
95-
96-
match ret {
97-
Ok(v) => FutureInternalReturn::ready(Ok(v)),
98-
Err(CommonErrors::NoData) => FutureInternalReturn::polled(),
99-
Err(CommonErrors::OperationAborted) => FutureInternalReturn::ready(Err(CommonErrors::OperationAborted)),
100-
Err(e) => {
101-
not_recoverable_error!(with e, "There has been an error in a task that is not recoverable ({})!");
102-
}
103-
}
104-
}
105124
FutureState::Finished => {
106125
not_recoverable_error!("Future polled after it finished!");
107126
}
@@ -256,6 +275,40 @@ mod tests {
256275
assert_eq!(poller.poll(), ::core::task::Poll::Ready(Ok(0)));
257276
}
258277
}
278+
279+
#[test]
280+
fn test_join_handle_waker_is_set_in_polled_state_also() {
281+
let scheduler = create_mock_scheduler();
282+
283+
{
284+
// Data is present before first poll of join handle
285+
let task = ArcInternal::new(AsyncTask::new(box_future(test_function::<u32>()), 1, scheduler.clone()));
286+
287+
let handle = JoinHandle::<u32>::new(TaskRef::new(task.clone()));
288+
289+
let mut poller = TestingFuturePoller::new(handle);
290+
291+
let waker_mock1 = TrackableWaker::new();
292+
let waker1 = waker_mock1.get_waker();
293+
294+
let waker_mock2 = TrackableWaker::new();
295+
let waker2 = waker_mock2.get_waker();
296+
297+
let _ = poller.poll_with_waker(&waker1);
298+
// Now in polled state, poll again with waker2
299+
let _ = poller.poll_with_waker(&waker2);
300+
{
301+
let waker = noop_waker();
302+
let mut cx = Context::from_waker(&waker);
303+
task.poll(&mut cx); // task done
304+
}
305+
306+
assert!(!waker_mock1.was_waked());
307+
// this should be TRUE
308+
assert!(waker_mock2.was_waked());
309+
assert_eq!(poller.poll(), ::core::task::Poll::Ready(Ok(0)));
310+
}
311+
}
259312
}
260313

261314
#[cfg(test)]
@@ -277,8 +330,9 @@ mod tests {
277330

278331
#[test]
279332
fn test_join_handler_mt_get_result() {
280-
let builder = Builder::new();
281-
333+
let mut builder = Builder::new();
334+
// Limit preemption to avoid loom error "Model exceeded maximum number of branches."
335+
builder.preemption_bound = Some(4);
282336
builder.check(|| {
283337
let scheduler = create_mock_scheduler();
284338

@@ -299,22 +353,17 @@ mod tests {
299353

300354
let waker_mock = TrackableWaker::new();
301355
let waker = waker_mock.get_waker();
302-
let mut was_pending = false;
303-
304356
loop {
305357
match poller.poll_with_waker(&waker) {
306358
Poll::Ready(v) => {
307359
assert_eq!(v, Ok(1234));
308-
309-
if was_pending {
310-
assert!(waker_mock.was_waked());
311-
}
360+
// Note:
361+
// Cannot check whether the waker was woken or not since the waker is set in the join handle poll every time if task is not yet done.
362+
// So depending on the interleaving, the task may finish before the waker is set.
312363

313364
break;
314365
}
315-
Poll::Pending => {
316-
was_pending = true;
317-
}
366+
Poll::Pending => {}
318367
}
319368
loom::hint::spin_loop();
320369
}

src/kyron/src/scheduler/safety_waker.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//
1313

1414
use super::task::async_task::*;
15+
use crate::scheduler::task::task_context::TaskContext;
1516
use core::task::{RawWaker, RawWakerVTable, Waker};
1617

1718
fn clone_waker(data: *const ()) -> RawWaker {
@@ -30,13 +31,18 @@ fn wake(data: *const ()) {
3031
let task_header_ptr = data as *const TaskHeader;
3132
let task_ref = unsafe { TaskRef::from_raw(task_header_ptr) };
3233

34+
// Just clear the flag which might have been set by async worker before calling wake/wake_by_ref
35+
// for the scenario where the join handle poll is executed by safety worker and waker is set
36+
TaskContext::clear_schedule_safety_flag();
3337
task_ref.schedule_safety();
3438
}
3539

3640
fn wake_by_ref(data: *const ()) {
3741
let task_header_ptr = data as *const TaskHeader;
3842
let task_ref = unsafe { TaskRef::from_raw(task_header_ptr) };
3943

44+
// Just clear the flag which might have been set by async worker before calling wake/wake_by_ref
45+
TaskContext::clear_schedule_safety_flag();
4046
task_ref.schedule_safety_by_ref();
4147

4248
::core::mem::forget(task_ref); // don't touch refcount from our data since this is done by drop_waker
@@ -55,11 +61,9 @@ static VTABLE: RawWakerVTable = RawWakerVTable::new(clone_waker, wake, wake_by_r
5561
///
5662
/// Waker will store internally a pointer to the ref counted Task.
5763
///
58-
pub(crate) unsafe fn create_safety_waker(waker: Waker) -> Waker {
59-
let raw_waker = RawWaker::new(waker.data(), &VTABLE);
60-
61-
// Forget original as we took over the ownership, so ref count
62-
::core::mem::forget(waker);
64+
pub(crate) fn create_safety_waker(ptr: TaskRef) -> Waker {
65+
let ptr = TaskRef::into_raw(ptr); // Extracts the pointer from TaskRef not decreasing it's reference count. Since we have a clone here, ref cnt was already increased
66+
let raw_waker = RawWaker::new(ptr as *const (), &VTABLE);
6367

6468
// Convert RawWaker to Waker
6569
unsafe { Waker::from_raw(raw_waker) }

0 commit comments

Comments
 (0)