Skip to content

Commit d195f49

Browse files
committed
feat(socket): initial scaffold for client/server ConnManager
1 parent 4827ec2 commit d195f49

File tree

5 files changed

+243
-11
lines changed

5 files changed

+243
-11
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ unused_rounding = "warn"
173173
use_self = "warn"
174174
useless_let_if_seq = "warn"
175175
zero_sized_map_values = "warn"
176+
default_trait_access = "warn"
176177

177178
# These are nursery lints which have findings. Allow them for now. Some are not
178179
# quite mature enough for use in our codebase and some we don't really want.

msg-socket/src/connection/backoff.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::{
66
};
77
use tokio::time::sleep;
88

9-
use crate::ConnOptions;
9+
use crate::ClientOptions;
1010

1111
/// Helper trait alias for backoff streams.
1212
/// We define any stream that yields `Duration`s as a backoff
@@ -41,8 +41,8 @@ impl ExponentialBackoff {
4141
}
4242
}
4343

44-
impl From<&ConnOptions> for ExponentialBackoff {
45-
fn from(options: &ConnOptions) -> Self {
44+
impl From<&ClientOptions> for ExponentialBackoff {
45+
fn from(options: &ClientOptions) -> Self {
4646
Self::new(options.backoff_duration, options.retry_attempts)
4747
}
4848
}

msg-socket/src/req/conn_manager.rs

Lines changed: 221 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ use std::{
55
task::{Context, Poll},
66
};
77

8+
use arc_swap::ArcSwap;
89
use bytes::Bytes;
910
use futures::{Future, FutureExt, SinkExt, StreamExt};
1011
use msg_common::span::{EnterSpan as _, WithSpan};
1112
use tokio_util::codec::Framed;
1213
use tracing::Instrument;
1314

14-
use crate::{ConnOptions, ConnectionState, ExponentialBackoff};
15+
use crate::{ClientOptions, ConnectionState, ExponentialBackoff};
1516

1617
use msg_transport::{Address, MeteredIo, Transport};
1718
use msg_wire::{auth, reqrep};
@@ -33,10 +34,227 @@ pub(crate) type Conn<Io, S, A> = Framed<MeteredIo<Io, S, A>, reqrep::Codec>;
3334
/// A connection controller that manages the connection to a server with an exponential backoff.
3435
pub(crate) type ConnCtl<Io, S, A> = ConnectionState<Conn<Io, S, A>, ExponentialBackoff, A>;
3536

37+
/// A connection manager for managing client OR server connections.
38+
/// The type parameter `S` contains the connection state, including its "side" (client / server).
39+
pub(crate) struct ConnectionManager<T, A, S>
40+
where
41+
T: Transport<A>,
42+
A: Address,
43+
{
44+
/// The connection state, including its "side" (client / server).
45+
state: S,
46+
/// The transport used for the connection.
47+
transport: T,
48+
/// Transport stats for metering IO.
49+
transport_stats: Arc<ArcSwap<T::Stats>>,
50+
51+
/// Connection manager tracing span.
52+
span: tracing::Span,
53+
}
54+
55+
/// A client connection to a remote server.
56+
pub(crate) struct ClientConnection<T, A>
57+
where
58+
T: Transport<A>,
59+
A: Address,
60+
{
61+
/// Options for the connection manager.
62+
options: ClientOptions,
63+
/// The address of the remote.
64+
addr: A,
65+
/// The connection task which handles the connection to the server.
66+
conn_task: Option<WithSpan<ConnTask<T::Io, T::Error>>>,
67+
/// The transport controller, wrapped in a [`ConnectionState`] for backoff.
68+
/// The [`Framed`] object can send and receive messages from the socket.
69+
conn_ctl: ConnCtl<T::Io, T::Stats, A>,
70+
}
71+
72+
impl<T, A> ClientConnection<T, A>
73+
where
74+
T: Transport<A>,
75+
A: Address,
76+
{
77+
/// Reset the connection state to inactive, so that it will be re-tried.
78+
///
79+
/// This is done when the connection is closed or an error occurs.
80+
#[inline]
81+
pub(crate) fn reset_connection(&mut self) {
82+
self.conn_ctl = ConnectionState::Inactive {
83+
addr: self.addr.clone(),
84+
backoff: ExponentialBackoff::from(&self.options),
85+
};
86+
}
87+
88+
/// Returns a mutable reference to the connection channel if it is active.
89+
#[inline]
90+
pub fn active_connection(&mut self) -> Option<&mut Conn<T::Io, T::Stats, A>> {
91+
if let ConnectionState::Active { ref mut channel } = self.conn_ctl {
92+
Some(channel)
93+
} else {
94+
None
95+
}
96+
}
97+
}
98+
99+
/// A local server connection. Manages the connection lifecycle:
100+
/// - Accepting incoming connections.
101+
/// - Handling established connections.
102+
pub(crate) struct ServerConnection<T, A>
103+
where
104+
T: Transport<A>,
105+
A: Address,
106+
{
107+
/// The local address.
108+
addr: A,
109+
/// The accept task which handles accepting an incoming connection.
110+
accept_task: Option<WithSpan<T::Accept>>,
111+
/// The inbound connection.
112+
conn: Conn<T::Io, T::Stats, A>,
113+
}
114+
115+
impl<T, A> ConnectionManager<T, A, ClientConnection<T, A>>
116+
where
117+
T: Transport<A>,
118+
A: Address,
119+
{
120+
pub(crate) fn new(
121+
options: ClientOptions,
122+
transport: T,
123+
addr: A,
124+
conn_ctl: ConnCtl<T::Io, T::Stats, A>,
125+
transport_stats: Arc<ArcSwap<T::Stats>>,
126+
span: tracing::Span,
127+
) -> Self {
128+
let conn = ClientConnection { options, addr, conn_task: None, conn_ctl };
129+
130+
Self { state: conn, transport, transport_stats, span }
131+
}
132+
133+
/// Start the connection task to the server, handling authentication if necessary.
134+
/// The result will be polled by the driver and re-tried according to the backoff policy.
135+
fn try_connect(&mut self) {
136+
let connect = self.transport.connect(self.state.addr.clone());
137+
let token = self.state.options.auth_token.clone();
138+
139+
let task = async move {
140+
let io = connect.await?;
141+
142+
let Some(token) = token else {
143+
return Ok(io);
144+
};
145+
146+
authentication_handshake::<T, A>(io, token).await
147+
}
148+
.in_current_span();
149+
150+
// FIX: coercion to BoxFuture for [`SpanExt::with_current_span`]
151+
self.state.conn_task = Some(WithSpan::current(Box::pin(task)));
152+
}
153+
154+
/// Reset the connection state to inactive, so that it will be re-tried.
155+
///
156+
/// This is done when the connection is closed or an error occurs.
157+
#[inline]
158+
pub(crate) fn reset_connection(&mut self) {
159+
self.state.reset_connection();
160+
}
161+
162+
/// Poll connection management logic: connection task, backoff, and retry logic.
163+
/// Loops until the connection is active, then returns a mutable reference to the channel.
164+
///
165+
/// Note: this is not a `Future` impl because we want to return a reference; doing it in
166+
/// a `Future` would require lifetime headaches or unsafe code.
167+
///
168+
/// Returns:
169+
/// * `Poll::Ready(Some(&mut channel))` if the connection is active
170+
/// * `Poll::Ready(None)` if we should terminate (max retries exceeded)
171+
/// * `Poll::Pending` if we need to wait for backoff
172+
#[allow(clippy::type_complexity)]
173+
pub(crate) fn poll(
174+
&mut self,
175+
cx: &mut Context<'_>,
176+
) -> Poll<Option<&mut Conn<T::Io, T::Stats, A>>> {
177+
loop {
178+
// Poll the active connection task, if any
179+
if let Some(ref mut conn_task) = self.state.conn_task {
180+
if let Poll::Ready(result) = conn_task.poll_unpin(cx).enter() {
181+
// As soon as the connection task finishes, set it to `None`.
182+
// - If it was successful, set the connection to active
183+
// - If it failed, it will be re-tried until the backoff limit is reached.
184+
self.state.conn_task = None;
185+
186+
match result.inner {
187+
Ok(io) => {
188+
tracing::info!("connected");
189+
190+
let metered = MeteredIo::new(io, self.transport_stats.clone());
191+
let framed = Framed::new(metered, reqrep::Codec::new());
192+
self.state.conn_ctl = ConnectionState::Active { channel: framed };
193+
}
194+
Err(e) => {
195+
tracing::error!(?e, "failed to connect");
196+
}
197+
}
198+
}
199+
}
200+
201+
// If the connection is inactive, try to connect to the server or poll the backoff
202+
// timer if we're already trying to connect.
203+
if let ConnectionState::Inactive { backoff, .. } = &mut self.state.conn_ctl {
204+
let Poll::Ready(item) = backoff.poll_next_unpin(cx) else {
205+
return Poll::Pending;
206+
};
207+
208+
let _span = tracing::info_span!(parent: &self.span, "connect").entered();
209+
210+
if let Some(duration) = item {
211+
if self.state.conn_task.is_none() {
212+
tracing::debug!(backoff = ?duration, "trying connection");
213+
self.try_connect();
214+
} else {
215+
tracing::debug!(
216+
backoff = ?duration,
217+
"not retrying as there is already a connection task"
218+
);
219+
}
220+
} else {
221+
tracing::error!("exceeded maximum number of retries, terminating connection");
222+
return Poll::Ready(None);
223+
}
224+
}
225+
226+
if let ConnectionState::Active { ref mut channel } = self.state.conn_ctl {
227+
return Poll::Ready(Some(channel));
228+
}
229+
}
230+
}
231+
}
232+
233+
pub struct ServerOptions {}
234+
235+
impl<T, A> ConnectionManager<T, A, ServerConnection<T, A>>
236+
where
237+
T: Transport<A>,
238+
A: Address,
239+
{
240+
pub(crate) fn new(
241+
options: ServerOptions,
242+
transport: T,
243+
addr: A,
244+
conn: Conn<T::Io, T::Stats, A>,
245+
transport_stats: Arc<ArcSwap<T::Stats>>,
246+
span: tracing::Span,
247+
) -> Self {
248+
let conn = ServerConnection { addr, accept_task: None, conn };
249+
250+
Self { state: conn, transport, transport_stats, span }
251+
}
252+
}
253+
36254
/// Manages the connection lifecycle: connecting, reconnecting, and maintaining the connection.
37255
pub(crate) struct ConnManager<T: Transport<A>, A: Address> {
38256
/// Options for the connection manager.
39-
options: ConnOptions,
257+
options: ClientOptions,
40258
/// The connection task which handles the connection to the server.
41259
conn_task: Option<WithSpan<ConnTask<T::Io, T::Error>>>,
42260
/// The transport controller, wrapped in a [`ConnectionState`] for backoff.
@@ -89,7 +307,7 @@ where
89307
A: Address,
90308
{
91309
pub(crate) fn new(
92-
options: ConnOptions,
310+
options: ClientOptions,
93311
transport: T,
94312
addr: A,
95313
conn_ctl: ConnCtl<T::Io, T::Stats, A>,

msg-socket/src/req/driver.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,22 @@ use msg_wire::{
2727
reqrep,
2828
};
2929

30+
/// Type state for a client connection.
31+
struct ClientConnection<T, A>
32+
where
33+
T: Transport<A>,
34+
A: Address,
35+
{
36+
conn_manager: ConnManager<T, A>,
37+
}
38+
3039
/// The request socket driver. Endless future that drives
3140
/// the socket forward.
32-
pub(crate) struct ReqDriver<T: Transport<A>, A: Address> {
41+
pub(crate) struct ReqDriver<T, A>
42+
where
43+
T: Transport<A>,
44+
A: Address,
45+
{
3346
/// Options shared with the socket.
3447
pub(crate) options: Arc<ReqOptions>,
3548
/// State shared with the socket.

msg-socket/src/req/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ impl SendCommand {
6868

6969
/// Options for the connection manager.
7070
#[derive(Debug, Clone)]
71-
pub struct ConnOptions {
71+
pub struct ClientOptions {
7272
/// Optional authentication token.
7373
pub auth_token: Option<Bytes>,
7474
/// The backoff duration for the underlying transport on reconnections.
@@ -77,7 +77,7 @@ pub struct ConnOptions {
7777
pub retry_attempts: Option<usize>,
7878
}
7979

80-
impl Default for ConnOptions {
80+
impl Default for ClientOptions {
8181
fn default() -> Self {
8282
Self {
8383
auth_token: None,
@@ -96,7 +96,7 @@ impl Default for ConnOptions {
9696
#[derive(Debug, Clone)]
9797
pub struct ReqOptions {
9898
/// Options for the connection manager.
99-
pub conn: ConnOptions,
99+
pub conn: ClientOptions,
100100
/// Timeout duration for requests.
101101
pub timeout: Duration,
102102
/// Wether to block on initial connection to the target.
@@ -211,7 +211,7 @@ impl ReqOptions {
211211
impl Default for ReqOptions {
212212
fn default() -> Self {
213213
Self {
214-
conn: ConnOptions::default(),
214+
conn: ClientOptions::default(),
215215
timeout: Duration::from_secs(5),
216216
blocking_connect: false,
217217
min_compress_size: 8192,

0 commit comments

Comments
 (0)