Skip to content

Commit c57c8f1

Browse files
committed
refactor: review
fix: better hook task spans docs: consistent "connection hook" phrasing refactor(sub): split connection and (old auth) hooks task sets fix: ensure connection hook not set when driver is running
1 parent 0941d6a commit c57c8f1

File tree

15 files changed

+307
-196
lines changed

15 files changed

+307
-196
lines changed

CLAUDE.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,42 @@ cargo flamegraph --bin example_name
9090
- **Book** - Comprehensive guide in `book/` directory, deployed to GitHub Pages
9191
- **Examples** - Rich example set covering all major usage patterns in `msg/examples/`
9292

93+
### Doc Code Blocks
94+
- Prefer `no_run` over `rust,ignore` in documentation code blocks whenever possible
95+
- Use `rust,ignore` only when the code cannot compile (e.g., pseudo-code or incomplete examples)
96+
9397
## Important Implementation Notes
9498

9599
### Connection Hooks
96100
- Implement `ConnectionHook` trait for custom authentication, handshakes, or protocol negotiation
97101
- Built-in token-based auth via `hooks::token::ServerHook` and `hooks::token::ClientHook`
98102

103+
### Tracing Pattern for Async Tasks
104+
When spawning async tasks that need tracing context, use the `WithSpan` pattern from `msg_common::span`:
105+
106+
```rust
107+
use msg_common::span::{EnterSpan as _, SpanExt as _, WithSpan};
108+
109+
// 1. Create a span with context
110+
let span = tracing::info_span!("connection_hook", ?addr);
111+
112+
// 2. Wrap the future with the span
113+
let fut = async move { /* ... */ };
114+
self.tasks.spawn(fut.with_span(span));
115+
116+
// 3. Poll with .enter() to re-enter the span, access result via .inner
117+
if let Poll::Ready(Some(Ok(result))) = self.tasks.poll_join_next(cx).enter() {
118+
match result.inner {
119+
Ok(value) => { /* ... */ }
120+
Err(e) => { /* ... */ }
121+
}
122+
}
123+
```
124+
125+
- The `JoinSet` type becomes `JoinSet<WithSpan<T>>` instead of `JoinSet<T>`
126+
- Logs inside the span don't need to repeat span fields (e.g., no `?addr` needed)
127+
- Add `#[allow(clippy::type_complexity)]` to structs with complex `WithSpan` types
128+
99129
### Statistics Collection
100130
- All socket types collect latency, throughput, and packet drop metrics
101131
- Access via socket statistics methods for monitoring

msg-socket/src/hooks/mod.rs

Lines changed: 64 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
//! Connection hooks for customizing connection setup.
1+
//! Connection hooks for customizing connection establishment.
22
//!
3-
//! The [`ConnectionHook`] trait provides a way to intercept connections during setup,
4-
//! enabling custom authentication, handshakes, or protocol negotiations.
3+
//! Connection hooks are attached when establishing connections and allow custom
4+
//! authentication, handshakes, or protocol negotiations. The [`ConnectionHook`] trait
5+
//! is called during connection setup, before the connection is used for messaging.
56
//!
67
//! # Built-in Hooks
78
//!
@@ -13,125 +14,121 @@
1314
//!
1415
//! Implement [`ConnectionHook`] for custom authentication or protocol negotiation:
1516
//!
16-
//! ```rust,ignore
17-
//! use msg_socket::ConnectionHook;
18-
//! use std::io;
17+
//! ```no_run
18+
//! use msg_socket::hooks::{ConnectionHook, Error, HookResult};
1919
//! use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
2020
//!
2121
//! struct MyAuth;
2222
//!
23+
//! #[derive(Debug, thiserror::Error)]
24+
//! enum MyAuthError {
25+
//! #[error("invalid token")]
26+
//! InvalidToken,
27+
//! }
28+
//!
2329
//! impl<Io> ConnectionHook<Io> for MyAuth
2430
//! where
2531
//! Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
2632
//! {
27-
//! async fn on_connection(&self, mut io: Io) -> Result<Io, HookError> {
33+
//! type Error = MyAuthError;
34+
//!
35+
//! async fn on_connection(&self, mut io: Io) -> HookResult<Io, Self::Error> {
2836
//! let mut buf = [0u8; 32];
2937
//! io.read_exact(&mut buf).await?;
3038
//! if &buf == b"expected_token_value_32_bytes!!!" {
3139
//! io.write_all(b"OK").await?;
3240
//! Ok(io)
3341
//! } else {
34-
//! Err(HookError::custom("invalid token"))
42+
//! Err(Error::hook(MyAuthError::InvalidToken))
3543
//! }
3644
//! }
3745
//! }
3846
//! ```
47+
//!
48+
//! # Future Extensions
49+
//!
50+
//! TODO: Additional hooks may be added for different parts of the connection lifecycle
51+
//! (e.g., disconnection, reconnection, periodic health checks).
3952
40-
use std::{error::Error as StdError, fmt, future::Future, io, pin::Pin, sync::Arc};
53+
use std::{error::Error as StdError, future::Future, io, pin::Pin, sync::Arc};
4154

4255
use tokio::io::{AsyncRead, AsyncWrite};
4356

4457
pub mod token;
4558

4659
/// Error type for connection hooks.
4760
///
48-
/// This enum provides two variants:
49-
/// - `Io` for standard I/O errors
50-
/// - `Custom` for hook-specific errors (type-erased)
51-
#[derive(Debug)]
52-
pub enum HookError {
61+
/// Distinguishes between I/O errors and hook-specific errors.
62+
#[derive(Debug, thiserror::Error)]
63+
pub enum Error<E> {
5364
/// An I/O error occurred.
54-
Io(io::Error),
55-
/// A custom hook-specific error.
56-
Custom(Box<dyn StdError + Send + Sync + 'static>),
65+
#[error("IO error: {0}")]
66+
Io(#[from] io::Error),
67+
/// A hook-specific error.
68+
#[error("Hook error: {0}")]
69+
Hook(#[source] E),
5770
}
5871

59-
impl HookError {
60-
/// Creates a custom error from any error type.
61-
pub fn custom<E>(error: E) -> Self
62-
where
63-
E: StdError + Send + Sync + 'static,
64-
{
65-
Self::Custom(Box::new(error))
66-
}
67-
68-
/// Creates a custom error from a string message.
69-
pub fn message(msg: impl Into<String>) -> Self {
70-
Self::Io(io::Error::other(msg.into()))
72+
impl<E> Error<E> {
73+
/// Create a hook error from a hook-specific error.
74+
pub fn hook(err: E) -> Self {
75+
Error::Hook(err)
7176
}
7277
}
7378

74-
impl fmt::Display for HookError {
75-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76-
match self {
77-
Self::Io(e) => write!(f, "IO error: {e}"),
78-
Self::Custom(e) => write!(f, "Hook error: {e}"),
79-
}
80-
}
81-
}
82-
83-
impl StdError for HookError {
84-
fn source(&self) -> Option<&(dyn StdError + 'static)> {
85-
match self {
86-
Self::Io(e) => Some(e),
87-
Self::Custom(e) => Some(e.as_ref()),
88-
}
89-
}
90-
}
79+
/// Result type for connection hooks.
80+
///
81+
/// This is intentionally named `HookResult` (not `Result`) to make it clear this is not
82+
/// `std::result::Result`. A `HookResult` can be:
83+
/// - `Ok(io)` - success, returns the IO stream
84+
/// - `Err(Error::Io(..))` - an I/O error occurred
85+
/// - `Err(Error::Hook(..))` - a hook-specific error occurred
86+
pub type HookResult<T, E> = std::result::Result<T, Error<E>>;
9187

92-
impl From<io::Error> for HookError {
93-
fn from(error: io::Error) -> Self {
94-
Self::Io(error)
95-
}
96-
}
88+
/// Type-erased hook result used internally by drivers.
89+
pub(crate) type ErasedHookResult<T> = HookResult<T, Box<dyn StdError + Send + Sync>>;
9790

98-
/// Hook executed during connection setup.
91+
/// Connection hook executed during connection establishment.
9992
///
10093
/// For server sockets (Rep, Pub): called when a connection is accepted.
10194
/// For client sockets (Req, Sub): called after connecting.
10295
///
103-
/// The hook receives the raw IO stream and has full control over the handshake protocol.
96+
/// The connection hook receives the raw IO stream and has full control over the handshake protocol.
10497
pub trait ConnectionHook<Io>: Send + Sync + 'static
10598
where
10699
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
107100
{
101+
/// The hook-specific error type.
102+
type Error: StdError + Send + Sync + 'static;
103+
108104
/// Called when a connection is established.
109105
///
110106
/// # Arguments
111107
/// * `io` - The raw IO stream for this connection
112108
///
113109
/// # Returns
114-
/// The IO stream on success (potentially wrapped/transformed), or an error to reject
115-
/// the connection.
116-
fn on_connection(&self, io: Io) -> impl Future<Output = Result<Io, HookError>> + Send;
110+
/// - `Ok(io)` - The IO stream on success (potentially wrapped/transformed)
111+
/// - `Err(Error::Io(..))` - An I/O error occurred
112+
/// - `Err(Error::Hook(Self::Error))` - A hook-specific error to reject the connection
113+
fn on_connection(&self, io: Io) -> impl Future<Output = HookResult<Io, Self::Error>> + Send;
117114
}
118115

119116
// ============================================================================
120-
// Type-erased hook for internal use
117+
// Type-erased connection hook for internal use
121118
// ============================================================================
122119

123120
/// Type-erased connection hook for internal use.
124121
///
125-
/// This trait allows storing hooks with different concrete types behind a single
126-
/// `Arc<dyn ConnectionHookErased<Io>>`.
122+
/// This trait allows storing connection hooks with different concrete types behind a single
123+
/// `Arc<dyn ConnectionHookErased<Io>>`. The hook error type is erased to `Box<dyn Error>`.
127124
pub(crate) trait ConnectionHookErased<Io>: Send + Sync + 'static
128125
where
129126
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
130127
{
131128
fn on_connection(
132129
self: Arc<Self>,
133130
io: Io,
134-
) -> Pin<Box<dyn Future<Output = Result<Io, HookError>> + Send + 'static>>;
131+
) -> Pin<Box<dyn Future<Output = ErasedHookResult<Io>> + Send + 'static>>;
135132
}
136133

137134
impl<T, Io> ConnectionHookErased<Io> for T
@@ -142,19 +139,12 @@ where
142139
fn on_connection(
143140
self: Arc<Self>,
144141
io: Io,
145-
) -> Pin<Box<dyn Future<Output = Result<Io, HookError>> + Send + 'static>> {
146-
Box::pin(async move { ConnectionHook::on_connection(&*self, io).await })
142+
) -> Pin<Box<dyn Future<Output = ErasedHookResult<Io>> + Send + 'static>> {
143+
Box::pin(async move {
144+
ConnectionHook::on_connection(&*self, io).await.map_err(|e| match e {
145+
Error::Io(io_err) => Error::Io(io_err),
146+
Error::Hook(hook_err) => Error::Hook(Box::new(hook_err) as Box<dyn StdError + Send + Sync>),
147+
})
148+
})
147149
}
148150
}
149-
150-
// ============================================================================
151-
// Hook result type for driver tasks
152-
// ============================================================================
153-
154-
/// The result of running a connection hook.
155-
///
156-
/// Contains the processed IO stream and associated address.
157-
pub(crate) struct HookResult<Io, A> {
158-
pub(crate) stream: Io,
159-
pub(crate) addr: A,
160-
}

0 commit comments

Comments
 (0)