Skip to content

Commit 4158a64

Browse files
committed
Add BatchSemaphore::{held_permits, is_queued}
1 parent d00bba1 commit 4158a64

File tree

3 files changed

+67
-22
lines changed

3 files changed

+67
-22
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ members = [
55
"wrappers/shuttle_sync",
66
]
77

8-
resolver = "2"
8+
resolver = "2"

shuttle/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ scoped-tls = "1.0.0"
2121
smallvec = { version = "1.11.2", features = ["const_new"] }
2222
tracing = { version = "0.1.36", default-features = false, features = ["std"] }
2323
corosensei = "0.3.1"
24+
indexmap = { version = "2.9", features = ["std"] }
25+
ahash = "0.8"
2426

2527
# for annotation only
2628
regex = { version = "1.10.6", optional = true }

shuttle/src/future/batch_semaphore.rs

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
//! A counting semaphore supporting both async and sync operations.
2-
use crate::current;
2+
use crate::current::{self, get_current_task};
33
use crate::runtime::execution::ExecutionState;
44
use crate::runtime::task::{clock::VectorClock, TaskId};
55
use crate::runtime::thread;
66
use crate::sync::{ResourceSignature, ResourceType};
7+
use ahash::random_state::RandomState;
8+
use indexmap::IndexMap;
79
use std::cell::RefCell;
810
use std::collections::VecDeque;
911
use std::fmt;
@@ -13,7 +15,7 @@ use std::sync::atomic::{AtomicBool, Ordering};
1315
use std::sync::Arc;
1416
use std::sync::Mutex;
1517
use std::task::{Context, Poll, Waker};
16-
use tracing::trace;
18+
use tracing::{error, trace};
1719

1820
struct Waiter {
1921
task_id: TaskId,
@@ -210,10 +212,17 @@ struct BatchSemaphoreState {
210212
permits_available: PermitsAvailable,
211213
// TODO: should there be a clock for the close event?
212214
closed: bool,
215+
// tracks which tasks hold how many permits.
216+
holders: IndexMap<TaskId, usize, RandomState>,
213217
}
214218

215219
impl BatchSemaphoreState {
216-
fn acquire_permits(&mut self, num_permits: usize, fairness: Fairness) -> Result<(), TryAcquireError> {
220+
fn acquire_permits(
221+
&mut self,
222+
num_permits: usize,
223+
fairness: Fairness,
224+
task_id: TaskId,
225+
) -> Result<(), TryAcquireError> {
217226
assert!(num_permits > 0);
218227
if self.closed {
219228
Err(TryAcquireError::Closed)
@@ -234,6 +243,11 @@ impl BatchSemaphoreState {
234243
s.update_clock(&clock);
235244
});
236245

246+
self.holders
247+
.entry(task_id)
248+
.and_modify(|permits| *permits += num_permits)
249+
.or_insert(num_permits);
250+
237251
Ok(())
238252
} else {
239253
Err(TryAcquireError::NoPermits)
@@ -341,6 +355,7 @@ impl BatchSemaphore {
341355
waiters: VecDeque::new(),
342356
permits_available: PermitsAvailable::new(num_permits),
343357
closed: false,
358+
holders: IndexMap::with_hasher(RandomState::with_seeds(0, 0, 0, 0)),
344359
});
345360
Self {
346361
state,
@@ -369,6 +384,7 @@ impl BatchSemaphore {
369384
waiters: VecDeque::new(),
370385
permits_available: PermitsAvailable::const_new(num_permits),
371386
closed: false,
387+
holders: IndexMap::with_hasher(RandomState::with_seeds(0, 0, 0, 0)),
372388
});
373389
Self {
374390
state,
@@ -446,24 +462,30 @@ impl BatchSemaphore {
446462
self.init_object_id();
447463
let mut state = self.state.borrow_mut();
448464
let id = state.id.unwrap();
449-
let res = state.acquire_permits(num_permits, self.fairness).inspect_err(|_err| {
450-
// Conservatively, the requester causally depends on the
451-
// last successful acquire.
452-
// TODO: This is not precise, but `try_acquire` causal dependency
453-
// TODO: is both hard to define, and is most likely not worth the
454-
// TODO: effort. The cases where causality would be tracked
455-
// TODO: "imprecisely" do not correspond to commonly used sync.
456-
// TODO: primitives, such as mutexes, mutexes, or condvars.
457-
// TODO: An example would be a counting semaphore used to guard
458-
// TODO: access to N homogenous resources (as opposed to FIFO,
459-
// TODO: heterogenous resources).
460-
// TODO: More precision could be gained by tracking clocks for all
461-
// TODO: current permit holders, with a data structure similar to
462-
// TODO: `permits_available`.
463-
ExecutionState::with(|s| {
464-
s.update_clock(&state.permits_available.last_acquire);
465-
});
465+
let task_id = get_current_task().unwrap_or_else(|| {
466+
error!("Tried to acquire a semaphore while there is no current task. Panicking");
467+
panic!("Tried to acquire a semaphore while there is no current task.");
466468
});
469+
let res = state
470+
.acquire_permits(num_permits, self.fairness, task_id)
471+
.inspect_err(|_err| {
472+
// Conservatively, the requester causally depends on the
473+
// last successful acquire.
474+
// TODO: This is not precise, but `try_acquire` causal dependency
475+
// TODO: is both hard to define, and is most likely not worth the
476+
// TODO: effort. The cases where causality would be tracked
477+
// TODO: "imprecisely" do not correspond to commonly used sync.
478+
// TODO: primitives, such as mutexes, mutexes, or condvars.
479+
// TODO: An example would be a counting semaphore used to guard
480+
// TODO: access to N homogenous resources (as opposed to FIFO,
481+
// TODO: heterogenous resources).
482+
// TODO: More precision could be gained by tracking clocks for all
483+
// TODO: current permit holders, with a data structure similar to
484+
// TODO: `permits_available`.
485+
ExecutionState::with(|s| {
486+
s.update_clock(&state.permits_available.last_acquire);
487+
});
488+
});
467489
drop(state);
468490

469491
// If we won the race for permits of an unfair semaphore, re-block
@@ -622,6 +644,23 @@ impl BatchSemaphore {
622644
}
623645
drop(state);
624646
}
647+
648+
/// Returns the number of permits currently held by the given `task_id`
649+
pub fn held_permits(&self, task_id: &TaskId) -> usize {
650+
let state = self.state.borrow_mut();
651+
*state.holders.get(task_id).unwrap_or(&0)
652+
}
653+
654+
/// Returns `true` iff the given `task_id` is currently waiting to acquire permits from the `BatchSemaphore`
655+
pub fn is_queued(&self, task_id: &TaskId) -> bool {
656+
let state = self.state.borrow_mut();
657+
for waiter in &state.waiters {
658+
if waiter.task_id == *task_id {
659+
return true;
660+
}
661+
}
662+
false
663+
}
625664
}
626665

627666
// Safety: Semaphore is never actually passed across true threads, only across continuations. The
@@ -748,7 +787,11 @@ impl Future for Acquire<'_> {
748787
// clock, as this thread will be blocked below.
749788
let mut state = self.semaphore.state.borrow_mut();
750789
let id = state.id.unwrap();
751-
let acquire_result = state.acquire_permits(self.waiter.num_permits, self.semaphore.fairness);
790+
let task_id = get_current_task().unwrap_or_else(|| {
791+
error!("Tried to acquire a semaphore while there is no current task. Panicking");
792+
panic!("Tried to acquire a semaphore while there is no current task.");
793+
});
794+
let acquire_result = state.acquire_permits(self.waiter.num_permits, self.semaphore.fairness, task_id);
752795
drop(state);
753796

754797
match acquire_result {

0 commit comments

Comments
 (0)