Skip to content

Commit 2f78b43

Browse files
committed
Add task-local context propagation
This commit introduces a new task-local context propagation mechanism to gRPC, enabling context to be preserved across await points in asynchronous tasks. The implementation is based on a combination of: * **Otel FutureExt**: Adopts the extension trait pattern from OpenTelemetry's [[FutureExt](cci:2://file:///usr/local/google/home/sauravz/repos/tonic/grpc/src/context/extensions.rs:34:0-44:1)](https://docs.rs/opentelemetry/latest/opentelemetry/future/trait.FutureExt.html) to allow fluent context attachment via `.with_context(ctx)`. * **Tokio Task Local**: Implements a runtime-agnostic task-local storage mechanism similar to [`tokio::task_local!`](https://docs.rs/tokio/latest/tokio/macro.task_local.html), ensuring context is correctly scoped and propagated. Changes: * Add `task_local_context` module for managing context scope. * Add FutureExt and StreamExt traits in `extensions.rs`. * Update `Context` to use `Arc<dyn Context>` for efficient sharing.
1 parent 72f734f commit 2f78b43

File tree

4 files changed

+504
-0
lines changed

4 files changed

+504
-0
lines changed

grpc/src/context.rs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
mod extensions;
2+
mod task_local_context;
3+
4+
pub use extensions::{FutureExt, StreamExt};
5+
pub use task_local_context::current;
6+
7+
use std::any::{Any, TypeId};
8+
use std::collections::HashMap;
9+
use std::sync::Arc;
10+
use std::time::Instant;
11+
12+
/// A task-local context for propagating metadata, deadlines, and other request-scoped values.
13+
pub trait Context: Send + Sync + 'static {
14+
/// Get the deadline for the current context.
15+
fn deadline(&self) -> Option<Instant>;
16+
17+
/// Create a new context with the given deadline.
18+
fn with_deadline(&self, deadline: Instant) -> Arc<dyn Context>;
19+
20+
/// Get a value from the context extensions.
21+
fn get(&self, type_id: TypeId) -> Option<&(dyn Any + Send + Sync)>;
22+
23+
/// Create a new context with the given value.
24+
fn with_value(&self, type_id: TypeId, value: Arc<dyn Any + Send + Sync>) -> Arc<dyn Context>;
25+
}
26+
27+
#[derive(Clone, Default)]
28+
struct ContextInner {
29+
deadline: Option<Instant>,
30+
extensions: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
31+
}
32+
33+
#[derive(Clone, Default)]
34+
pub(crate) struct ContextImpl {
35+
inner: Arc<ContextInner>,
36+
}
37+
38+
impl Context for ContextImpl {
39+
fn deadline(&self) -> Option<Instant> {
40+
self.inner.deadline
41+
}
42+
43+
fn with_deadline(&self, deadline: Instant) -> Arc<dyn Context> {
44+
let mut inner = (*self.inner).clone();
45+
inner.deadline = Some(deadline);
46+
Arc::new(Self {
47+
inner: Arc::new(inner),
48+
})
49+
}
50+
51+
fn get(&self, type_id: TypeId) -> Option<&(dyn Any + Send + Sync)> {
52+
self.inner.extensions.get(&type_id).map(|v| &**v as _)
53+
}
54+
55+
fn with_value(&self, type_id: TypeId, value: Arc<dyn Any + Send + Sync>) -> Arc<dyn Context> {
56+
let mut inner = (*self.inner).clone();
57+
inner.extensions.insert(type_id, value);
58+
Arc::new(Self {
59+
inner: Arc::new(inner),
60+
})
61+
}
62+
}
63+
64+
#[cfg(test)]
65+
mod tests {
66+
use super::*;
67+
use std::time::Duration;
68+
69+
#[test]
70+
fn default_context_has_no_deadline_or_extensions() {
71+
let context = ContextImpl::default();
72+
assert!(context.deadline().is_none());
73+
assert!(context.get(TypeId::of::<i32>()).is_none());
74+
}
75+
76+
#[test]
77+
fn with_deadline_sets_deadline_and_preserves_original() {
78+
let context = ContextImpl::default();
79+
let deadline = Instant::now() + Duration::from_secs(5);
80+
let context_with_deadline = context.with_deadline(deadline);
81+
82+
assert_eq!(context_with_deadline.deadline(), Some(deadline));
83+
// Original context should remain unchanged
84+
assert!(context.deadline().is_none());
85+
}
86+
87+
#[test]
88+
fn with_value_stores_extension_and_preserves_original() {
89+
let context = ContextImpl::default();
90+
91+
#[derive(Debug, PartialEq)]
92+
struct MyValue(i32);
93+
94+
let context_with_value = context.with_value(TypeId::of::<MyValue>(), Arc::new(MyValue(42)));
95+
96+
let value = context_with_value
97+
.get(TypeId::of::<MyValue>())
98+
.and_then(|v| v.downcast_ref::<MyValue>());
99+
assert_eq!(value, Some(&MyValue(42)));
100+
101+
// Original context should not have the value
102+
assert!(context.get(TypeId::of::<MyValue>()).is_none());
103+
}
104+
105+
#[test]
106+
fn with_value_overwrites_existing_extension_and_preserves_previous() {
107+
let context = ContextImpl::default();
108+
109+
#[derive(Debug, PartialEq)]
110+
struct MyValue(i32);
111+
112+
let ctx1 = context.with_value(TypeId::of::<MyValue>(), Arc::new(MyValue(10)));
113+
let ctx2 = ctx1.with_value(TypeId::of::<MyValue>(), Arc::new(MyValue(20)));
114+
115+
let val1 = ctx1
116+
.get(TypeId::of::<MyValue>())
117+
.and_then(|v| v.downcast_ref::<MyValue>());
118+
let val2 = ctx2
119+
.get(TypeId::of::<MyValue>())
120+
.and_then(|v| v.downcast_ref::<MyValue>());
121+
122+
assert_eq!(val1, Some(&MyValue(10)));
123+
assert_eq!(val2, Some(&MyValue(20)));
124+
}
125+
}

grpc/src/context/extensions.rs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
//! Extension traits for `Future` and `Stream` to provide context propagation.
2+
//!
3+
//! This module provides the [`FutureExt`] and [`StreamExt`] traits, which allow
4+
//! attaching a [`Context`] to a [`Future`] or [`Stream`]. This ensures that the
5+
//! context is set as the current task-local context whenever the future or stream
6+
//! is polled.
7+
8+
use std::future::Future;
9+
use std::sync::Arc;
10+
use tokio_stream::Stream;
11+
12+
use super::task_local_context;
13+
use super::Context;
14+
15+
/// Extension trait for `Future` to provide context propagation.
16+
///
17+
/// This trait allows attaching a [`Context`] to a [`Future`], ensuring that the context
18+
/// is set as the current task-local context whenever the future is polled.
19+
///
20+
/// # Examples
21+
///
22+
/// ```rust
23+
/// # use std::sync::Arc;
24+
/// # use grpc::context::{Context, FutureExt};
25+
/// # async fn example() {
26+
/// let context = grpc::context::current();
27+
/// let future = async {
28+
/// // Context is available here
29+
/// assert!(grpc::context::current().deadline().is_none());
30+
/// };
31+
///
32+
/// future.with_context(context).await;
33+
/// # }
34+
/// ```
35+
pub trait FutureExt: Future {
36+
/// Attach a context to this future.
37+
///
38+
/// The context will be set as the current task-local context whenever the future is polled.
39+
fn with_context(self, context: Arc<dyn Context>) -> impl Future<Output = Self::Output>
40+
where
41+
Self: Sized,
42+
{
43+
task_local_context::ContextScope::new(self, context)
44+
}
45+
}
46+
47+
impl<F: Future> FutureExt for F {}
48+
49+
/// Extension trait for `Stream` to provide context propagation.
50+
///
51+
/// This trait allows attaching a [`Context`] to a [`Stream`], ensuring that the context
52+
/// is set as the current task-local context whenever the stream is polled.
53+
///
54+
/// # Examples
55+
///
56+
/// ```rust
57+
/// # use std::sync::Arc;
58+
/// # use grpc::context::{Context, StreamExt};
59+
/// # use tokio_stream::StreamExt as _;
60+
/// # async fn example() {
61+
/// let context = grpc::context::current();
62+
/// let stream = tokio_stream::iter(vec![1, 2, 3]);
63+
///
64+
/// let mut scoped_stream = stream.with_context(context);
65+
///
66+
/// while let Some(item) = scoped_stream.next().await {
67+
/// // Context is available here
68+
/// assert!(grpc::context::current().deadline().is_none());
69+
/// }
70+
/// # }
71+
/// ```
72+
pub trait StreamExt: Stream {
73+
/// Attach a context to this stream.
74+
///
75+
/// The context will be set as the current task-local context whenever the stream is polled.
76+
fn with_context(self, context: Arc<dyn Context>) -> impl Stream<Item = Self::Item>
77+
where
78+
Self: Sized,
79+
{
80+
task_local_context::ContextScope::new(self, context)
81+
}
82+
}
83+
84+
impl<S: Stream> StreamExt for S {}
85+
86+
#[cfg(test)]
87+
mod tests {
88+
use super::super::ContextImpl;
89+
use super::*;
90+
use tokio_stream::StreamExt as _;
91+
92+
#[tokio::test]
93+
async fn test_future_ext_attaches_context_correctly() {
94+
let ctx = ContextImpl::default();
95+
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10);
96+
let ctx = ctx.with_deadline(deadline);
97+
98+
let future = async {
99+
let current_ctx = super::task_local_context::current();
100+
assert_eq!(current_ctx.deadline(), Some(deadline));
101+
};
102+
103+
future.with_context(ctx).await;
104+
}
105+
106+
#[tokio::test]
107+
async fn test_stream_ext_attaches_context_correctly() {
108+
let ctx = ContextImpl::default();
109+
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10);
110+
let ctx = ctx.with_deadline(deadline);
111+
112+
let stream = async_stream::stream! {
113+
let current_ctx = super::task_local_context::current();
114+
assert_eq!(current_ctx.deadline(), Some(deadline));
115+
yield 1;
116+
};
117+
118+
let scoped_stream = stream.with_context(ctx);
119+
tokio::pin!(scoped_stream);
120+
scoped_stream.next().await;
121+
}
122+
}

0 commit comments

Comments
 (0)