Skip to content

Commit 825fc4c

Browse files
scheduler: Fix bug in scheduling safety task
Fixed bugs related to handling and scheduling safety task. #18 #20 #39
1 parent c348284 commit 825fc4c

File tree

15 files changed

+233
-70
lines changed

15 files changed

+233
-70
lines changed

src/kyron/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,15 @@ rust_binary(
100100
visibility = ["//visibility:public"],
101101
deps = _EXAMPLE_DEPS,
102102
)
103+
104+
rust_binary(
105+
name = "safety_task",
106+
srcs = [
107+
"examples/safety_task.rs",
108+
],
109+
proc_macro_deps = [
110+
"//src/kyron-macros:runtime_macros",
111+
],
112+
visibility = ["//visibility:public"],
113+
deps = _EXAMPLE_DEPS,
114+
)

src/kyron/src/scheduler/context.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ pub(crate) fn ctx_get_drivers() -> Drivers {
481481
.unwrap()
482482
}
483483

484+
#[allow(dead_code)] // Mock function is used instead of this if mock runtime feature is enabled
484485
///
485486
/// Sets currently running `task`
486487
///
@@ -494,6 +495,7 @@ pub(super) fn ctx_set_running_task(task: TaskRef) {
494495
});
495496
}
496497

498+
#[allow(dead_code)] // Mock function is used instead of this if mock runtime feature is enabled
497499
///
498500
/// Clears currently running `task`
499501
///
@@ -505,6 +507,7 @@ pub(super) fn ctx_unset_running_task() {
505507
.map_err(|_| {});
506508
}
507509

510+
#[allow(dead_code)] // Mock function is used instead of this if mock runtime feature is enabled
508511
///
509512
/// Gets currently running `task id`
510513
///
@@ -523,6 +526,24 @@ pub(crate) fn ctx_get_running_task_id() -> Option<TaskId> {
523526
})
524527
}
525528

529+
#[allow(dead_code)] // Mock function is used instead of this if mock runtime feature is enabled
530+
///
531+
/// Returns `true` if the running task resulted in safety error
532+
///
533+
pub(crate) fn ctx_get_task_safety_error() -> bool {
534+
CTX.try_with(|ctx| {
535+
// This funcation can be called from a thread outside of Kyron runtime through wake()/wake_by_ref(), so we need to check for ctx presence
536+
if let Some(cx) = ctx.borrow().as_ref() {
537+
cx.running_task.borrow().as_ref().is_some_and(|task| task.get_task_safety_error())
538+
} else {
539+
false
540+
}
541+
})
542+
.unwrap_or_else(|e| {
543+
panic!("Something is really bad here, error {}!", e);
544+
})
545+
}
546+
526547
#[cfg(test)]
527548
mod tests {
528549
use super::*;

src/kyron/src/scheduler/execution_engine.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,8 @@ mod tests {
506506
#[test]
507507
#[cfg(not(miri))] // Provenance issues
508508
fn create_engine_with_worker_and_verify_ids() {
509-
use crate::scheduler::context::{ctx_get_running_task_id, ctx_get_worker_id};
509+
use crate::scheduler::context::ctx_get_worker_id;
510+
use crate::testing::mock_context::ctx_get_running_task_id;
510511
let mut engine = ExecutionEngineBuilder::new().workers(1).task_queue_size(16).set_engine_id(1).build();
511512
let result: Result<bool, ()> = engine
512513
.run_in_engine(async move {

src/kyron/src/scheduler/join_handle.rs

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ impl<T: Send + 'static> Future for JoinHandle<T> {
6565
///
6666
fn poll(self: ::core::pin::Pin<&mut Self>, cx: &mut ::core::task::Context<'_>) -> Poll<Self::Output> {
6767
let res: FutureInternalReturn<JoinResult<T>> = match self.state {
68-
FutureState::New => {
68+
FutureState::New | FutureState::Polled => {
6969
let waker = cx.waker();
7070

7171
// Set the waker, return values tells what have happen and took care about correct synchronization
@@ -80,28 +80,14 @@ impl<T: Send + 'static> Future for JoinHandle<T> {
8080

8181
match ret {
8282
Ok(v) => FutureInternalReturn::ready(Ok(v)),
83+
Err(CommonErrors::NoData) => FutureInternalReturn::polled(),
8384
Err(CommonErrors::OperationAborted) => FutureInternalReturn::ready(Err(CommonErrors::OperationAborted)),
8485
Err(e) => {
8586
not_recoverable_error!(with e, "There has been an error in a task that is not recoverable ({})!");
8687
}
8788
}
8889
}
8990
}
90-
FutureState::Polled => {
91-
// Safety belows forms AqrRel so waker is really written before we do marking
92-
let mut ret: Result<T, CommonErrors> = Err(CommonErrors::NoData);
93-
let ret_as_ptr = &mut ret as *mut _;
94-
self.for_task.get_return_val(ret_as_ptr as *mut u8);
95-
96-
match ret {
97-
Ok(v) => FutureInternalReturn::ready(Ok(v)),
98-
Err(CommonErrors::NoData) => FutureInternalReturn::polled(),
99-
Err(CommonErrors::OperationAborted) => FutureInternalReturn::ready(Err(CommonErrors::OperationAborted)),
100-
Err(e) => {
101-
not_recoverable_error!(with e, "There has been an error in a task that is not recoverable ({})!");
102-
}
103-
}
104-
}
10591
FutureState::Finished => {
10692
not_recoverable_error!("Future polled after it finished!");
10793
}
@@ -262,6 +248,41 @@ mod tests {
262248
assert_eq!(poller.poll(), ::core::task::Poll::Ready(Ok(0)));
263249
}
264250
}
251+
252+
#[test]
253+
fn test_join_handle_waker_is_set_in_polled_state_also() {
254+
let scheduler = create_mock_scheduler();
255+
256+
{
257+
// Data is present before first poll of join handle
258+
let worker_id = create_mock_worker_id(0, 1);
259+
let task = ArcInternal::new(AsyncTask::new(box_future(test_function::<u32>()), &worker_id, scheduler.clone()));
260+
261+
let handle = JoinHandle::<u32>::new(TaskRef::new(task.clone()));
262+
263+
let mut poller = TestingFuturePoller::new(handle);
264+
265+
let waker_mock1 = TrackableWaker::new();
266+
let waker1 = waker_mock1.get_waker();
267+
268+
let waker_mock2 = TrackableWaker::new();
269+
let waker2 = waker_mock2.get_waker();
270+
271+
let _ = poller.poll_with_waker(&waker1);
272+
// Now in polled state, poll again with waker2
273+
let _ = poller.poll_with_waker(&waker2);
274+
{
275+
let waker = noop_waker();
276+
let mut cx = Context::from_waker(&waker);
277+
task.poll(&mut cx); // task done
278+
}
279+
280+
assert!(!waker_mock1.was_waked());
281+
// this should be TRUE
282+
assert!(waker_mock2.was_waked());
283+
assert_eq!(poller.poll(), ::core::task::Poll::Ready(Ok(0)));
284+
}
285+
}
265286
}
266287

267288
#[cfg(test)]
@@ -284,8 +305,9 @@ mod tests {
284305

285306
#[test]
286307
fn test_join_handler_mt_get_result() {
287-
let builder = Builder::new();
288-
308+
let mut builder = Builder::new();
309+
// Limit preemption to avoid loom error "Model exceeded maximum number of branches."
310+
builder.preemption_bound = Some(4);
289311
builder.check(|| {
290312
let scheduler = create_mock_scheduler();
291313

@@ -307,22 +329,17 @@ mod tests {
307329

308330
let waker_mock = TrackableWaker::new();
309331
let waker = waker_mock.get_waker();
310-
let mut was_pending = false;
311-
312332
loop {
313333
match poller.poll_with_waker(&waker) {
314334
Poll::Ready(v) => {
315335
assert_eq!(v, Ok(1234));
316-
317-
if was_pending {
318-
assert!(waker_mock.was_waked());
319-
}
336+
// Note:
337+
// Cannot check whether the waker was woken or not since the waker is set in the join handle poll every time if task is not yet done.
338+
// So depending on the interleaving, the task may finish before the waker is set.
320339

321340
break;
322341
}
323-
Poll::Pending => {
324-
was_pending = true;
325-
}
342+
Poll::Pending => {}
326343
}
327344
loom::hint::spin_loop();
328345
}

src/kyron/src/scheduler/safety_waker.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,9 @@ static VTABLE: RawWakerVTable = RawWakerVTable::new(clone_waker, wake, wake_by_r
5555
///
5656
/// Waker will store internally a pointer to the ref counted Task.
5757
///
58-
pub(crate) unsafe fn create_safety_waker(waker: Waker) -> Waker {
59-
let raw_waker = RawWaker::new(waker.data(), &VTABLE);
60-
61-
// Forget original as we took over the ownership, so ref count
62-
::core::mem::forget(waker);
58+
pub(crate) fn create_safety_waker(ptr: TaskRef) -> Waker {
59+
let ptr = TaskRef::into_raw(ptr); // Extracts the pointer from TaskRef not decreasing it's reference count. Since we have a clone here, ref cnt was already increased
60+
let raw_waker = RawWaker::new(ptr as *const (), &VTABLE);
6361

6462
// Convert RawWaker to Waker
6563
unsafe { Waker::from_raw(raw_waker) }

src/kyron/src/scheduler/task/async_task.rs

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313

1414
use super::task_state::*;
1515
use crate::core::types::*;
16-
use crate::scheduler::safety_waker::create_safety_waker;
1716
use crate::scheduler::scheduler_mt::SchedulerTrait;
1817
use crate::scheduler::workers::worker_types::WorkerId;
18+
use ::core::cell::Cell;
1919
use ::core::future::Future;
2020
use ::core::mem;
2121
use ::core::ops::{Deref, DerefMut};
@@ -82,7 +82,7 @@ pub(crate) enum TaskStage<T, ResultType> {
8282
pub(crate) struct TaskHeader {
8383
pub(in crate::scheduler) state: TaskState,
8484
id: TaskId,
85-
85+
is_safety_error: Cell<bool>, // Flag to indicate whether task resulted in safety error
8686
vtable: &'static TaskVTable, // API entrypoint to typed task
8787
}
8888

@@ -97,6 +97,7 @@ impl TaskHeader {
9797
Self {
9898
state: TaskState::new(),
9999
id: TaskId::new(worker_id),
100+
is_safety_error: Cell::new(false),
100101
vtable: create_task_vtable::<T, AllocatedFuture, SchedulerType>(),
101102
}
102103
}
@@ -111,9 +112,18 @@ impl TaskHeader {
111112
Self {
112113
state: TaskState::new(),
113114
id: TaskId::new(worker_id),
115+
is_safety_error: Cell::new(false),
114116
vtable: create_task_s_vtable::<T, E, AllocatedFuture, SchedulerType>(),
115117
}
116118
}
119+
120+
pub(crate) fn set_safety_error(&self) {
121+
self.is_safety_error.set(true);
122+
}
123+
124+
pub(crate) fn get_safety_error(&self) -> bool {
125+
self.is_safety_error.get()
126+
}
117127
}
118128

119129
#[derive(PartialEq, Debug)]
@@ -203,13 +213,20 @@ where
203213
}
204214

205215
pub(crate) fn set_waker(&self, waker: Waker) -> bool {
206-
unsafe {
207-
self.handle_waker.with_mut(|ptr| {
208-
*ptr = Some(waker);
209-
})
216+
// Safety: Unset join handle flag before setting waker, the flag would have been set previously in the first poll of join handle.
217+
// If flag is not cleared, another worker finishing the task will see the flag set and
218+
// read the waker to call wake() while it is written here
219+
if self.header.state.unset_join_handle() {
220+
unsafe {
221+
self.handle_waker.with_mut(|ptr| {
222+
*ptr = Some(waker);
223+
})
224+
};
225+
226+
return self.header.state.set_join_handle();
210227
}
211228

212-
self.header.state.set_waker() // Safety: makes sure storing waker is not reordered behind this operation
229+
false
213230
}
214231

215232
///
@@ -308,12 +325,10 @@ where
308325
self.handle_waker.with_mut(|ptr: *mut Option<Waker>| match unsafe { (*ptr).take() } {
309326
Some(v) => {
310327
if is_safety_err && self.is_with_safety {
311-
unsafe {
312-
create_safety_waker(v).wake();
313-
}
314-
} else {
315-
v.wake();
328+
// Set saftey error flag which will be checked in wake()/wkae_by_ref() to schedule parent task into safety worker
329+
self.header.set_safety_error();
316330
}
331+
v.wake();
317332
}
318333
None => not_recoverable_error!("We shall never be here if we have HadConnectedJoinHandle set!"),
319334
})
@@ -614,6 +629,10 @@ impl TaskRef {
614629
snapshot.is_completed() || snapshot.is_canceled()
615630
}
616631

632+
pub(crate) fn get_task_safety_error(&self) -> bool {
633+
unsafe { self.header.as_ref().get_safety_error() }
634+
}
635+
617636
pub(crate) fn id(&self) -> TaskId {
618637
unsafe { self.header.as_ref().id }
619638
}
@@ -641,7 +660,10 @@ mod tests {
641660
safety::SafetyResult,
642661
scheduler::{
643662
scheduler_mt::SchedulerTrait,
644-
task::async_task::{TaskId, TaskRef},
663+
task::{
664+
async_task::{TaskId, TaskRef},
665+
task_context::TaskContextGuard,
666+
},
645667
},
646668
testing::*,
647669
};
@@ -778,6 +800,7 @@ mod tests {
778800
safety_task_ref.set_join_handle_waker(waker.clone()); // Mimic that JoinHandler is set
779801

780802
let mut ctx = Context::from_waker(&waker);
803+
let _guard = TaskContextGuard::new(safety_task_ref.clone());
781804
assert_eq!(safety_task_ref.poll(&mut ctx), TaskPollResult::Done);
782805

783806
let mut result: SafetyResult<bool, bool> = Ok(true);
@@ -801,6 +824,7 @@ mod tests {
801824
safety_task_ref.set_join_handle_waker(waker.clone()); // Mimic that JoinHandler is set
802825

803826
let mut ctx = Context::from_waker(&waker);
827+
let _guard = TaskContextGuard::new(safety_task_ref.clone());
804828
assert_eq!(safety_task_ref.poll(&mut ctx), TaskPollResult::Done);
805829

806830
let mut result: SafetyResult<bool, bool> = Ok(true);
@@ -829,6 +853,7 @@ mod tests {
829853
safety_task_ref.set_join_handle_waker(waker.clone()); // Mimic that JoinHandler is set
830854

831855
let mut ctx = Context::from_waker(&waker);
856+
let _guard = TaskContextGuard::new(safety_task_ref.clone());
832857
assert_eq!(safety_task_ref.poll(&mut ctx), TaskPollResult::Done);
833858

834859
let mut result: SafetyResult<bool, bool> = Ok(true);
@@ -857,6 +882,7 @@ mod tests {
857882
safety_task_ref.set_join_handle_waker(waker.clone()); // Mimic that JoinHandler is set
858883

859884
let mut ctx = Context::from_waker(&waker);
885+
let _guard = TaskContextGuard::new(safety_task_ref.clone());
860886
assert_eq!(safety_task_ref.poll(&mut ctx), TaskPollResult::Done);
861887

862888
let mut result: SafetyResult<bool, bool> = Ok(true);

src/kyron/src/scheduler/task/task_context.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
// SPDX-License-Identifier: Apache-2.0
1212
//
1313

14+
#[cfg(not(any(test, feature = "runtime-api-mock")))]
15+
use crate::scheduler::context::{ctx_get_running_task_id, ctx_get_task_safety_error, ctx_set_running_task, ctx_unset_running_task};
16+
#[cfg(any(test, feature = "runtime-api-mock"))]
17+
use crate::testing::mock_context::{ctx_get_running_task_id, ctx_get_task_safety_error, ctx_set_running_task, ctx_unset_running_task};
1418
use crate::{
1519
core::types::TaskId,
16-
scheduler::{
17-
context::{ctx_get_running_task_id, ctx_get_worker_id, ctx_set_running_task, ctx_unset_running_task},
18-
workers::worker_types::WorkerId,
19-
},
20+
scheduler::{context::ctx_get_worker_id, workers::worker_types::WorkerId},
2021
TaskRef,
2122
};
2223

@@ -33,6 +34,11 @@ impl TaskContext {
3334
pub fn task_id() -> Option<TaskId> {
3435
ctx_get_running_task_id()
3536
}
37+
38+
/// Check whether the running task resulted in safety error to schedule parent into safety worker
39+
pub(crate) fn should_wake_task_into_safety() -> bool {
40+
ctx_get_task_safety_error()
41+
}
3642
}
3743

3844
/// A guard that sets the task on creation and unsets it on drop.

0 commit comments

Comments
 (0)