Skip to content

Commit 9f34552

Browse files
committed
Add task-local context propagation
This commit introduces a new task-local context propagation mechanism to gRPC, enabling context to be preserved across await points in asynchronous tasks. The implementation is based on a combination of: * **Otel FutureExt**: Adopts the extension trait pattern from OpenTelemetry's [[FutureExt](cci:2://file:///usr/local/google/home/sauravz/repos/tonic/grpc/src/context/extensions.rs:34:0-44:1)](https://docs.rs/opentelemetry/latest/opentelemetry/future/trait.FutureExt.html) to allow fluent context attachment via `.with_context(ctx)`. * **Tokio Task Local**: Implements a runtime-agnostic task-local storage mechanism similar to [`tokio::task_local!`](https://docs.rs/tokio/latest/tokio/macro.task_local.html), ensuring context is correctly scoped and propagated. Changes: * Add `task_local_context` module for managing context scope. * Add FutureExt and StreamExt traits in `extensions.rs`. * Update `Context` to use `Arc<dyn Context>` for efficient sharing.
1 parent 72f734f commit 9f34552

File tree

5 files changed

+505
-0
lines changed

5 files changed

+505
-0
lines changed

grpc/src/context.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
mod context;
2+
mod extensions;
3+
mod task_local_context;
4+
5+
pub use context::Context;
6+
pub use extensions::{FutureExt, StreamExt};
7+
pub use task_local_context::current;

grpc/src/context/context.rs

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
use std::any::{Any, TypeId};
2+
use std::collections::HashMap;
3+
use std::sync::Arc;
4+
use std::time::Instant;
5+
6+
/// A task-local context for propagating metadata, deadlines, and other request-scoped values.
7+
pub trait Context: Send + Sync + 'static {
8+
/// Get the deadline for the current context.
9+
fn deadline(&self) -> Option<Instant>;
10+
11+
/// Create a new context with the given deadline.
12+
fn with_deadline(&self, deadline: Instant) -> Arc<dyn Context>;
13+
14+
/// Get a value from the context extensions.
15+
fn get(&self, type_id: TypeId) -> Option<&(dyn Any + Send + Sync)>;
16+
17+
/// Create a new context with the given value.
18+
fn with_value(&self, type_id: TypeId, value: Arc<dyn Any + Send + Sync>) -> Arc<dyn Context>;
19+
}
20+
21+
#[derive(Clone, Default)]
22+
struct ContextInner {
23+
deadline: Option<Instant>,
24+
extensions: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
25+
}
26+
27+
#[derive(Clone, Default)]
28+
pub(crate) struct ContextImpl {
29+
inner: Arc<ContextInner>,
30+
}
31+
32+
impl Context for ContextImpl {
33+
fn deadline(&self) -> Option<Instant> {
34+
self.inner.deadline
35+
}
36+
37+
fn with_deadline(&self, deadline: Instant) -> Arc<dyn Context> {
38+
let mut inner = (*self.inner).clone();
39+
inner.deadline = Some(deadline);
40+
Arc::new(Self {
41+
inner: Arc::new(inner),
42+
})
43+
}
44+
45+
fn get(&self, type_id: TypeId) -> Option<&(dyn Any + Send + Sync)> {
46+
self.inner.extensions.get(&type_id).map(|v| &**v as _)
47+
}
48+
49+
fn with_value(&self, type_id: TypeId, value: Arc<dyn Any + Send + Sync>) -> Arc<dyn Context> {
50+
let mut inner = (*self.inner).clone();
51+
inner.extensions.insert(type_id, value);
52+
Arc::new(Self {
53+
inner: Arc::new(inner),
54+
})
55+
}
56+
}
57+
58+
#[cfg(test)]
59+
mod tests {
60+
use super::*;
61+
use std::time::Duration;
62+
63+
#[test]
64+
fn default_context_has_no_deadline_or_extensions() {
65+
let context = ContextImpl::default();
66+
assert!(context.deadline().is_none());
67+
assert!(context.get(TypeId::of::<i32>()).is_none());
68+
}
69+
70+
#[test]
71+
fn with_deadline_sets_deadline_and_preserves_original() {
72+
let context = ContextImpl::default();
73+
let deadline = Instant::now() + Duration::from_secs(5);
74+
let context_with_deadline = context.with_deadline(deadline);
75+
76+
assert_eq!(context_with_deadline.deadline(), Some(deadline));
77+
// Original context should remain unchanged
78+
assert!(context.deadline().is_none());
79+
}
80+
81+
#[test]
82+
fn with_value_stores_extension_and_preserves_original() {
83+
let context = ContextImpl::default();
84+
85+
#[derive(Debug, PartialEq)]
86+
struct MyValue(i32);
87+
88+
let context_with_value = context.with_value(TypeId::of::<MyValue>(), Arc::new(MyValue(42)));
89+
90+
let value = context_with_value
91+
.get(TypeId::of::<MyValue>())
92+
.and_then(|v| v.downcast_ref::<MyValue>());
93+
assert_eq!(value, Some(&MyValue(42)));
94+
95+
// Original context should not have the value
96+
assert!(context.get(TypeId::of::<MyValue>()).is_none());
97+
}
98+
99+
#[test]
100+
fn with_value_overwrites_existing_extension_and_preserves_previous() {
101+
let context = ContextImpl::default();
102+
103+
#[derive(Debug, PartialEq)]
104+
struct MyValue(i32);
105+
106+
let ctx1 = context.with_value(TypeId::of::<MyValue>(), Arc::new(MyValue(10)));
107+
let ctx2 = ctx1.with_value(TypeId::of::<MyValue>(), Arc::new(MyValue(20)));
108+
109+
let val1 = ctx1
110+
.get(TypeId::of::<MyValue>())
111+
.and_then(|v| v.downcast_ref::<MyValue>());
112+
let val2 = ctx2
113+
.get(TypeId::of::<MyValue>())
114+
.and_then(|v| v.downcast_ref::<MyValue>());
115+
116+
assert_eq!(val1, Some(&MyValue(10)));
117+
assert_eq!(val2, Some(&MyValue(20)));
118+
}
119+
}

grpc/src/context/extensions.rs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
//! Extension traits for `Future` and `Stream` to provide context propagation.
2+
//!
3+
//! This module provides the [`FutureExt`] and [`StreamExt`] traits, which allow
4+
//! attaching a [`Context`] to a [`Future`] or [`Stream`]. This ensures that the
5+
//! context is set as the current task-local context whenever the future or stream
6+
//! is polled.
7+
8+
use std::future::Future;
9+
use std::sync::Arc;
10+
use tokio_stream::Stream;
11+
12+
use super::context::Context;
13+
use super::task_local_context;
14+
15+
/// Extension trait for `Future` to provide context propagation.
16+
///
17+
/// This trait allows attaching a [`Context`] to a [`Future`], ensuring that the context
18+
/// is set as the current task-local context whenever the future is polled.
19+
///
20+
/// # Examples
21+
///
22+
/// ```rust
23+
/// # use std::sync::Arc;
24+
/// # use grpc::context::{Context, FutureExt};
25+
/// # async fn example() {
26+
/// let context = grpc::context::current();
27+
/// let future = async {
28+
/// // Context is available here
29+
/// assert!(grpc::context::current().deadline().is_none());
30+
/// };
31+
///
32+
/// future.with_context(context).await;
33+
/// # }
34+
/// ```
35+
pub trait FutureExt: Future {
36+
/// Attach a context to this future.
37+
///
38+
/// The context will be set as the current task-local context whenever the future is polled.
39+
fn with_context(self, context: Arc<dyn Context>) -> impl Future<Output = Self::Output>
40+
where
41+
Self: Sized,
42+
{
43+
task_local_context::ContextScope::new(self, context)
44+
}
45+
}
46+
47+
impl<F: Future> FutureExt for F {}
48+
49+
/// Extension trait for `Stream` to provide context propagation.
50+
///
51+
/// This trait allows attaching a [`Context`] to a [`Stream`], ensuring that the context
52+
/// is set as the current task-local context whenever the stream is polled.
53+
///
54+
/// # Examples
55+
///
56+
/// ```rust
57+
/// # use std::sync::Arc;
58+
/// # use grpc::context::{Context, StreamExt};
59+
/// # use tokio_stream::StreamExt as _;
60+
/// # async fn example() {
61+
/// let context = grpc::context::current();
62+
/// let stream = tokio_stream::iter(vec![1, 2, 3]);
63+
///
64+
/// let mut scoped_stream = stream.with_context(context);
65+
///
66+
/// while let Some(item) = scoped_stream.next().await {
67+
/// // Context is available here
68+
/// assert!(grpc::context::current().deadline().is_none());
69+
/// }
70+
/// # }
71+
/// ```
72+
pub trait StreamExt: Stream {
73+
/// Attach a context to this stream.
74+
///
75+
/// The context will be set as the current task-local context whenever the stream is polled.
76+
fn with_context(self, context: Arc<dyn Context>) -> impl Stream<Item = Self::Item>
77+
where
78+
Self: Sized,
79+
{
80+
task_local_context::ContextScope::new(self, context)
81+
}
82+
}
83+
84+
impl<S: Stream> StreamExt for S {}
85+
86+
#[cfg(test)]
87+
mod tests {
88+
use super::super::context::ContextImpl;
89+
use super::*;
90+
use tokio_stream::StreamExt as _;
91+
92+
#[tokio::test]
93+
async fn test_future_ext_attaches_context_correctly() {
94+
let ctx = ContextImpl::default();
95+
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10);
96+
let ctx = ctx.with_deadline(deadline);
97+
98+
let future = async {
99+
let current_ctx = super::task_local_context::current();
100+
assert_eq!(current_ctx.deadline(), Some(deadline));
101+
};
102+
103+
future.with_context(ctx).await;
104+
}
105+
106+
#[tokio::test]
107+
async fn test_stream_ext_attaches_context_correctly() {
108+
let ctx = ContextImpl::default();
109+
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10);
110+
let ctx = ctx.with_deadline(deadline);
111+
112+
let stream = async_stream::stream! {
113+
let current_ctx = super::task_local_context::current();
114+
assert_eq!(current_ctx.deadline(), Some(deadline));
115+
yield 1;
116+
};
117+
118+
let scoped_stream = stream.with_context(ctx);
119+
tokio::pin!(scoped_stream);
120+
scoped_stream.next().await;
121+
}
122+
}

0 commit comments

Comments
 (0)