Skip to content

Commit 6e74279

Browse files
committed
refactor(auth): connection hooks
feat(req,rep,pub,sub): support custom connection hooks
1 parent 4827ec2 commit 6e74279

File tree

20 files changed

+558
-406
lines changed

20 files changed

+558
-406
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

msg-socket/src/hooks/mod.rs

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
//! Connection hooks for customizing connection setup.
2+
//!
3+
//! The [`ConnectionHook`] trait provides a way to intercept connections during setup,
4+
//! enabling custom authentication, handshakes, or protocol negotiations.
5+
//!
6+
//! # Built-in Hooks
7+
//!
8+
//! The [`token`] module provides ready-to-use token-based authentication hooks:
9+
//! - [`token::ServerHook`] - Server-side hook that validates client tokens
10+
//! - [`token::ClientHook`] - Client-side hook that sends a token to the server
11+
//!
12+
//! # Custom Hooks
13+
//!
14+
//! Implement [`ConnectionHook`] for custom authentication or protocol negotiation:
15+
//!
16+
//! ```rust,ignore
17+
//! use msg_socket::ConnectionHook;
18+
//! use std::io;
19+
//! use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
20+
//!
21+
//! struct MyAuth;
22+
//!
23+
//! impl<Io> ConnectionHook<Io> for MyAuth
24+
//! where
25+
//! Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
26+
//! {
27+
//! async fn on_connection(&self, mut io: Io) -> Result<Io, HookError> {
28+
//! let mut buf = [0u8; 32];
29+
//! io.read_exact(&mut buf).await?;
30+
//! if &buf == b"expected_token_value_32_bytes!!!" {
31+
//! io.write_all(b"OK").await?;
32+
//! Ok(io)
33+
//! } else {
34+
//! Err(HookError::custom("invalid token"))
35+
//! }
36+
//! }
37+
//! }
38+
//! ```
39+
40+
use std::{error::Error as StdError, fmt, future::Future, io, pin::Pin, sync::Arc};
41+
42+
use tokio::io::{AsyncRead, AsyncWrite};
43+
44+
pub mod token;
45+
46+
/// Error type for connection hooks.
47+
///
48+
/// This enum provides two variants:
49+
/// - `Io` for standard I/O errors
50+
/// - `Custom` for hook-specific errors (type-erased)
51+
#[derive(Debug)]
52+
pub enum HookError {
53+
/// An I/O error occurred.
54+
Io(io::Error),
55+
/// A custom hook-specific error.
56+
Custom(Box<dyn StdError + Send + Sync + 'static>),
57+
}
58+
59+
impl HookError {
60+
/// Creates a custom error from any error type.
61+
pub fn custom<E>(error: E) -> Self
62+
where
63+
E: StdError + Send + Sync + 'static,
64+
{
65+
Self::Custom(Box::new(error))
66+
}
67+
68+
/// Creates a custom error from a string message.
69+
pub fn message(msg: impl Into<String>) -> Self {
70+
Self::Io(io::Error::other(msg.into()))
71+
}
72+
}
73+
74+
impl fmt::Display for HookError {
75+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76+
match self {
77+
Self::Io(e) => write!(f, "IO error: {e}"),
78+
Self::Custom(e) => write!(f, "Hook error: {e}"),
79+
}
80+
}
81+
}
82+
83+
impl StdError for HookError {
84+
fn source(&self) -> Option<&(dyn StdError + 'static)> {
85+
match self {
86+
Self::Io(e) => Some(e),
87+
Self::Custom(e) => Some(e.as_ref()),
88+
}
89+
}
90+
}
91+
92+
impl From<io::Error> for HookError {
93+
fn from(error: io::Error) -> Self {
94+
Self::Io(error)
95+
}
96+
}
97+
98+
/// Hook executed during connection setup.
99+
///
100+
/// For server sockets (Rep, Pub): called when a connection is accepted.
101+
/// For client sockets (Req, Sub): called after connecting.
102+
///
103+
/// The hook receives the raw IO stream and has full control over the handshake protocol.
104+
pub trait ConnectionHook<Io>: Send + Sync + 'static
105+
where
106+
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
107+
{
108+
/// Called when a connection is established.
109+
///
110+
/// # Arguments
111+
/// * `io` - The raw IO stream for this connection
112+
///
113+
/// # Returns
114+
/// The IO stream on success (potentially wrapped/transformed), or an error to reject
115+
/// the connection.
116+
fn on_connection(&self, io: Io) -> impl Future<Output = Result<Io, HookError>> + Send;
117+
}
118+
119+
// ============================================================================
120+
// Type-erased hook for internal use
121+
// ============================================================================
122+
123+
/// Type-erased connection hook for internal use.
124+
///
125+
/// This trait allows storing hooks with different concrete types behind a single
126+
/// `Arc<dyn ConnectionHookErased<Io>>`.
127+
pub(crate) trait ConnectionHookErased<Io>: Send + Sync + 'static
128+
where
129+
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
130+
{
131+
fn on_connection(
132+
self: Arc<Self>,
133+
io: Io,
134+
) -> Pin<Box<dyn Future<Output = Result<Io, HookError>> + Send + 'static>>;
135+
}
136+
137+
impl<T, Io> ConnectionHookErased<Io> for T
138+
where
139+
T: ConnectionHook<Io>,
140+
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
141+
{
142+
fn on_connection(
143+
self: Arc<Self>,
144+
io: Io,
145+
) -> Pin<Box<dyn Future<Output = Result<Io, HookError>> + Send + 'static>> {
146+
Box::pin(async move { ConnectionHook::on_connection(&*self, io).await })
147+
}
148+
}
149+
150+
// ============================================================================
151+
// Hook result type for driver tasks
152+
// ============================================================================
153+
154+
/// The result of running a connection hook.
155+
///
156+
/// Contains the processed IO stream and associated address.
157+
pub(crate) struct HookResult<Io, A> {
158+
pub(crate) stream: Io,
159+
pub(crate) addr: A,
160+
}

msg-socket/src/hooks/token.rs

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
//! Token-based authentication hooks.
2+
//!
3+
//! This module provides ready-to-use connection hooks for simple token-based authentication:
4+
//!
5+
//! - [`ServerHook`] - Server-side hook that validates client tokens
6+
//! - [`ClientHook`] - Client-side hook that sends a token to the server
7+
//!
8+
//! # Example
9+
//!
10+
//! ```rust,ignore
11+
//! use msg_socket::{RepSocket, ReqSocket, tcp::Tcp};
12+
//! use msg_socket::hooks::token::{ServerHook, ClientHook};
13+
//! use bytes::Bytes;
14+
//!
15+
//! // Server side - validates incoming tokens
16+
//! let rep = RepSocket::new(Tcp::default())
17+
//! .with_connection_hook(ServerHook::new(|token| {
18+
//! // Custom validation logic
19+
//! token == b"secret"
20+
//! }));
21+
//!
22+
//! // Client side - sends token on connect
23+
//! let req = ReqSocket::new(Tcp::default())
24+
//! .with_connection_hook(ClientHook::new(Bytes::from("secret")));
25+
//! ```
26+
27+
use bytes::Bytes;
28+
use futures::SinkExt;
29+
use tokio::io::{AsyncRead, AsyncWrite};
30+
use tokio_stream::StreamExt;
31+
use tokio_util::codec::Framed;
32+
33+
use crate::hooks::{ConnectionHook, HookError};
34+
use msg_wire::auth;
35+
36+
/// Server-side authentication hook that validates incoming client tokens.
37+
///
38+
/// When a client connects, this hook:
39+
/// 1. Waits for the client to send an auth token
40+
/// 2. Validates the token using the provided validator function
41+
/// 3. Sends an ACK on success, or rejects the connection on failure
42+
///
43+
/// # Example
44+
///
45+
/// ```rust,ignore
46+
/// use msg_socket::hooks::token::ServerHook;
47+
///
48+
/// // Accept all tokens
49+
/// let hook = ServerHook::accept_all();
50+
///
51+
/// // Custom validation
52+
/// let hook = ServerHook::new(|token| token == b"my_secret_token");
53+
/// ```
54+
pub struct ServerHook<F> {
55+
validator: F,
56+
}
57+
58+
impl ServerHook<fn(&Bytes) -> bool> {
59+
/// Creates a server hook that accepts all tokens.
60+
pub fn accept_all() -> Self {
61+
Self { validator: |_| true }
62+
}
63+
}
64+
65+
impl<F> ServerHook<F>
66+
where
67+
F: Fn(&Bytes) -> bool + Send + Sync + 'static,
68+
{
69+
/// Creates a new server hook with the given validator function.
70+
///
71+
/// The validator receives the client's token and returns `true` to accept
72+
/// the connection or `false` to reject it.
73+
pub fn new(validator: F) -> Self {
74+
Self { validator }
75+
}
76+
}
77+
78+
impl<Io, F> ConnectionHook<Io> for ServerHook<F>
79+
where
80+
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
81+
F: Fn(&Bytes) -> bool + Send + Sync + 'static,
82+
{
83+
async fn on_connection(&self, io: Io) -> Result<Io, HookError> {
84+
let mut conn = Framed::new(io, auth::Codec::new_server());
85+
86+
// Wait for client authentication message
87+
let msg = conn
88+
.next()
89+
.await
90+
.ok_or_else(|| HookError::message("connection closed"))?
91+
.map_err(HookError::custom)?;
92+
93+
let auth::Message::Auth(token) = msg else {
94+
return Err(HookError::message("expected auth message"));
95+
};
96+
97+
// Validate the token
98+
if !(self.validator)(&token) {
99+
conn.send(auth::Message::Reject).await?;
100+
return Err(HookError::message("authentication rejected"));
101+
}
102+
103+
// Send acknowledgment
104+
conn.send(auth::Message::Ack).await?;
105+
106+
Ok(conn.into_inner())
107+
}
108+
}
109+
110+
/// Client-side authentication hook that sends a token to the server.
111+
///
112+
/// When connecting to a server, this hook:
113+
/// 1. Sends the configured token to the server
114+
/// 2. Waits for the server's ACK response
115+
/// 3. Returns an error if the server rejects the token
116+
///
117+
/// # Example
118+
///
119+
/// ```rust,ignore
120+
/// use msg_socket::hooks::token::ClientHook;
121+
/// use bytes::Bytes;
122+
///
123+
/// let hook = ClientHook::new(Bytes::from("my_secret_token"));
124+
/// ```
125+
pub struct ClientHook {
126+
token: Bytes,
127+
}
128+
129+
impl ClientHook {
130+
/// Creates a new client hook with the given authentication token.
131+
pub fn new(token: Bytes) -> Self {
132+
Self { token }
133+
}
134+
}
135+
136+
impl<Io> ConnectionHook<Io> for ClientHook
137+
where
138+
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
139+
{
140+
async fn on_connection(&self, io: Io) -> Result<Io, HookError> {
141+
let mut conn = Framed::new(io, auth::Codec::new_client());
142+
143+
// Send authentication token
144+
conn.send(auth::Message::Auth(self.token.clone())).await?;
145+
146+
conn.flush().await?;
147+
148+
// Wait for server acknowledgment
149+
let ack = conn
150+
.next()
151+
.await
152+
.ok_or_else(|| HookError::message("connection closed"))?
153+
.map_err(HookError::custom)?;
154+
155+
if !matches!(ack, auth::Message::Ack) {
156+
return Err(HookError::message("authentication denied"));
157+
}
158+
159+
Ok(conn.into_inner())
160+
}
161+
}

msg-socket/src/lib.rs

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
#![doc(issue_tracker_base_url = "https://github.com/chainbound/msg-rs/issues/")]
1111
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
1212
#![cfg_attr(not(test), warn(unused_crate_dependencies))]
13-
use bytes::Bytes;
14-
use tokio::io::{AsyncRead, AsyncWrite};
15-
16-
use msg_transport::Address;
1713

1814
pub mod stats;
1915

16+
pub mod hooks;
17+
pub use hooks::{ConnectionHook, HookError};
18+
19+
// Re-export type-erased hook and HookResult for internal use
20+
pub(crate) use hooks::{ConnectionHookErased, HookResult};
21+
2022
#[path = "pub/mod.rs"]
2123
mod pubs;
2224
pub use pubs::{PubError, PubOptions, PubSocket};
@@ -53,18 +55,6 @@ impl RequestId {
5355
}
5456
}
5557

56-
/// An interface for authenticating clients, given their ID.
57-
pub trait Authenticator: Send + Sync + Unpin + 'static {
58-
fn authenticate(&self, id: &Bytes) -> bool;
59-
}
60-
61-
/// The result of an authentication attempt.
62-
pub(crate) struct AuthResult<S: AsyncRead + AsyncWrite, A: Address> {
63-
id: Bytes,
64-
addr: A,
65-
stream: S,
66-
}
67-
6858
/// The performance profile to tune socket options for.
6959
#[derive(Debug, Clone, Default, Copy, PartialEq, Eq)]
7060
pub enum Profile {

0 commit comments

Comments
 (0)