Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions grpc/src/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
mod extensions;
mod task_local_context;

pub use extensions::{FutureExt, StreamExt};
pub use task_local_context::current;

use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;

/// A task-local context for propagating metadata, deadlines, and other request-scoped values.
pub trait Context: Send + Sync + 'static {
/// Get the deadline for the current context.
fn deadline(&self) -> Option<Instant>;

/// Create a new context with the given deadline.
fn with_deadline(&self, deadline: Instant) -> Arc<dyn Context>;

/// Get a value from the context extensions.
fn get(&self, type_id: TypeId) -> Option<&(dyn Any + Send + Sync)>;

/// Create a new context with the given value.
fn with_value(&self, type_id: TypeId, value: Arc<dyn Any + Send + Sync>) -> Arc<dyn Context>;
}

#[derive(Clone, Default)]
struct ContextInner {
deadline: Option<Instant>,
extensions: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
}

#[derive(Clone, Default)]
pub(crate) struct ContextImpl {
inner: Arc<ContextInner>,
}

impl Context for ContextImpl {
fn deadline(&self) -> Option<Instant> {
self.inner.deadline
}

fn with_deadline(&self, deadline: Instant) -> Arc<dyn Context> {
let mut inner = (*self.inner).clone();
inner.deadline = Some(deadline);
Arc::new(Self {
inner: Arc::new(inner),
})
}

fn get(&self, type_id: TypeId) -> Option<&(dyn Any + Send + Sync)> {
self.inner.extensions.get(&type_id).map(|v| &**v as _)
}

fn with_value(&self, type_id: TypeId, value: Arc<dyn Any + Send + Sync>) -> Arc<dyn Context> {
let mut inner = (*self.inner).clone();
inner.extensions.insert(type_id, value);
Arc::new(Self {
inner: Arc::new(inner),
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;

#[test]
fn default_context_has_no_deadline_or_extensions() {
let context = ContextImpl::default();
assert!(context.deadline().is_none());
assert!(context.get(TypeId::of::<i32>()).is_none());
}

#[test]
fn with_deadline_sets_deadline_and_preserves_original() {
let context = ContextImpl::default();
let deadline = Instant::now() + Duration::from_secs(5);
let context_with_deadline = context.with_deadline(deadline);

assert_eq!(context_with_deadline.deadline(), Some(deadline));
// Original context should remain unchanged
assert!(context.deadline().is_none());
}

#[test]
fn with_value_stores_extension_and_preserves_original() {
let context = ContextImpl::default();

#[derive(Debug, PartialEq)]
struct MyValue(i32);

let context_with_value = context.with_value(TypeId::of::<MyValue>(), Arc::new(MyValue(42)));

let value = context_with_value
.get(TypeId::of::<MyValue>())
.and_then(|v| v.downcast_ref::<MyValue>());
assert_eq!(value, Some(&MyValue(42)));

// Original context should not have the value
assert!(context.get(TypeId::of::<MyValue>()).is_none());
}

#[test]
fn with_value_overwrites_existing_extension_and_preserves_previous() {
let context = ContextImpl::default();

#[derive(Debug, PartialEq)]
struct MyValue(i32);

let ctx1 = context.with_value(TypeId::of::<MyValue>(), Arc::new(MyValue(10)));
let ctx2 = ctx1.with_value(TypeId::of::<MyValue>(), Arc::new(MyValue(20)));

let val1 = ctx1
.get(TypeId::of::<MyValue>())
.and_then(|v| v.downcast_ref::<MyValue>());
let val2 = ctx2
.get(TypeId::of::<MyValue>())
.and_then(|v| v.downcast_ref::<MyValue>());

assert_eq!(val1, Some(&MyValue(10)));
assert_eq!(val2, Some(&MyValue(20)));
}
}
122 changes: 122 additions & 0 deletions grpc/src/context/extensions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
//! Extension traits for `Future` and `Stream` to provide context propagation.
//!
//! This module provides the [`FutureExt`] and [`StreamExt`] traits, which allow
//! attaching a [`Context`] to a [`Future`] or [`Stream`]. This ensures that the
//! context is set as the current task-local context whenever the future or stream
//! is polled.

use std::future::Future;
use std::sync::Arc;
use tokio_stream::Stream;

use super::task_local_context;
use super::Context;

/// Extension trait for `Future` to provide context propagation.
///
/// This trait allows attaching a [`Context`] to a [`Future`], ensuring that the context
/// is set as the current task-local context whenever the future is polled.
///
/// # Examples
///
/// ```rust
/// # use std::sync::Arc;
/// # use grpc::context::{Context, FutureExt};
/// # async fn example() {
/// let context = grpc::context::current();
/// let future = async {
/// // Context is available here
/// assert!(grpc::context::current().deadline().is_none());
/// };
///
/// future.with_context(context).await;
/// # }
/// ```
pub trait FutureExt: Future {
/// Attach a context to this future.
///
/// The context will be set as the current task-local context whenever the future is polled.
fn with_context(self, context: Arc<dyn Context>) -> impl Future<Output = Self::Output>
where
Self: Sized,
{
task_local_context::ContextScope::new(self, context)
}
}

impl<F: Future> FutureExt for F {}

/// Extension trait for `Stream` to provide context propagation.
///
/// This trait allows attaching a [`Context`] to a [`Stream`], ensuring that the context
/// is set as the current task-local context whenever the stream is polled.
///
/// # Examples
///
/// ```rust
/// # use std::sync::Arc;
/// # use grpc::context::{Context, StreamExt};
/// # use tokio_stream::StreamExt as _;
/// # async fn example() {
/// let context = grpc::context::current();
/// let stream = tokio_stream::iter(vec![1, 2, 3]);
///
/// let mut scoped_stream = stream.with_context(context);
///
/// while let Some(item) = scoped_stream.next().await {
/// // Context is available here
/// assert!(grpc::context::current().deadline().is_none());
/// }
/// # }
/// ```
pub trait StreamExt: Stream {
/// Attach a context to this stream.
///
/// The context will be set as the current task-local context whenever the stream is polled.
fn with_context(self, context: Arc<dyn Context>) -> impl Stream<Item = Self::Item>
where
Self: Sized,
{
task_local_context::ContextScope::new(self, context)
}
}

impl<S: Stream> StreamExt for S {}

#[cfg(test)]
mod tests {
use super::super::ContextImpl;
use super::*;
use tokio_stream::StreamExt as _;

#[tokio::test]
async fn test_future_ext_attaches_context_correctly() {
let ctx = ContextImpl::default();
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10);
let ctx = ctx.with_deadline(deadline);

let future = async {
let current_ctx = super::task_local_context::current();
assert_eq!(current_ctx.deadline(), Some(deadline));
};

future.with_context(ctx).await;
}

#[tokio::test]
async fn test_stream_ext_attaches_context_correctly() {
let ctx = ContextImpl::default();
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10);
let ctx = ctx.with_deadline(deadline);

let stream = async_stream::stream! {
let current_ctx = super::task_local_context::current();
assert_eq!(current_ctx.deadline(), Some(deadline));
yield 1;
};

let scoped_stream = stream.with_context(ctx);
tokio::pin!(scoped_stream);
scoped_stream.next().await;
}
}
Loading
Loading