Skip to content

Commit 196777d

Browse files
scheduler: Fix bug in scheduling safety task
Fixed bugs related to handling and scheduling safety task. #18 #20 #39
1 parent d146401 commit 196777d

File tree

15 files changed

+246
-78
lines changed

15 files changed

+246
-78
lines changed

src/kyron/src/safety.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ pub fn ensure_safety_enabled() {
4646
///
4747
/// # Safety
4848
/// This API is intended to provide a way to ensure that user can react on errors within a `task` independent of other workers state (ie. being busy looping etc).
49-
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in `SafetyWorker`.
49+
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in either `SafetyWorker` or regular worker.
50+
/// Assumption of Use is that the task that is running on SafetyWorker never blocks.
5051
///
5152
pub fn spawn<F, T, E>(future: F) -> JoinHandle<F::Output>
5253
where
@@ -64,7 +65,8 @@ where
6465
///
6566
/// # Safety
6667
/// This API is intended to provide a way to ensure that user can react on errors within a `task` independent of other workers state (ie. being busy looping etc).
67-
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in `SafetyWorker`.
68+
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in either `SafetyWorker` or regular worker.
69+
/// Assumption of Use is that the task that is running on SafetyWorker never blocks.
6870
///
6971
pub fn spawn_from_boxed<T, E>(boxed: FutureBox<SafetyResult<T, E>>) -> JoinHandle<SafetyResult<T, E>>
7072
where
@@ -88,7 +90,8 @@ where
8890
///
8991
/// # Safety
9092
/// This API is intended to provide a way to ensure that user can react on errors within a `task` independent of other workers state (ie. being busy looping etc).
91-
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in `SafetyWorker`.
93+
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in either `SafetyWorker` or regular worker.
94+
/// Assumption of Use is that the task that is running on SafetyWorker never blocks.
9295
///
9396
pub fn spawn_from_reusable<T, E>(reusable: ReusableBoxFuture<SafetyResult<T, E>>) -> JoinHandle<SafetyResult<T, E>>
9497
where
@@ -113,7 +116,8 @@ where
113116
///
114117
/// # Safety
115118
/// This API is intended to provide a way to ensure that user can react on errors within a `task` independent of other workers state (ie. being busy looping etc).
116-
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in `SafetyWorker`.
119+
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in either `SafetyWorker` or regular worker.
120+
/// Assumption of Use is that the task that is running on SafetyWorker never blocks.
117121
///
118122
pub fn spawn_on_dedicated<F, T, E>(future: F, worker_id: UniqueWorkerId) -> JoinHandle<F::Output>
119123
where
@@ -131,7 +135,8 @@ where
131135
///
132136
/// # Safety
133137
/// This API is intended to provide a way to ensure that user can react on errors within a `task` independent of other workers state (ie. being busy looping etc).
134-
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in `SafetyWorker`.
138+
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in either `SafetyWorker` or regular worker.
139+
/// Assumption of Use is that the task that is running on SafetyWorker never blocks.
135140
///
136141
pub fn spawn_from_boxed_on_dedicated<T, E>(
137142
boxed: FutureBox<SafetyResult<T, E>>,
@@ -158,7 +163,8 @@ where
158163
///
159164
/// # Safety
160165
/// This API is intended to provide a way to ensure that user can react on errors within a `task` independent of other workers state (ie. being busy looping etc).
161-
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in `SafetyWorker`.
166+
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in either `SafetyWorker` or regular worker.
167+
/// Assumption of Use is that the task that is running on SafetyWorker never blocks.
162168
///
163169
pub fn spawn_from_reusable_on_dedicated<T, E>(
164170
reusable: ReusableBoxFuture<SafetyResult<T, E>>,

src/kyron/src/scheduler/context.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ pub(crate) fn ctx_get_drivers() -> Drivers {
542542
.unwrap()
543543
}
544544

545+
#[allow(dead_code)] // Mock function is used instead of this if mock runtime feature is enabled
545546
///
546547
/// Sets currently running `task`
547548
///
@@ -559,6 +560,7 @@ pub(super) fn ctx_set_running_task(task: TaskRef) {
559560
});
560561
}
561562

563+
#[allow(dead_code)] // Mock function is used instead of this if mock runtime feature is enabled
562564
///
563565
/// Clears currently running `task`
564566
///
@@ -574,6 +576,7 @@ pub(super) fn ctx_unset_running_task() {
574576
.map_err(|_| {});
575577
}
576578

579+
#[allow(dead_code)] // Mock function is used instead of this if mock runtime feature is enabled
577580
///
578581
/// Gets currently running `task id`
579582
///
@@ -592,6 +595,27 @@ pub(crate) fn ctx_get_running_task_id() -> Option<TaskId> {
592595
})
593596
}
594597

598+
#[allow(dead_code)] // Mock function is used instead of this if mock runtime feature is enabled
599+
///
600+
/// Returns `true` if the running task resulted in safety error
601+
///
602+
pub(crate) fn ctx_get_task_safety_error() -> bool {
603+
CTX.try_with(|ctx| {
604+
// This funcation can be called from a thread outside of Kyron runtime through wake()/wake_by_ref(), so we need to check for ctx presence
605+
if let Some(cx) = ctx.borrow().as_ref() {
606+
cx.running_task
607+
.borrow()
608+
.as_ref()
609+
.is_some_and(|task| task.get_task_safety_error())
610+
} else {
611+
false
612+
}
613+
})
614+
.unwrap_or_else(|e| {
615+
panic!("Something is really bad here, error {}!", e);
616+
})
617+
}
618+
595619
#[cfg(test)]
596620
mod tests {
597621
use super::*;

src/kyron/src/scheduler/execution_engine.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,8 @@ mod tests {
541541
#[test]
542542
#[cfg(not(miri))] // Provenance issues
543543
fn create_engine_with_worker_and_verify_ids() {
544-
use crate::scheduler::context::{ctx_get_running_task_id, ctx_get_worker_id};
544+
use crate::scheduler::context::ctx_get_worker_id;
545+
use crate::testing::mock_context::ctx_get_running_task_id;
545546
let mut engine = ExecutionEngineBuilder::new()
546547
.workers(1)
547548
.task_queue_size(16)

src/kyron/src/scheduler/join_handle.rs

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ impl<T: Send + 'static> Future for JoinHandle<T> {
6565
///
6666
fn poll(self: ::core::pin::Pin<&mut Self>, cx: &mut ::core::task::Context<'_>) -> Poll<Self::Output> {
6767
let res: FutureInternalReturn<JoinResult<T>> = match self.state {
68-
FutureState::New => {
68+
FutureState::New | FutureState::Polled => {
6969
let waker = cx.waker();
7070

7171
// Set the waker, return values tells what have happen and took care about correct synchronization
@@ -80,6 +80,7 @@ impl<T: Send + 'static> Future for JoinHandle<T> {
8080

8181
match ret {
8282
Ok(v) => FutureInternalReturn::ready(Ok(v)),
83+
Err(CommonErrors::NoData) => FutureInternalReturn::polled(),
8384
Err(CommonErrors::OperationAborted) => {
8485
FutureInternalReturn::ready(Err(CommonErrors::OperationAborted))
8586
},
@@ -89,23 +90,6 @@ impl<T: Send + 'static> Future for JoinHandle<T> {
8990
}
9091
}
9192
},
92-
FutureState::Polled => {
93-
// Safety belows forms AqrRel so waker is really written before we do marking
94-
let mut ret: Result<T, CommonErrors> = Err(CommonErrors::NoData);
95-
let ret_as_ptr = &mut ret as *mut _;
96-
self.for_task.get_return_val(ret_as_ptr as *mut u8);
97-
98-
match ret {
99-
Ok(v) => FutureInternalReturn::ready(Ok(v)),
100-
Err(CommonErrors::NoData) => FutureInternalReturn::polled(),
101-
Err(CommonErrors::OperationAborted) => {
102-
FutureInternalReturn::ready(Err(CommonErrors::OperationAborted))
103-
},
104-
Err(e) => {
105-
not_recoverable_error!(with e, "There has been an error in a task that is not recoverable ({})!");
106-
},
107-
}
108-
},
10993
FutureState::Finished => {
11094
not_recoverable_error!("Future polled after it finished!");
11195
},
@@ -293,6 +277,45 @@ mod tests {
293277
assert_eq!(poller.poll(), ::core::task::Poll::Ready(Ok(0)));
294278
}
295279
}
280+
281+
#[test]
282+
fn test_join_handle_waker_is_set_in_polled_state_also() {
283+
let scheduler = create_mock_scheduler();
284+
285+
{
286+
// Data is present before first poll of join handle
287+
let worker_id = create_mock_worker_id(0, 1);
288+
let task = ArcInternal::new(AsyncTask::new(
289+
box_future(test_function::<u32>()),
290+
&worker_id,
291+
scheduler.clone(),
292+
));
293+
294+
let handle = JoinHandle::<u32>::new(TaskRef::new(task.clone()));
295+
296+
let mut poller = TestingFuturePoller::new(handle);
297+
298+
let waker_mock1 = TrackableWaker::new();
299+
let waker1 = waker_mock1.get_waker();
300+
301+
let waker_mock2 = TrackableWaker::new();
302+
let waker2 = waker_mock2.get_waker();
303+
304+
let _ = poller.poll_with_waker(&waker1);
305+
// Now in polled state, poll again with waker2
306+
let _ = poller.poll_with_waker(&waker2);
307+
{
308+
let waker = noop_waker();
309+
let mut cx = Context::from_waker(&waker);
310+
task.poll(&mut cx); // task done
311+
}
312+
313+
assert!(!waker_mock1.was_waked());
314+
// this should be TRUE
315+
assert!(waker_mock2.was_waked());
316+
assert_eq!(poller.poll(), ::core::task::Poll::Ready(Ok(0)));
317+
}
318+
}
296319
}
297320

298321
#[cfg(test)]
@@ -315,8 +338,9 @@ mod tests {
315338

316339
#[test]
317340
fn test_join_handler_mt_get_result() {
318-
let builder = Builder::new();
319-
341+
let mut builder = Builder::new();
342+
// Limit preemption to avoid loom error "Model exceeded maximum number of branches."
343+
builder.preemption_bound = Some(4);
320344
builder.check(|| {
321345
let scheduler = create_mock_scheduler();
322346

@@ -342,22 +366,17 @@ mod tests {
342366

343367
let waker_mock = TrackableWaker::new();
344368
let waker = waker_mock.get_waker();
345-
let mut was_pending = false;
346-
347369
loop {
348370
match poller.poll_with_waker(&waker) {
349371
Poll::Ready(v) => {
350372
assert_eq!(v, Ok(1234));
351-
352-
if was_pending {
353-
assert!(waker_mock.was_waked());
354-
}
373+
// Note:
374+
// Cannot check whether the waker was woken or not since the waker is set in the join handle poll every time if task is not yet done.
375+
// So depending on the interleaving, the task may finish before the waker is set.
355376

356377
break;
357378
},
358-
Poll::Pending => {
359-
was_pending = true;
360-
},
379+
Poll::Pending => {},
361380
}
362381
loom::hint::spin_loop();
363382
}

src/kyron/src/scheduler/safety_waker.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,9 @@ static VTABLE: RawWakerVTable = RawWakerVTable::new(clone_waker, wake, wake_by_r
5555
///
5656
/// Waker will store internally a pointer to the ref counted Task.
5757
///
58-
pub(crate) unsafe fn create_safety_waker(waker: Waker) -> Waker {
59-
let raw_waker = RawWaker::new(waker.data(), &VTABLE);
60-
61-
// Forget original as we took over the ownership, so ref count
62-
::core::mem::forget(waker);
58+
pub(crate) fn create_safety_waker(ptr: TaskRef) -> Waker {
59+
let ptr = TaskRef::into_raw(ptr); // Extracts the pointer from TaskRef not decreasing it's reference count. Since we have a clone here, ref cnt was already increased
60+
let raw_waker = RawWaker::new(ptr as *const (), &VTABLE);
6361

6462
// Convert RawWaker to Waker
6563
unsafe { Waker::from_raw(raw_waker) }

0 commit comments

Comments
 (0)