Skip to content

Commit f0d6424

Browse files
committed
fix: Change thread local context to allow overlapped scopes
1 parent 5d4e15f commit f0d6424

File tree

1 file changed

+165
-21
lines changed

1 file changed

+165
-21
lines changed

opentelemetry/src/context.rs

Lines changed: 165 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::otel_warn;
12
#[cfg(feature = "trace")]
23
use crate::trace::context::SynchronizedSpan;
34
use std::any::{Any, TypeId};
@@ -9,7 +10,7 @@ use std::marker::PhantomData;
910
use std::sync::Arc;
1011

1112
thread_local! {
12-
static CURRENT_CONTEXT: RefCell<Context> = RefCell::new(Context::default());
13+
static CURRENT_CONTEXT: RefCell<ContextStack> = RefCell::new(ContextStack::default());
1314
}
1415

1516
/// An execution-scoped collection of values.
@@ -122,7 +123,7 @@ impl Context {
122123
/// Note: This function will panic if you attempt to attach another context
123124
/// while the current one is still borrowed.
124125
pub fn map_current<T>(f: impl FnOnce(&Context) -> T) -> T {
125-
CURRENT_CONTEXT.with(|cx| f(&cx.borrow()))
126+
CURRENT_CONTEXT.with(|cx| cx.borrow().map_current_cx(f))
126127
}
127128

128129
/// Returns a clone of the current thread's context with the given value.
@@ -298,12 +299,10 @@ impl Context {
298299
/// assert_eq!(Context::current().get::<ValueA>(), None);
299300
/// ```
300301
pub fn attach(self) -> ContextGuard {
301-
let previous_cx = CURRENT_CONTEXT
302-
.try_with(|current| current.replace(self))
303-
.ok();
302+
let cx_id = CURRENT_CONTEXT.with(|cx| cx.borrow_mut().push(self));
304303

305304
ContextGuard {
306-
previous_cx,
305+
cx_pos: cx_id,
307306
_marker: PhantomData,
308307
}
309308
}
@@ -344,17 +343,19 @@ impl fmt::Debug for Context {
344343
}
345344

346345
/// A guard that resets the current context to the prior context when dropped.
347-
#[allow(missing_debug_implementations)]
346+
#[derive(Debug)]
348347
pub struct ContextGuard {
349-
previous_cx: Option<Context>,
350-
// ensure this type is !Send as it relies on thread locals
348+
// The position of the context in the stack. This is used to pop the context.
349+
cx_pos: u16,
350+
// Ensure this type is !Send as it relies on thread locals
351351
_marker: PhantomData<*const ()>,
352352
}
353353

354354
impl Drop for ContextGuard {
355355
fn drop(&mut self) {
356-
if let Some(previous_cx) = self.previous_cx.take() {
357-
let _ = CURRENT_CONTEXT.try_with(|current| current.replace(previous_cx));
356+
let id = self.cx_pos;
357+
if id > ContextStack::BASE_POS && id < ContextStack::MAX_POS {
358+
CURRENT_CONTEXT.with(|context_stack| context_stack.borrow_mut().pop_id(id));
358359
}
359360
}
360361
}
@@ -381,10 +382,107 @@ impl Hasher for IdHasher {
381382
}
382383
}
383384

385+
/// A stack for keeping track of the [`Context`] instances that have been attached
386+
/// to a thread.
387+
///
388+
/// The stack allows for popping of contexts by position, which is used to do out
389+
/// of order dropping of [`ContextGuard`] instances. Only when the top of the
390+
/// stack is popped, the topmost [`Context`] is actually restored.
391+
///
392+
/// The stack relies on the fact that it is thread local and that the
393+
/// [`ContextGuard`] instances that are constructed using it can't be shared with
394+
/// other threads.
395+
struct ContextStack {
396+
/// This is the current [`Context`] that is active on this thread, and the top
397+
/// of the [`ContextStack`]. It is always present, and if the `stack` is empty
398+
/// it's an empty [`Context`].
399+
///
400+
/// Having this here allows for fast access to the current [`Context`].
401+
current_cx: Context,
402+
/// A `stack` of the other contexts that have been attached to the thread.
403+
stack: Vec<Option<Context>>,
404+
/// Ensure this type is !Send as it relies on thread locals
405+
_marker: PhantomData<*const ()>,
406+
}
407+
408+
impl ContextStack {
409+
const BASE_POS: u16 = 0;
410+
const MAX_POS: u16 = u16::MAX;
411+
const INITIAL_CAPACITY: usize = 8;
412+
413+
#[inline(always)]
414+
fn push(&mut self, cx: Context) -> u16 {
415+
// The next id is the length of the `stack`, plus one since we have the
416+
// top of the [`ContextStack`] as the `current_cx`.
417+
let next_id = self.stack.len() + 1;
418+
if next_id < ContextStack::MAX_POS.into() {
419+
let current_cx = std::mem::replace(&mut self.current_cx, cx);
420+
self.stack.push(Some(current_cx));
421+
next_id as u16
422+
} else {
423+
// This is an overflow, log it and ignore it.
424+
otel_warn!(name: "ContextStack.push", message = "Context stack overflow, context not pushed.");
425+
ContextStack::MAX_POS
426+
}
427+
}
428+
429+
#[inline(always)]
430+
fn pop_id(&mut self, pos: u16) {
431+
if pos == ContextStack::BASE_POS || pos == ContextStack::MAX_POS {
432+
// The empty context is always at the bottom of the [`ContextStack`]
433+
// and cannot be popped, and the overflow position is invalid, so do
434+
// nothing.
435+
return;
436+
}
437+
let len: u16 = self.stack.len() as u16;
438+
// Are we at the top of the [`ContextStack`]?
439+
if pos == len {
440+
// Shrink the stack if possible to clear out any out of order pops.
441+
while let Some(None) = self.stack.last() {
442+
_ = self.stack.pop();
443+
}
444+
// Restore the previous context. This will always happen since the
445+
// empty context is always at the bottom of the stack if the
446+
// [`ContextStack`] is not empty.
447+
if let Some(Some(next_cx)) = self.stack.pop() {
448+
self.current_cx = next_cx;
449+
}
450+
} else {
451+
// This is an out of order pop.
452+
if pos >= len {
453+
// This is an invalid id, ignore it.
454+
return;
455+
}
456+
// Clear out the entry at the given id.
457+
_ = self.stack[pos as usize].take();
458+
}
459+
}
460+
461+
#[inline(always)]
462+
fn map_current_cx<T>(&self, f: impl FnOnce(&Context) -> T) -> T {
463+
f(&self.current_cx)
464+
}
465+
}
466+
467+
impl Default for ContextStack {
468+
fn default() -> Self {
469+
ContextStack {
470+
current_cx: Context::default(),
471+
stack: Vec::with_capacity(ContextStack::INITIAL_CAPACITY),
472+
_marker: PhantomData,
473+
}
474+
}
475+
}
476+
384477
#[cfg(test)]
385478
mod tests {
386479
use super::*;
387480

481+
#[derive(Debug, PartialEq)]
482+
struct ValueA(&'static str);
483+
#[derive(Debug, PartialEq)]
484+
struct ValueB(u64);
485+
388486
#[test]
389487
fn context_immutable() {
390488
#[derive(Debug, PartialEq)]
@@ -424,10 +522,6 @@ mod tests {
424522

425523
#[test]
426524
fn nested_contexts() {
427-
#[derive(Debug, PartialEq)]
428-
struct ValueA(&'static str);
429-
#[derive(Debug, PartialEq)]
430-
struct ValueB(u64);
431525
let _outer_guard = Context::new().with_value(ValueA("a")).attach();
432526

433527
// Only value `a` is set
@@ -462,13 +556,7 @@ mod tests {
462556
}
463557

464558
#[test]
465-
#[ignore = "overlapping contexts are not supported yet"]
466559
fn overlapping_contexts() {
467-
#[derive(Debug, PartialEq)]
468-
struct ValueA(&'static str);
469-
#[derive(Debug, PartialEq)]
470-
struct ValueB(u64);
471-
472560
let outer_guard = Context::new().with_value(ValueA("a")).attach();
473561

474562
// Only value `a` is set
@@ -502,4 +590,60 @@ mod tests {
502590
assert_eq!(current.get::<ValueA>(), None);
503591
assert_eq!(current.get::<ValueB>(), None);
504592
}
593+
594+
#[test]
595+
fn too_many_contexts() {
596+
let mut guards: Vec<ContextGuard> = Vec::with_capacity(ContextStack::MAX_POS as usize);
597+
let stack_max_pos = ContextStack::MAX_POS as u64;
598+
// Fill the stack up until the last position
599+
for i in 1..stack_max_pos {
600+
let cx_guard = Context::current().with_value(ValueB(i)).attach();
601+
assert_eq!(Context::current().get(), Some(&ValueB(i)));
602+
assert_eq!(cx_guard.cx_pos, i as u16);
603+
guards.push(cx_guard);
604+
}
605+
// Let's overflow the stack a couple of times
606+
for _ in 0..16 {
607+
let cx_guard = Context::current().with_value(ValueA("overflow")).attach();
608+
assert_eq!(cx_guard.cx_pos, ContextStack::MAX_POS);
609+
assert_eq!(Context::current().get::<ValueA>(), None);
610+
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 1)));
611+
guards.push(cx_guard);
612+
}
613+
// Drop the overflow contexts
614+
for _ in 0..16 {
615+
guards.pop();
616+
assert_eq!(Context::current().get::<ValueA>(), None);
617+
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 1)));
618+
}
619+
// Drop one more so we can add a new one
620+
guards.pop();
621+
assert_eq!(Context::current().get::<ValueA>(), None);
622+
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 2)));
623+
// Push a new context and see that it works
624+
let cx_guard = Context::current().with_value(ValueA("last")).attach();
625+
assert_eq!(cx_guard.cx_pos, ContextStack::MAX_POS - 1);
626+
assert_eq!(Context::current().get(), Some(&ValueA("last")));
627+
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 2)));
628+
guards.push(cx_guard);
629+
// Let's overflow the stack a couple of times again
630+
for _ in 0..16 {
631+
let cx_guard = Context::current().with_value(ValueA("overflow")).attach();
632+
assert_eq!(cx_guard.cx_pos, ContextStack::MAX_POS);
633+
assert_eq!(Context::current().get(), Some(&ValueA("last")));
634+
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 2)));
635+
guards.push(cx_guard);
636+
}
637+
}
638+
639+
#[test]
640+
fn context_stack_pop_id() {
641+
// This is to get full line coverage of the `pop_id` function.
642+
// In real life the `Drop`` implementation of `ContextGuard` ensures that
643+
// the ids are valid and inside the bounds.
644+
let mut stack = ContextStack::default();
645+
stack.pop_id(ContextStack::BASE_POS);
646+
stack.pop_id(ContextStack::MAX_POS);
647+
stack.pop_id(4711);
648+
}
505649
}

0 commit comments

Comments
 (0)