Skip to content

Commit 628a3f3

Browse files
committed
Add BatchSemaphore::{held_permits, is_queued}
1 parent 5ccf306 commit 628a3f3

File tree

2 files changed

+67
-21
lines changed

2 files changed

+67
-21
lines changed

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ regex = { version = "1.10.6", optional = true }
2727
serde = { version = "1.0", features = ["derive"], optional = true }
2828
serde_json = { version = "1.0", optional = true }
2929

30+
indexmap = { version = "2.9", features = ["std"] }
31+
ahash = "0.8"
32+
3033
[dev-dependencies]
3134
criterion = { version = "0.4.0", features = ["html_reports"] }
3235
futures = "0.3.15"

src/future/batch_semaphore.rs

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
//! A counting semaphore supporting both async and sync operations.
2-
use crate::current;
2+
use crate::current::{self, get_current_task};
33
use crate::runtime::execution::ExecutionState;
44
use crate::runtime::task::{clock::VectorClock, TaskId};
55
use crate::runtime::thread;
6+
use ahash::random_state::RandomState;
7+
use indexmap::IndexMap;
68
use std::cell::RefCell;
79
use std::collections::VecDeque;
810
use std::fmt;
@@ -12,7 +14,7 @@ use std::sync::atomic::{AtomicBool, Ordering};
1214
use std::sync::Arc;
1315
use std::sync::Mutex;
1416
use std::task::{Context, Poll, Waker};
15-
use tracing::trace;
17+
use tracing::{error, trace};
1618

1719
struct Waiter {
1820
task_id: TaskId,
@@ -209,10 +211,17 @@ struct BatchSemaphoreState {
209211
permits_available: PermitsAvailable,
210212
// TODO: should there be a clock for the close event?
211213
closed: bool,
214+
// tracks which tasks hold how many permits.
215+
holders: IndexMap<TaskId, usize, RandomState>,
212216
}
213217

214218
impl BatchSemaphoreState {
215-
fn acquire_permits(&mut self, num_permits: usize, fairness: Fairness) -> Result<(), TryAcquireError> {
219+
fn acquire_permits(
220+
&mut self,
221+
num_permits: usize,
222+
fairness: Fairness,
223+
task_id: TaskId,
224+
) -> Result<(), TryAcquireError> {
216225
assert!(num_permits > 0);
217226
if self.closed {
218227
Err(TryAcquireError::Closed)
@@ -233,6 +242,11 @@ impl BatchSemaphoreState {
233242
s.update_clock(&clock);
234243
});
235244

245+
self.holders
246+
.entry(task_id)
247+
.and_modify(|permits| *permits += num_permits)
248+
.or_insert(num_permits);
249+
236250
Ok(())
237251
} else {
238252
Err(TryAcquireError::NoPermits)
@@ -329,6 +343,7 @@ impl BatchSemaphore {
329343
waiters: VecDeque::new(),
330344
permits_available: PermitsAvailable::new(num_permits),
331345
closed: false,
346+
holders: IndexMap::with_hasher(RandomState::with_seeds(0, 0, 0, 0)),
332347
});
333348
Self { state, fairness }
334349
}
@@ -340,6 +355,7 @@ impl BatchSemaphore {
340355
waiters: VecDeque::new(),
341356
permits_available: PermitsAvailable::const_new(num_permits),
342357
closed: false,
358+
holders: IndexMap::with_hasher(RandomState::with_seeds(0, 0, 0, 0)),
343359
});
344360
Self { state, fairness }
345361
}
@@ -405,24 +421,30 @@ impl BatchSemaphore {
405421
self.init_object_id();
406422
let mut state = self.state.borrow_mut();
407423
let id = state.id.unwrap();
408-
let res = state.acquire_permits(num_permits, self.fairness).inspect_err(|_err| {
409-
// Conservatively, the requester causally depends on the
410-
// last successful acquire.
411-
// TODO: This is not precise, but `try_acquire` causal dependency
412-
// TODO: is both hard to define, and is most likely not worth the
413-
// TODO: effort. The cases where causality would be tracked
414-
// TODO: "imprecisely" do not correspond to commonly used sync.
415-
// TODO: primitives, such as mutexes, mutexes, or condvars.
416-
// TODO: An example would be a counting semaphore used to guard
417-
// TODO: access to N homogenous resources (as opposed to FIFO,
418-
// TODO: heterogenous resources).
419-
// TODO: More precision could be gained by tracking clocks for all
420-
// TODO: current permit holders, with a data structure similar to
421-
// TODO: `permits_available`.
422-
ExecutionState::with(|s| {
423-
s.update_clock(&state.permits_available.last_acquire);
424-
});
424+
let task_id = get_current_task().unwrap_or_else(|| {
425+
error!("Tried to acquire a semaphore while there is no current task. Panicking");
426+
panic!("Tried to acquire a semaphore while there is no current task.");
425427
});
428+
let res = state
429+
.acquire_permits(num_permits, self.fairness, task_id)
430+
.inspect_err(|_err| {
431+
// Conservatively, the requester causally depends on the
432+
// last successful acquire.
433+
// TODO: This is not precise, but `try_acquire` causal dependency
434+
// TODO: is both hard to define, and is most likely not worth the
435+
// TODO: effort. The cases where causality would be tracked
436+
// TODO: "imprecisely" do not correspond to commonly used sync.
437+
// TODO: primitives, such as mutexes, mutexes, or condvars.
438+
// TODO: An example would be a counting semaphore used to guard
439+
// TODO: access to N homogenous resources (as opposed to FIFO,
440+
// TODO: heterogenous resources).
441+
// TODO: More precision could be gained by tracking clocks for all
442+
// TODO: current permit holders, with a data structure similar to
443+
// TODO: `permits_available`.
444+
ExecutionState::with(|s| {
445+
s.update_clock(&state.permits_available.last_acquire);
446+
});
447+
});
426448
drop(state);
427449

428450
// If we won the race for permits of an unfair semaphore, re-block
@@ -586,6 +608,23 @@ impl BatchSemaphore {
586608
// Releasing a semaphore is a yield point
587609
thread::switch();
588610
}
611+
612+
/// Returns the number of permits currently held by the given `task_id`
613+
pub fn held_permits(&self, task_id: &TaskId) -> usize {
614+
let state = self.state.borrow_mut();
615+
*state.holders.get(task_id).unwrap_or(&0)
616+
}
617+
618+
/// Returns `true` iff the given `task_id` is currently waiting to acquire permits from the `BatchSemaphore`
619+
pub fn is_queued(&self, task_id: &TaskId) -> bool {
620+
let state = self.state.borrow_mut();
621+
for waiter in &state.waiters {
622+
if waiter.task_id == *task_id {
623+
return true;
624+
}
625+
}
626+
false
627+
}
589628
}
590629

591630
// Safety: Semaphore is never actually passed across true threads, only across continuations. The
@@ -677,7 +716,11 @@ impl Future for Acquire<'_> {
677716
// clock, as this thread will be blocked below.
678717
let mut state = self.semaphore.state.borrow_mut();
679718
let id = state.id.unwrap();
680-
let acquire_result = state.acquire_permits(self.waiter.num_permits, self.semaphore.fairness);
719+
let task_id = get_current_task().unwrap_or_else(|| {
720+
error!("Tried to acquire a semaphore while there is no current task. Panicking");
721+
panic!("Tried to acquire a semaphore while there is no current task.");
722+
});
723+
let acquire_result = state.acquire_permits(self.waiter.num_permits, self.semaphore.fairness, task_id);
681724
drop(state);
682725

683726
match acquire_result {

0 commit comments

Comments
 (0)