@@ -9,7 +9,7 @@ use std::marker::PhantomData;
99use std:: sync:: Arc ;
1010
1111thread_local ! {
12- static CURRENT_CONTEXT : RefCell <Context > = RefCell :: new( Context :: default ( ) ) ;
12+ static CURRENT_CONTEXT : RefCell <ContextStack > = RefCell :: new( ContextStack :: default ( ) ) ;
1313}
1414
1515/// An execution-scoped collection of values.
@@ -122,7 +122,7 @@ impl Context {
122122 /// Note: This function will panic if you attempt to attach another context
123123 /// while the current one is still borrowed.
124124 pub fn map_current < T > ( f : impl FnOnce ( & Context ) -> T ) -> T {
125- CURRENT_CONTEXT . with ( |cx| f ( & cx. borrow ( ) ) )
125+ CURRENT_CONTEXT . with ( |cx| cx. borrow ( ) . map_current_cx ( f ) )
126126 }
127127
128128 /// Returns a clone of the current thread's context with the given value.
@@ -298,12 +298,10 @@ impl Context {
298298 /// assert_eq!(Context::current().get::<ValueA>(), None);
299299 /// ```
300300 pub fn attach ( self ) -> ContextGuard {
301- let previous_cx = CURRENT_CONTEXT
302- . try_with ( |current| current. replace ( self ) )
303- . ok ( ) ;
301+ let cx_id = CURRENT_CONTEXT . with ( |cx| cx. borrow_mut ( ) . push ( self ) ) ;
304302
305303 ContextGuard {
306- previous_cx ,
304+ cx_pos : cx_id ,
307305 _marker : PhantomData ,
308306 }
309307 }
@@ -344,17 +342,19 @@ impl fmt::Debug for Context {
344342}
345343
346344/// A guard that resets the current context to the prior context when dropped.
347- #[ allow ( missing_debug_implementations ) ]
345+ #[ derive ( Debug ) ]
348346pub struct ContextGuard {
349- previous_cx : Option < Context > ,
350- // ensure this type is !Send as it relies on thread locals
347+ // The position of the context in the stack. This is used to pop the context.
348+ cx_pos : usize ,
349+ // Ensure this type is !Send as it relies on thread locals
351350 _marker : PhantomData < * const ( ) > ,
352351}
353352
354353impl Drop for ContextGuard {
355354 fn drop ( & mut self ) {
356- if let Some ( previous_cx) = self . previous_cx . take ( ) {
357- let _ = CURRENT_CONTEXT . try_with ( |current| current. replace ( previous_cx) ) ;
355+ let id = self . cx_pos ;
356+ if id > 0 {
357+ CURRENT_CONTEXT . with ( |context_stack| context_stack. borrow_mut ( ) . pop_id ( id) ) ;
358358 }
359359 }
360360}
@@ -381,6 +381,87 @@ impl Hasher for IdHasher {
381381 }
382382}
383383
384+ /// A stack for keeping track of the [`Context`] instances that have been attached
385+ /// to a thread.
386+ ///
387+ /// The stack allows for popping of contexts by position, which is used to do out
388+ /// of order dropping of [`ContextGuard`] instances. Only when the top of the
389+ /// stack is popped, the topmost [`Context`] is actually restored.
390+ ///
391+ /// The stack relies on the fact that it is thread local and that the
392+ /// [`ContextGuard`] instances that are constructed using it can't be shared with
393+ /// other threads.
394+ struct ContextStack {
395+ /// This is the current [`Context`] that is active on this thread, and the top
396+ /// of the [`ContextStack`]. It is always present, and if the `stack` is empty
397+ /// it's an empty [`Context`].
398+ ///
399+ /// Having this here allows for fast access to the current [`Context`].
400+ current_cx : Context ,
401+ /// A `stack` of the other contexts that have been attached to the thread.
402+ stack : Vec < Option < Context > > ,
403+ /// Ensure this type is !Send as it relies on thread locals
404+ _marker : PhantomData < * const ( ) > ,
405+ }
406+
407+ impl ContextStack {
408+ #[ inline( always) ]
409+ fn push ( & mut self , cx : Context ) -> usize {
410+ // The next id is the length of the `stack`, plus one since we have the
411+ // top of the [`ContextStack`] as the `current_cx`.
412+ let next_id = self . stack . len ( ) + 1 ;
413+ let current_cx = std:: mem:: replace ( & mut self . current_cx , cx) ;
414+ self . stack . push ( Some ( current_cx) ) ;
415+ next_id
416+ }
417+
418+ #[ inline( always) ]
419+ fn pop_id ( & mut self , pos : usize ) {
420+ if pos == 0 {
421+ // The empty context is always at the bottom of the [`ContextStack`]
422+ // and cannot be popped, so do nothing.
423+ return ;
424+ }
425+ let len = self . stack . len ( ) ;
426+ // Are we at the top of the [`ContextStack`]?
427+ if pos == len {
428+ // Shrink the stack if possible to clear out any out of order pops.
429+ while let Some ( None ) = self . stack . last ( ) {
430+ _ = self . stack . pop ( ) ;
431+ }
432+ // Restore the previous context. This will always happen since the
433+ // empty context is always at the bottom of the stack if the
434+ // [`ContextStack`] is not empty.
435+ if let Some ( Some ( next_cx) ) = self . stack . pop ( ) {
436+ self . current_cx = next_cx;
437+ }
438+ } else {
439+ // This is an out of order pop.
440+ if pos >= len {
441+ // This is an invalid id, ignore it.
442+ return ;
443+ }
444+ // Clear out the entry at the given id.
445+ _ = self . stack [ pos] . take ( ) ;
446+ }
447+ }
448+
449+ #[ inline( always) ]
450+ fn map_current_cx < T > ( & self , f : impl FnOnce ( & Context ) -> T ) -> T {
451+ f ( & self . current_cx )
452+ }
453+ }
454+
455+ impl Default for ContextStack {
456+ fn default ( ) -> Self {
457+ ContextStack {
458+ current_cx : Context :: default ( ) ,
459+ stack : Vec :: with_capacity ( 64 ) ,
460+ _marker : PhantomData ,
461+ }
462+ }
463+ }
464+
384465#[ cfg( test) ]
385466mod tests {
386467 use super :: * ;
@@ -425,7 +506,6 @@ mod tests {
425506 }
426507
427508 #[ test]
428- #[ ignore = "overlapping contexts are not supported yet" ]
429509 fn overlapping_contexts ( ) {
430510 #[ derive( Debug , PartialEq ) ]
431511 struct ValueA ( & ' static str ) ;
0 commit comments