11//! A counting semaphore supporting both async and sync operations.
2- use crate :: current;
2+ use crate :: current:: { self , get_current_task } ;
33use crate :: runtime:: execution:: ExecutionState ;
44use crate :: runtime:: task:: { clock:: VectorClock , TaskId } ;
55use crate :: runtime:: thread;
66use crate :: sync:: { ResourceSignature , ResourceType } ;
7+ use ahash:: random_state:: RandomState ;
8+ use indexmap:: IndexMap ;
79use std:: cell:: RefCell ;
810use std:: collections:: VecDeque ;
911use std:: fmt;
@@ -13,7 +15,7 @@ use std::sync::atomic::{AtomicBool, Ordering};
1315use std:: sync:: Arc ;
1416use std:: sync:: Mutex ;
1517use std:: task:: { Context , Poll , Waker } ;
16- use tracing:: trace;
18+ use tracing:: { error , trace} ;
1719
1820struct Waiter {
1921 task_id : TaskId ,
@@ -210,10 +212,17 @@ struct BatchSemaphoreState {
210212 permits_available : PermitsAvailable ,
211213 // TODO: should there be a clock for the close event?
212214 closed : bool ,
215+ // tracks which tasks hold how many permits.
216+ holders : IndexMap < TaskId , usize , RandomState > ,
213217}
214218
215219impl BatchSemaphoreState {
216- fn acquire_permits ( & mut self , num_permits : usize , fairness : Fairness ) -> Result < ( ) , TryAcquireError > {
220+ fn acquire_permits (
221+ & mut self ,
222+ num_permits : usize ,
223+ fairness : Fairness ,
224+ task_id : TaskId ,
225+ ) -> Result < ( ) , TryAcquireError > {
217226 assert ! ( num_permits > 0 ) ;
218227 if self . closed {
219228 Err ( TryAcquireError :: Closed )
@@ -234,6 +243,11 @@ impl BatchSemaphoreState {
234243 s. update_clock ( & clock) ;
235244 } ) ;
236245
246+ self . holders
247+ . entry ( task_id)
248+ . and_modify ( |permits| * permits += num_permits)
249+ . or_insert ( num_permits) ;
250+
237251 Ok ( ( ) )
238252 } else {
239253 Err ( TryAcquireError :: NoPermits )
@@ -341,6 +355,7 @@ impl BatchSemaphore {
341355 waiters : VecDeque :: new ( ) ,
342356 permits_available : PermitsAvailable :: new ( num_permits) ,
343357 closed : false ,
358+ holders : IndexMap :: with_hasher ( RandomState :: with_seeds ( 0 , 0 , 0 , 0 ) ) ,
344359 } ) ;
345360 Self {
346361 state,
@@ -369,6 +384,7 @@ impl BatchSemaphore {
369384 waiters : VecDeque :: new ( ) ,
370385 permits_available : PermitsAvailable :: const_new ( num_permits) ,
371386 closed : false ,
387+ holders : IndexMap :: with_hasher ( RandomState :: with_seeds ( 0 , 0 , 0 , 0 ) ) ,
372388 } ) ;
373389 Self {
374390 state,
@@ -446,24 +462,30 @@ impl BatchSemaphore {
446462 self . init_object_id ( ) ;
447463 let mut state = self . state . borrow_mut ( ) ;
448464 let id = state. id . unwrap ( ) ;
449- let res = state. acquire_permits ( num_permits, self . fairness ) . inspect_err ( |_err| {
450- // Conservatively, the requester causally depends on the
451- // last successful acquire.
452- // TODO: This is not precise, but `try_acquire` causal dependency
453- // TODO: is both hard to define, and is most likely not worth the
454- // TODO: effort. The cases where causality would be tracked
455- // TODO: "imprecisely" do not correspond to commonly used sync.
456- // TODO: primitives, such as mutexes, mutexes, or condvars.
457- // TODO: An example would be a counting semaphore used to guard
458- // TODO: access to N homogenous resources (as opposed to FIFO,
459- // TODO: heterogenous resources).
460- // TODO: More precision could be gained by tracking clocks for all
461- // TODO: current permit holders, with a data structure similar to
462- // TODO: `permits_available`.
463- ExecutionState :: with ( |s| {
464- s. update_clock ( & state. permits_available . last_acquire ) ;
465- } ) ;
465+ let task_id = get_current_task ( ) . unwrap_or_else ( || {
466+ error ! ( "Tried to acquire a semaphore while there is no current task. Panicking" ) ;
467+ panic ! ( "Tried to acquire a semaphore while there is no current task." ) ;
466468 } ) ;
469+ let res = state
470+ . acquire_permits ( num_permits, self . fairness , task_id)
471+ . inspect_err ( |_err| {
472+ // Conservatively, the requester causally depends on the
473+ // last successful acquire.
474+ // TODO: This is not precise, but `try_acquire` causal dependency
475+ // TODO: is both hard to define, and is most likely not worth the
476+ // TODO: effort. The cases where causality would be tracked
477+ // TODO: "imprecisely" do not correspond to commonly used sync.
478+ // TODO: primitives, such as mutexes, mutexes, or condvars.
479+ // TODO: An example would be a counting semaphore used to guard
480+ // TODO: access to N homogenous resources (as opposed to FIFO,
481+ // TODO: heterogenous resources).
482+ // TODO: More precision could be gained by tracking clocks for all
483+ // TODO: current permit holders, with a data structure similar to
484+ // TODO: `permits_available`.
485+ ExecutionState :: with ( |s| {
486+ s. update_clock ( & state. permits_available . last_acquire ) ;
487+ } ) ;
488+ } ) ;
467489 drop ( state) ;
468490
469491 // If we won the race for permits of an unfair semaphore, re-block
@@ -622,6 +644,23 @@ impl BatchSemaphore {
622644 }
623645 drop ( state) ;
624646 }
647+
648+ /// Returns the number of permits currently held by the given `task_id`
649+ pub fn held_permits ( & self , task_id : & TaskId ) -> usize {
650+ let state = self . state . borrow_mut ( ) ;
651+ * state. holders . get ( task_id) . unwrap_or ( & 0 )
652+ }
653+
654+ /// Returns `true` iff the given `task_id` is currently waiting to acquire permits from the `BatchSemaphore`
655+ pub fn is_queued ( & self , task_id : & TaskId ) -> bool {
656+ let state = self . state . borrow_mut ( ) ;
657+ for waiter in & state. waiters {
658+ if waiter. task_id == * task_id {
659+ return true ;
660+ }
661+ }
662+ false
663+ }
625664}
626665
627666// Safety: Semaphore is never actually passed across true threads, only across continuations. The
@@ -748,7 +787,11 @@ impl Future for Acquire<'_> {
748787 // clock, as this thread will be blocked below.
749788 let mut state = self . semaphore . state . borrow_mut ( ) ;
750789 let id = state. id . unwrap ( ) ;
751- let acquire_result = state. acquire_permits ( self . waiter . num_permits , self . semaphore . fairness ) ;
790+ let task_id = get_current_task ( ) . unwrap_or_else ( || {
791+ error ! ( "Tried to acquire a semaphore while there is no current task. Panicking" ) ;
792+ panic ! ( "Tried to acquire a semaphore while there is no current task." ) ;
793+ } ) ;
794+ let acquire_result = state. acquire_permits ( self . waiter . num_permits , self . semaphore . fairness , task_id) ;
752795 drop ( state) ;
753796
754797 match acquire_result {
0 commit comments