1- //! Connection hooks for customizing connection setup .
1+ //! Connection hooks for customizing connection establishment .
22//!
3- //! The [`ConnectionHook`] trait provides a way to intercept connections during setup,
4- //! enabling custom authentication, handshakes, or protocol negotiations.
3+ //! Connection hooks are attached when establishing connections and allow custom
4+ //! authentication, handshakes, or protocol negotiations. The [`ConnectionHook`] trait
5+ //! is called during connection setup, before the connection is used for messaging.
56//!
67//! # Built-in Hooks
78//!
1314//!
1415//! Implement [`ConnectionHook`] for custom authentication or protocol negotiation:
1516//!
16- //! ```rust,ignore
17- //! use msg_socket::ConnectionHook;
18- //! use std::io;
17+ //! ```no_run
18+ //! use msg_socket::hooks::{ConnectionHook, Error, HookResult};
1919//! use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
2020//!
2121//! struct MyAuth;
2222//!
23+ //! #[derive(Debug, thiserror::Error)]
24+ //! enum MyAuthError {
25+ //! #[error("invalid token")]
26+ //! InvalidToken,
27+ //! }
28+ //!
2329//! impl<Io> ConnectionHook<Io> for MyAuth
2430//! where
2531//! Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
2632//! {
27- //! async fn on_connection(&self, mut io: Io) -> Result<Io, HookError> {
33+ //! type Error = MyAuthError;
34+ //!
35+ //! async fn on_connection(&self, mut io: Io) -> HookResult<Io, Self::Error> {
2836//! let mut buf = [0u8; 32];
2937//! io.read_exact(&mut buf).await?;
3038//! if &buf == b"expected_token_value_32_bytes!!!" {
3139//! io.write_all(b"OK").await?;
3240//! Ok(io)
3341//! } else {
34- //! Err(HookError::custom("invalid token" ))
42+ //! Err(Error::hook(MyAuthError::InvalidToken ))
3543//! }
3644//! }
3745//! }
3846//! ```
47+ //!
48+ //! # Future Extensions
49+ //!
50+ //! TODO: Additional hooks may be added for different parts of the connection lifecycle
51+ //! (e.g., disconnection, reconnection, periodic health checks).
3952
40- use std:: { error:: Error as StdError , fmt , future:: Future , io, pin:: Pin , sync:: Arc } ;
53+ use std:: { error:: Error as StdError , future:: Future , io, pin:: Pin , sync:: Arc } ;
4154
4255use tokio:: io:: { AsyncRead , AsyncWrite } ;
4356
4457pub mod token;
4558
4659/// Error type for connection hooks.
4760///
48- /// This enum provides two variants:
49- /// - `Io` for standard I/O errors
50- /// - `Custom` for hook-specific errors (type-erased)
51- #[ derive( Debug ) ]
52- pub enum HookError {
61+ /// Distinguishes between I/O errors and hook-specific errors.
62+ #[ derive( Debug , thiserror:: Error ) ]
63+ pub enum Error < E > {
5364 /// An I/O error occurred.
54- Io ( io:: Error ) ,
55- /// A custom hook-specific error.
56- Custom ( Box < dyn StdError + Send + Sync + ' static > ) ,
65+ #[ error( "IO error: {0}" ) ]
66+ Io ( #[ from] io:: Error ) ,
67+ /// A hook-specific error.
68+ #[ error( "Hook error: {0}" ) ]
69+ Hook ( #[ source] E ) ,
5770}
5871
59- impl HookError {
60- /// Creates a custom error from any error type.
61- pub fn custom < E > ( error : E ) -> Self
62- where
63- E : StdError + Send + Sync + ' static ,
64- {
65- Self :: Custom ( Box :: new ( error) )
66- }
67-
68- /// Creates a custom error from a string message.
69- pub fn message ( msg : impl Into < String > ) -> Self {
70- Self :: Io ( io:: Error :: other ( msg. into ( ) ) )
72+ impl < E > Error < E > {
73+ /// Create a hook error from a hook-specific error.
74+ pub fn hook ( err : E ) -> Self {
75+ Error :: Hook ( err)
7176 }
7277}
7378
74- impl fmt:: Display for HookError {
75- fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
76- match self {
77- Self :: Io ( e) => write ! ( f, "IO error: {e}" ) ,
78- Self :: Custom ( e) => write ! ( f, "Hook error: {e}" ) ,
79- }
80- }
81- }
82-
83- impl StdError for HookError {
84- fn source ( & self ) -> Option < & ( dyn StdError + ' static ) > {
85- match self {
86- Self :: Io ( e) => Some ( e) ,
87- Self :: Custom ( e) => Some ( e. as_ref ( ) ) ,
88- }
89- }
90- }
79+ /// Result type for connection hooks.
80+ ///
81+ /// This is intentionally named `HookResult` (not `Result`) to make it clear this is not
82+ /// `std::result::Result`. A `HookResult` can be:
83+ /// - `Ok(io)` - success, returns the IO stream
84+ /// - `Err(Error::Io(..))` - an I/O error occurred
85+ /// - `Err(Error::Hook(..))` - a hook-specific error occurred
86+ pub type HookResult < T , E > = std:: result:: Result < T , Error < E > > ;
9187
92- impl From < io:: Error > for HookError {
93- fn from ( error : io:: Error ) -> Self {
94- Self :: Io ( error)
95- }
96- }
88+ /// Type-erased hook result used internally by drivers.
89+ pub ( crate ) type ErasedHookResult < T > = HookResult < T , Box < dyn StdError + Send + Sync > > ;
9790
98- /// Hook executed during connection setup .
91+ /// Connection hook executed during connection establishment .
9992///
10093/// For server sockets (Rep, Pub): called when a connection is accepted.
10194/// For client sockets (Req, Sub): called after connecting.
10295///
103- /// The hook receives the raw IO stream and has full control over the handshake protocol.
96+ /// The connection hook receives the raw IO stream and has full control over the handshake protocol.
10497pub trait ConnectionHook < Io > : Send + Sync + ' static
10598where
10699 Io : AsyncRead + AsyncWrite + Send + Unpin + ' static ,
107100{
101+ /// The hook-specific error type.
102+ type Error : StdError + Send + Sync + ' static ;
103+
108104 /// Called when a connection is established.
109105 ///
110106 /// # Arguments
111107 /// * `io` - The raw IO stream for this connection
112108 ///
113109 /// # Returns
114- /// The IO stream on success (potentially wrapped/transformed), or an error to reject
115- /// the connection.
116- fn on_connection ( & self , io : Io ) -> impl Future < Output = Result < Io , HookError > > + Send ;
110+ /// - `Ok(io)` - The IO stream on success (potentially wrapped/transformed)
111+ /// - `Err(Error::Io(..))` - An I/O error occurred
112+ /// - `Err(Error::Hook(Self::Error))` - A hook-specific error to reject the connection
113+ fn on_connection ( & self , io : Io ) -> impl Future < Output = HookResult < Io , Self :: Error > > + Send ;
117114}
118115
119116// ============================================================================
120- // Type-erased hook for internal use
117+ // Type-erased connection hook for internal use
121118// ============================================================================
122119
123120/// Type-erased connection hook for internal use.
124121///
125- /// This trait allows storing hooks with different concrete types behind a single
126- /// `Arc<dyn ConnectionHookErased<Io>>`.
122+ /// This trait allows storing connection hooks with different concrete types behind a single
123+ /// `Arc<dyn ConnectionHookErased<Io>>`. The hook error type is erased to `Box<dyn Error>`.
127124pub ( crate ) trait ConnectionHookErased < Io > : Send + Sync + ' static
128125where
129126 Io : AsyncRead + AsyncWrite + Send + Unpin + ' static ,
130127{
131128 fn on_connection (
132129 self : Arc < Self > ,
133130 io : Io ,
134- ) -> Pin < Box < dyn Future < Output = Result < Io , HookError > > + Send + ' static > > ;
131+ ) -> Pin < Box < dyn Future < Output = ErasedHookResult < Io > > + Send + ' static > > ;
135132}
136133
137134impl < T , Io > ConnectionHookErased < Io > for T
@@ -142,19 +139,12 @@ where
142139 fn on_connection (
143140 self : Arc < Self > ,
144141 io : Io ,
145- ) -> Pin < Box < dyn Future < Output = Result < Io , HookError > > + Send + ' static > > {
146- Box :: pin ( async move { ConnectionHook :: on_connection ( & * self , io) . await } )
142+ ) -> Pin < Box < dyn Future < Output = ErasedHookResult < Io > > + Send + ' static > > {
143+ Box :: pin ( async move {
144+ ConnectionHook :: on_connection ( & * self , io) . await . map_err ( |e| match e {
145+ Error :: Io ( io_err) => Error :: Io ( io_err) ,
146+ Error :: Hook ( hook_err) => Error :: Hook ( Box :: new ( hook_err) as Box < dyn StdError + Send + Sync > ) ,
147+ } )
148+ } )
147149 }
148150}
149-
150- // ============================================================================
151- // Hook result type for driver tasks
152- // ============================================================================
153-
154- /// The result of running a connection hook.
155- ///
156- /// Contains the processed IO stream and associated address.
157- pub ( crate ) struct HookResult < Io , A > {
158- pub ( crate ) stream : Io ,
159- pub ( crate ) addr : A ,
160- }
0 commit comments