1+ use crate :: otel_warn;
12#[ cfg( feature = "trace" ) ]
23use crate :: trace:: context:: SynchronizedSpan ;
34use std:: any:: { Any , TypeId } ;
@@ -9,7 +10,7 @@ use std::marker::PhantomData;
910use std:: sync:: Arc ;
1011
1112thread_local ! {
12- static CURRENT_CONTEXT : RefCell <Context > = RefCell :: new( Context :: default ( ) ) ;
13+ static CURRENT_CONTEXT : RefCell <ContextStack > = RefCell :: new( ContextStack :: default ( ) ) ;
1314}
1415
1516/// An execution-scoped collection of values.
@@ -122,7 +123,7 @@ impl Context {
122123 /// Note: This function will panic if you attempt to attach another context
123124 /// while the current one is still borrowed.
124125 pub fn map_current < T > ( f : impl FnOnce ( & Context ) -> T ) -> T {
125- CURRENT_CONTEXT . with ( |cx| f ( & cx. borrow ( ) ) )
126+ CURRENT_CONTEXT . with ( |cx| cx. borrow ( ) . map_current_cx ( f ) )
126127 }
127128
128129 /// Returns a clone of the current thread's context with the given value.
@@ -298,12 +299,10 @@ impl Context {
298299 /// assert_eq!(Context::current().get::<ValueA>(), None);
299300 /// ```
300301 pub fn attach ( self ) -> ContextGuard {
301- let previous_cx = CURRENT_CONTEXT
302- . try_with ( |current| current. replace ( self ) )
303- . ok ( ) ;
302+ let cx_id = CURRENT_CONTEXT . with ( |cx| cx. borrow_mut ( ) . push ( self ) ) ;
304303
305304 ContextGuard {
306- previous_cx ,
305+ cx_pos : cx_id ,
307306 _marker : PhantomData ,
308307 }
309308 }
@@ -344,17 +343,19 @@ impl fmt::Debug for Context {
344343}
345344
346345/// A guard that resets the current context to the prior context when dropped.
347- #[ allow ( missing_debug_implementations ) ]
346+ #[ derive ( Debug ) ]
348347pub struct ContextGuard {
349- previous_cx : Option < Context > ,
350- // ensure this type is !Send as it relies on thread locals
348+ // The position of the context in the stack. This is used to pop the context.
349+ cx_pos : u16 ,
350+ // Ensure this type is !Send as it relies on thread locals
351351 _marker : PhantomData < * const ( ) > ,
352352}
353353
354354impl Drop for ContextGuard {
355355 fn drop ( & mut self ) {
356- if let Some ( previous_cx) = self . previous_cx . take ( ) {
357- let _ = CURRENT_CONTEXT . try_with ( |current| current. replace ( previous_cx) ) ;
356+ let id = self . cx_pos ;
357+ if id > ContextStack :: BASE_POS && id < ContextStack :: MAX_POS {
358+ CURRENT_CONTEXT . with ( |context_stack| context_stack. borrow_mut ( ) . pop_id ( id) ) ;
358359 }
359360 }
360361}
@@ -381,10 +382,107 @@ impl Hasher for IdHasher {
381382 }
382383}
383384
385+ /// A stack for keeping track of the [`Context`] instances that have been attached
386+ /// to a thread.
387+ ///
388+ /// The stack allows for popping of contexts by position, which is used to do out
389+ /// of order dropping of [`ContextGuard`] instances. Only when the top of the
390+ /// stack is popped, the topmost [`Context`] is actually restored.
391+ ///
392+ /// The stack relies on the fact that it is thread local and that the
393+ /// [`ContextGuard`] instances that are constructed using it can't be shared with
394+ /// other threads.
395+ struct ContextStack {
396+ /// This is the current [`Context`] that is active on this thread, and the top
397+ /// of the [`ContextStack`]. It is always present, and if the `stack` is empty
398+ /// it's an empty [`Context`].
399+ ///
400+ /// Having this here allows for fast access to the current [`Context`].
401+ current_cx : Context ,
402+ /// A `stack` of the other contexts that have been attached to the thread.
403+ stack : Vec < Option < Context > > ,
404+ /// Ensure this type is !Send as it relies on thread locals
405+ _marker : PhantomData < * const ( ) > ,
406+ }
407+
408+ impl ContextStack {
409+ const BASE_POS : u16 = 0 ;
410+ const MAX_POS : u16 = u16:: MAX ;
411+ const INITIAL_CAPACITY : usize = 8 ;
412+
413+ #[ inline( always) ]
414+ fn push ( & mut self , cx : Context ) -> u16 {
415+ // The next id is the length of the `stack`, plus one since we have the
416+ // top of the [`ContextStack`] as the `current_cx`.
417+ let next_id = self . stack . len ( ) + 1 ;
418+ if next_id < ContextStack :: MAX_POS . into ( ) {
419+ let current_cx = std:: mem:: replace ( & mut self . current_cx , cx) ;
420+ self . stack . push ( Some ( current_cx) ) ;
421+ next_id as u16
422+ } else {
423+ // This is an overflow, log it and ignore it.
424+ otel_warn ! ( name: "ContextStack.push" , message = "Context stack overflow, context not pushed." ) ;
425+ ContextStack :: MAX_POS
426+ }
427+ }
428+
429+ #[ inline( always) ]
430+ fn pop_id ( & mut self , pos : u16 ) {
431+ if pos == ContextStack :: BASE_POS || pos == ContextStack :: MAX_POS {
432+ // The empty context is always at the bottom of the [`ContextStack`]
433+ // and cannot be popped, and the overflow position is invalid, so do
434+ // nothing.
435+ return ;
436+ }
437+ let len: u16 = self . stack . len ( ) as u16 ;
438+ // Are we at the top of the [`ContextStack`]?
439+ if pos == len {
440+ // Shrink the stack if possible to clear out any out of order pops.
441+ while let Some ( None ) = self . stack . last ( ) {
442+ _ = self . stack . pop ( ) ;
443+ }
444+ // Restore the previous context. This will always happen since the
445+ // empty context is always at the bottom of the stack if the
446+ // [`ContextStack`] is not empty.
447+ if let Some ( Some ( next_cx) ) = self . stack . pop ( ) {
448+ self . current_cx = next_cx;
449+ }
450+ } else {
451+ // This is an out of order pop.
452+ if pos >= len {
453+ // This is an invalid id, ignore it.
454+ return ;
455+ }
456+ // Clear out the entry at the given id.
457+ _ = self . stack [ pos as usize ] . take ( ) ;
458+ }
459+ }
460+
461+ #[ inline( always) ]
462+ fn map_current_cx < T > ( & self , f : impl FnOnce ( & Context ) -> T ) -> T {
463+ f ( & self . current_cx )
464+ }
465+ }
466+
467+ impl Default for ContextStack {
468+ fn default ( ) -> Self {
469+ ContextStack {
470+ current_cx : Context :: default ( ) ,
471+ stack : Vec :: with_capacity ( ContextStack :: INITIAL_CAPACITY ) ,
472+ _marker : PhantomData ,
473+ }
474+ }
475+ }
476+
384477#[ cfg( test) ]
385478mod tests {
386479 use super :: * ;
387480
481+ #[ derive( Debug , PartialEq ) ]
482+ struct ValueA ( & ' static str ) ;
483+ #[ derive( Debug , PartialEq ) ]
484+ struct ValueB ( u64 ) ;
485+
388486 #[ test]
389487 fn context_immutable ( ) {
390488 #[ derive( Debug , PartialEq ) ]
@@ -424,10 +522,6 @@ mod tests {
424522
425523 #[ test]
426524 fn nested_contexts ( ) {
427- #[ derive( Debug , PartialEq ) ]
428- struct ValueA ( & ' static str ) ;
429- #[ derive( Debug , PartialEq ) ]
430- struct ValueB ( u64 ) ;
431525 let _outer_guard = Context :: new ( ) . with_value ( ValueA ( "a" ) ) . attach ( ) ;
432526
433527 // Only value `a` is set
@@ -462,13 +556,7 @@ mod tests {
462556 }
463557
464558 #[ test]
465- #[ ignore = "overlapping contexts are not supported yet" ]
466559 fn overlapping_contexts ( ) {
467- #[ derive( Debug , PartialEq ) ]
468- struct ValueA ( & ' static str ) ;
469- #[ derive( Debug , PartialEq ) ]
470- struct ValueB ( u64 ) ;
471-
472560 let outer_guard = Context :: new ( ) . with_value ( ValueA ( "a" ) ) . attach ( ) ;
473561
474562 // Only value `a` is set
@@ -502,4 +590,60 @@ mod tests {
502590 assert_eq ! ( current. get:: <ValueA >( ) , None ) ;
503591 assert_eq ! ( current. get:: <ValueB >( ) , None ) ;
504592 }
593+
594+ #[ test]
595+ fn too_many_contexts ( ) {
596+ let mut guards: Vec < ContextGuard > = Vec :: with_capacity ( ContextStack :: MAX_POS as usize ) ;
597+ let stack_max_pos = ContextStack :: MAX_POS as u64 ;
598+ // Fill the stack up until the last position
599+ for i in 1 ..stack_max_pos {
600+ let cx_guard = Context :: current ( ) . with_value ( ValueB ( i) ) . attach ( ) ;
601+ assert_eq ! ( Context :: current( ) . get( ) , Some ( & ValueB ( i) ) ) ;
602+ assert_eq ! ( cx_guard. cx_pos, i as u16 ) ;
603+ guards. push ( cx_guard) ;
604+ }
605+ // Let's overflow the stack a couple of times
606+ for _ in 0 ..16 {
607+ let cx_guard = Context :: current ( ) . with_value ( ValueA ( "overflow" ) ) . attach ( ) ;
608+ assert_eq ! ( cx_guard. cx_pos, ContextStack :: MAX_POS ) ;
609+ assert_eq ! ( Context :: current( ) . get:: <ValueA >( ) , None ) ;
610+ assert_eq ! ( Context :: current( ) . get( ) , Some ( & ValueB ( stack_max_pos - 1 ) ) ) ;
611+ guards. push ( cx_guard) ;
612+ }
613+ // Drop the overflow contexts
614+ for _ in 0 ..16 {
615+ guards. pop ( ) ;
616+ assert_eq ! ( Context :: current( ) . get:: <ValueA >( ) , None ) ;
617+ assert_eq ! ( Context :: current( ) . get( ) , Some ( & ValueB ( stack_max_pos - 1 ) ) ) ;
618+ }
619+ // Drop one more so we can add a new one
620+ guards. pop ( ) ;
621+ assert_eq ! ( Context :: current( ) . get:: <ValueA >( ) , None ) ;
622+ assert_eq ! ( Context :: current( ) . get( ) , Some ( & ValueB ( stack_max_pos - 2 ) ) ) ;
623+ // Push a new context and see that it works
624+ let cx_guard = Context :: current ( ) . with_value ( ValueA ( "last" ) ) . attach ( ) ;
625+ assert_eq ! ( cx_guard. cx_pos, ContextStack :: MAX_POS - 1 ) ;
626+ assert_eq ! ( Context :: current( ) . get( ) , Some ( & ValueA ( "last" ) ) ) ;
627+ assert_eq ! ( Context :: current( ) . get( ) , Some ( & ValueB ( stack_max_pos - 2 ) ) ) ;
628+ guards. push ( cx_guard) ;
629+ // Let's overflow the stack a couple of times again
630+ for _ in 0 ..16 {
631+ let cx_guard = Context :: current ( ) . with_value ( ValueA ( "overflow" ) ) . attach ( ) ;
632+ assert_eq ! ( cx_guard. cx_pos, ContextStack :: MAX_POS ) ;
633+ assert_eq ! ( Context :: current( ) . get( ) , Some ( & ValueA ( "last" ) ) ) ;
634+ assert_eq ! ( Context :: current( ) . get( ) , Some ( & ValueB ( stack_max_pos - 2 ) ) ) ;
635+ guards. push ( cx_guard) ;
636+ }
637+ }
638+
639+ #[ test]
640+ fn context_stack_pop_id ( ) {
641+ // This is to get full line coverage of the `pop_id` function.
642+ // In real life the `Drop`` implementation of `ContextGuard` ensures that
643+ // the ids are valid and inside the bounds.
644+ let mut stack = ContextStack :: default ( ) ;
645+ stack. pop_id ( ContextStack :: BASE_POS ) ;
646+ stack. pop_id ( ContextStack :: MAX_POS ) ;
647+ stack. pop_id ( 4711 ) ;
648+ }
505649}
0 commit comments