11//! A counting semaphore supporting both async and sync operations.
2- use crate :: current;
2+ use crate :: current:: { self , get_current_task } ;
33use crate :: runtime:: execution:: ExecutionState ;
44use crate :: runtime:: task:: { clock:: VectorClock , TaskId } ;
55use crate :: runtime:: thread;
6+ use ahash:: random_state:: RandomState ;
7+ use indexmap:: IndexMap ;
68use std:: cell:: RefCell ;
79use std:: collections:: VecDeque ;
810use std:: fmt;
@@ -12,7 +14,7 @@ use std::sync::atomic::{AtomicBool, Ordering};
1214use std:: sync:: Arc ;
1315use std:: sync:: Mutex ;
1416use std:: task:: { Context , Poll , Waker } ;
15- use tracing:: trace;
17+ use tracing:: { error , trace} ;
1618
1719struct Waiter {
1820 task_id : TaskId ,
@@ -209,10 +211,17 @@ struct BatchSemaphoreState {
209211 permits_available : PermitsAvailable ,
210212 // TODO: should there be a clock for the close event?
211213 closed : bool ,
214+ // tracks which tasks hold how many permits.
215+ holders : IndexMap < TaskId , usize , RandomState > ,
212216}
213217
214218impl BatchSemaphoreState {
215- fn acquire_permits ( & mut self , num_permits : usize , fairness : Fairness ) -> Result < ( ) , TryAcquireError > {
219+ fn acquire_permits (
220+ & mut self ,
221+ num_permits : usize ,
222+ fairness : Fairness ,
223+ task_id : TaskId ,
224+ ) -> Result < ( ) , TryAcquireError > {
216225 assert ! ( num_permits > 0 ) ;
217226 if self . closed {
218227 Err ( TryAcquireError :: Closed )
@@ -233,6 +242,11 @@ impl BatchSemaphoreState {
233242 s. update_clock ( & clock) ;
234243 } ) ;
235244
245+ self . holders
246+ . entry ( task_id)
247+ . and_modify ( |permits| * permits += num_permits)
248+ . or_insert ( num_permits) ;
249+
236250 Ok ( ( ) )
237251 } else {
238252 Err ( TryAcquireError :: NoPermits )
@@ -329,6 +343,7 @@ impl BatchSemaphore {
329343 waiters : VecDeque :: new ( ) ,
330344 permits_available : PermitsAvailable :: new ( num_permits) ,
331345 closed : false ,
346+ holders : IndexMap :: with_hasher ( RandomState :: with_seeds ( 0 , 0 , 0 , 0 ) ) ,
332347 } ) ;
333348 Self { state, fairness }
334349 }
@@ -340,6 +355,7 @@ impl BatchSemaphore {
340355 waiters : VecDeque :: new ( ) ,
341356 permits_available : PermitsAvailable :: const_new ( num_permits) ,
342357 closed : false ,
358+ holders : IndexMap :: with_hasher ( RandomState :: with_seeds ( 0 , 0 , 0 , 0 ) ) ,
343359 } ) ;
344360 Self { state, fairness }
345361 }
@@ -405,24 +421,30 @@ impl BatchSemaphore {
405421 self . init_object_id ( ) ;
406422 let mut state = self . state . borrow_mut ( ) ;
407423 let id = state. id . unwrap ( ) ;
408- let res = state. acquire_permits ( num_permits, self . fairness ) . inspect_err ( |_err| {
409- // Conservatively, the requester causally depends on the
410- // last successful acquire.
411- // TODO: This is not precise, but `try_acquire` causal dependency
412- // TODO: is both hard to define, and is most likely not worth the
413- // TODO: effort. The cases where causality would be tracked
414- // TODO: "imprecisely" do not correspond to commonly used sync.
415- // TODO: primitives, such as mutexes, mutexes, or condvars.
416- // TODO: An example would be a counting semaphore used to guard
417- // TODO: access to N homogenous resources (as opposed to FIFO,
418- // TODO: heterogenous resources).
419- // TODO: More precision could be gained by tracking clocks for all
420- // TODO: current permit holders, with a data structure similar to
421- // TODO: `permits_available`.
422- ExecutionState :: with ( |s| {
423- s. update_clock ( & state. permits_available . last_acquire ) ;
424- } ) ;
424+ let task_id = get_current_task ( ) . unwrap_or_else ( || {
425+ error ! ( "Tried to acquire a semaphore while there is no current task. Panicking" ) ;
426+ panic ! ( "Tried to acquire a semaphore while there is no current task." ) ;
425427 } ) ;
428+ let res = state
429+ . acquire_permits ( num_permits, self . fairness , task_id)
430+ . inspect_err ( |_err| {
431+ // Conservatively, the requester causally depends on the
432+ // last successful acquire.
433+ // TODO: This is not precise, but `try_acquire` causal dependency
434+ // TODO: is both hard to define, and is most likely not worth the
435+ // TODO: effort. The cases where causality would be tracked
436+ // TODO: "imprecisely" do not correspond to commonly used sync.
437+ // TODO: primitives, such as mutexes, mutexes, or condvars.
438+ // TODO: An example would be a counting semaphore used to guard
439+ // TODO: access to N homogenous resources (as opposed to FIFO,
440+ // TODO: heterogenous resources).
441+ // TODO: More precision could be gained by tracking clocks for all
442+ // TODO: current permit holders, with a data structure similar to
443+ // TODO: `permits_available`.
444+ ExecutionState :: with ( |s| {
445+ s. update_clock ( & state. permits_available . last_acquire ) ;
446+ } ) ;
447+ } ) ;
426448 drop ( state) ;
427449
428450 // If we won the race for permits of an unfair semaphore, re-block
@@ -586,6 +608,23 @@ impl BatchSemaphore {
586608 // Releasing a semaphore is a yield point
587609 thread:: switch ( ) ;
588610 }
611+
612+ /// Returns the number of permits currently held by the given `task_id`
613+ pub fn held_permits ( & self , task_id : & TaskId ) -> usize {
614+ let state = self . state . borrow_mut ( ) ;
615+ * state. holders . get ( task_id) . unwrap_or ( & 0 )
616+ }
617+
618+ /// Returns `true` iff the given `task_id` is currently waiting to acquire permits from the `BatchSemaphore`
619+ pub fn is_queued ( & self , task_id : & TaskId ) -> bool {
620+ let state = self . state . borrow_mut ( ) ;
621+ for waiter in & state. waiters {
622+ if waiter. task_id == * task_id {
623+ return true ;
624+ }
625+ }
626+ false
627+ }
589628}
590629
591630// Safety: Semaphore is never actually passed across true threads, only across continuations. The
@@ -677,7 +716,11 @@ impl Future for Acquire<'_> {
677716 // clock, as this thread will be blocked below.
678717 let mut state = self . semaphore . state . borrow_mut ( ) ;
679718 let id = state. id . unwrap ( ) ;
680- let acquire_result = state. acquire_permits ( self . waiter . num_permits , self . semaphore . fairness ) ;
719+ let task_id = get_current_task ( ) . unwrap_or_else ( || {
720+ error ! ( "Tried to acquire a semaphore while there is no current task. Panicking" ) ;
721+ panic ! ( "Tried to acquire a semaphore while there is no current task." ) ;
722+ } ) ;
723+ let acquire_result = state. acquire_permits ( self . waiter . num_permits , self . semaphore . fairness , task_id) ;
681724 drop ( state) ;
682725
683726 match acquire_result {
0 commit comments