@@ -9,7 +9,7 @@ use std::marker::PhantomData;
99use std:: sync:: Arc ;
1010
1111thread_local ! {
12- static CURRENT_CONTEXT : RefCell <Context > = RefCell :: new( Context :: default ( ) ) ;
12+ static CURRENT_CONTEXT : RefCell <ContextStack > = RefCell :: new( ContextStack :: default ( ) ) ;
1313}
1414
1515/// An execution-scoped collection of values.
@@ -122,7 +122,7 @@ impl Context {
122122 /// Note: This function will panic if you attempt to attach another context
123123 /// while the current one is still borrowed.
124124 pub fn map_current < T > ( f : impl FnOnce ( & Context ) -> T ) -> T {
125- CURRENT_CONTEXT . with ( |cx| f ( & cx. borrow ( ) ) )
125+ CURRENT_CONTEXT . with ( |cx| cx. borrow ( ) . map_current_cx ( f ) )
126126 }
127127
128128 /// Returns a clone of the current thread's context with the given value.
@@ -298,12 +298,10 @@ impl Context {
298298 /// assert_eq!(Context::current().get::<ValueA>(), None);
299299 /// ```
300300 pub fn attach ( self ) -> ContextGuard {
301- let previous_cx = CURRENT_CONTEXT
302- . try_with ( |current| current. replace ( self ) )
303- . ok ( ) ;
301+ let cx_id = CURRENT_CONTEXT . with ( |cx| cx. borrow_mut ( ) . push ( self ) ) ;
304302
305303 ContextGuard {
306- previous_cx ,
304+ cx_id ,
307305 _marker : PhantomData ,
308306 }
309307 }
@@ -346,15 +344,16 @@ impl fmt::Debug for Context {
346344/// A guard that resets the current context to the prior context when dropped.
347345#[ allow( missing_debug_implementations) ]
348346pub struct ContextGuard {
349- previous_cx : Option < Context > ,
347+ cx_id : usize ,
350348 // ensure this type is !Send as it relies on thread locals
351349 _marker : PhantomData < * const ( ) > ,
352350}
353351
354352impl Drop for ContextGuard {
355353 fn drop ( & mut self ) {
356- if let Some ( previous_cx) = self . previous_cx . take ( ) {
357- let _ = CURRENT_CONTEXT . try_with ( |current| current. replace ( previous_cx) ) ;
354+ let id = self . cx_id ;
355+ if id > 0 {
356+ CURRENT_CONTEXT . with ( |context_stack| context_stack. borrow_mut ( ) . pop_id ( id) ) ;
358357 }
359358 }
360359}
@@ -381,6 +380,75 @@ impl Hasher for IdHasher {
381380 }
382381}
383382
383+ struct ContextStack {
384+ current_cx : Context ,
385+ current_id : usize ,
386+ // TODO:ban wrap the whole id thing in its own type
387+ id_count : usize ,
388+ // TODO:ban wrap the the tuple in its own type
389+ stack : Vec < Option < ( usize , Context ) > > ,
390+ }
391+
392+ impl ContextStack {
393+ #[ inline( always) ]
394+ fn push ( & mut self , cx : Context ) -> usize {
395+ self . id_count += 512 ; // TODO:ban clean up this
396+ let next_id = self . stack . len ( ) + 1 + self . id_count ;
397+ let current_cx = std:: mem:: replace ( & mut self . current_cx , cx) ;
398+ self . stack . push ( Some ( ( self . current_id , current_cx) ) ) ;
399+ self . current_id = next_id;
400+ next_id
401+ }
402+
403+ #[ inline( always) ]
404+ fn pop_id ( & mut self , id : usize ) {
405+ if id == 0 {
406+ return ;
407+ }
408+ // Are we at the top of the stack?
409+ if id == self . current_id {
410+ // Shrink the stack if possible
411+ while let Some ( None ) = self . stack . last ( ) {
412+ self . stack . pop ( ) ;
413+ }
414+ // There is always the initial context at the bottom of the stack
415+ if let Some ( Some ( ( next_id, next_cx) ) ) = self . stack . pop ( ) {
416+ self . current_cx = next_cx;
417+ self . current_id = next_id;
418+ }
419+ } else {
420+ let pos = id & 511 ; // TODO:ban clean up this
421+ if pos >= self . stack . len ( ) {
422+ // This is an invalid id, ignore it
423+ return ;
424+ }
425+ if let Some ( ( pos_id, _) ) = self . stack [ pos] {
426+ // Is the correct id at this position?
427+ if pos_id == id {
428+ // Clear out this entry
429+ self . stack [ pos] = None ;
430+ }
431+ }
432+ }
433+ }
434+
435+ #[ inline( always) ]
436+ fn map_current_cx < T > ( & self , f : impl FnOnce ( & Context ) -> T ) -> T {
437+ f ( & self . current_cx )
438+ }
439+ }
440+
441+ impl Default for ContextStack {
442+ fn default ( ) -> Self {
443+ ContextStack {
444+ current_id : 0 ,
445+ current_cx : Context :: default ( ) ,
446+ id_count : 0 ,
447+ stack : Vec :: with_capacity ( 64 ) ,
448+ }
449+ }
450+ }
451+
384452#[ cfg( test) ]
385453mod tests {
386454 use super :: * ;
@@ -425,7 +493,6 @@ mod tests {
425493 }
426494
427495 #[ test]
428- #[ ignore = "overlapping contexts are not supported yet" ]
429496 fn overlapping_contexts ( ) {
430497 #[ derive( Debug , PartialEq ) ]
431498 struct ValueA ( & ' static str ) ;
0 commit comments