@@ -5,13 +5,14 @@ use std::{
55 task:: { Context , Poll } ,
66} ;
77
8+ use arc_swap:: ArcSwap ;
89use bytes:: Bytes ;
910use futures:: { Future , FutureExt , SinkExt , StreamExt } ;
1011use msg_common:: span:: { EnterSpan as _, WithSpan } ;
1112use tokio_util:: codec:: Framed ;
1213use tracing:: Instrument ;
1314
14- use crate :: { ConnOptions , ConnectionState , ExponentialBackoff } ;
15+ use crate :: { ClientOptions , ConnectionState , ExponentialBackoff } ;
1516
1617use msg_transport:: { Address , MeteredIo , Transport } ;
1718use msg_wire:: { auth, reqrep} ;
@@ -33,10 +34,227 @@ pub(crate) type Conn<Io, S, A> = Framed<MeteredIo<Io, S, A>, reqrep::Codec>;
3334/// A connection controller that manages the connection to a server with an exponential backoff.
3435pub ( crate ) type ConnCtl < Io , S , A > = ConnectionState < Conn < Io , S , A > , ExponentialBackoff , A > ;
3536
37+ /// A connection manager for managing client OR server connections.
38+ /// The type parameter `S` contains the connection state, including its "side" (client / server).
39+ pub ( crate ) struct ConnectionManager < T , A , S >
40+ where
41+ T : Transport < A > ,
42+ A : Address ,
43+ {
44+ /// The connection state, including its "side" (client / server).
45+ state : S ,
46+ /// The transport used for the connection.
47+ transport : T ,
48+ /// Transport stats for metering IO.
49+ transport_stats : Arc < ArcSwap < T :: Stats > > ,
50+
51+ /// Connection manager tracing span.
52+ span : tracing:: Span ,
53+ }
54+
55+ /// A client connection to a remote server.
56+ pub ( crate ) struct ClientConnection < T , A >
57+ where
58+ T : Transport < A > ,
59+ A : Address ,
60+ {
61+ /// Options for the connection manager.
62+ options : ClientOptions ,
63+ /// The address of the remote.
64+ addr : A ,
65+ /// The connection task which handles the connection to the server.
66+ conn_task : Option < WithSpan < ConnTask < T :: Io , T :: Error > > > ,
67+ /// The transport controller, wrapped in a [`ConnectionState`] for backoff.
68+ /// The [`Framed`] object can send and receive messages from the socket.
69+ conn_ctl : ConnCtl < T :: Io , T :: Stats , A > ,
70+ }
71+
72+ impl < T , A > ClientConnection < T , A >
73+ where
74+ T : Transport < A > ,
75+ A : Address ,
76+ {
77+ /// Reset the connection state to inactive, so that it will be re-tried.
78+ ///
79+ /// This is done when the connection is closed or an error occurs.
80+ #[ inline]
81+ pub ( crate ) fn reset_connection ( & mut self ) {
82+ self . conn_ctl = ConnectionState :: Inactive {
83+ addr : self . addr . clone ( ) ,
84+ backoff : ExponentialBackoff :: from ( & self . options ) ,
85+ } ;
86+ }
87+
88+ /// Returns a mutable reference to the connection channel if it is active.
89+ #[ inline]
90+ pub fn active_connection ( & mut self ) -> Option < & mut Conn < T :: Io , T :: Stats , A > > {
91+ if let ConnectionState :: Active { ref mut channel } = self . conn_ctl {
92+ Some ( channel)
93+ } else {
94+ None
95+ }
96+ }
97+ }
98+
99+ /// A local server connection. Manages the connection lifecycle:
100+ /// - Accepting incoming connections.
101+ /// - Handling established connections.
102+ pub ( crate ) struct ServerConnection < T , A >
103+ where
104+ T : Transport < A > ,
105+ A : Address ,
106+ {
107+ /// The local address.
108+ addr : A ,
109+ /// The accept task which handles accepting an incoming connection.
110+ accept_task : Option < WithSpan < T :: Accept > > ,
111+ /// The inbound connection.
112+ conn : Conn < T :: Io , T :: Stats , A > ,
113+ }
114+
115+ impl < T , A > ConnectionManager < T , A , ClientConnection < T , A > >
116+ where
117+ T : Transport < A > ,
118+ A : Address ,
119+ {
120+ pub ( crate ) fn new (
121+ options : ClientOptions ,
122+ transport : T ,
123+ addr : A ,
124+ conn_ctl : ConnCtl < T :: Io , T :: Stats , A > ,
125+ transport_stats : Arc < ArcSwap < T :: Stats > > ,
126+ span : tracing:: Span ,
127+ ) -> Self {
128+ let conn = ClientConnection { options, addr, conn_task : None , conn_ctl } ;
129+
130+ Self { state : conn, transport, transport_stats, span }
131+ }
132+
133+ /// Start the connection task to the server, handling authentication if necessary.
134+ /// The result will be polled by the driver and re-tried according to the backoff policy.
135+ fn try_connect ( & mut self ) {
136+ let connect = self . transport . connect ( self . state . addr . clone ( ) ) ;
137+ let token = self . state . options . auth_token . clone ( ) ;
138+
139+ let task = async move {
140+ let io = connect. await ?;
141+
142+ let Some ( token) = token else {
143+ return Ok ( io) ;
144+ } ;
145+
146+ authentication_handshake :: < T , A > ( io, token) . await
147+ }
148+ . in_current_span ( ) ;
149+
150+ // FIX: coercion to BoxFuture for [`SpanExt::with_current_span`]
151+ self . state . conn_task = Some ( WithSpan :: current ( Box :: pin ( task) ) ) ;
152+ }
153+
154+ /// Reset the connection state to inactive, so that it will be re-tried.
155+ ///
156+ /// This is done when the connection is closed or an error occurs.
157+ #[ inline]
158+ pub ( crate ) fn reset_connection ( & mut self ) {
159+ self . state . reset_connection ( ) ;
160+ }
161+
162+ /// Poll connection management logic: connection task, backoff, and retry logic.
163+ /// Loops until the connection is active, then returns a mutable reference to the channel.
164+ ///
165+ /// Note: this is not a `Future` impl because we want to return a reference; doing it in
166+ /// a `Future` would require lifetime headaches or unsafe code.
167+ ///
168+ /// Returns:
169+ /// * `Poll::Ready(Some(&mut channel))` if the connection is active
170+ /// * `Poll::Ready(None)` if we should terminate (max retries exceeded)
171+ /// * `Poll::Pending` if we need to wait for backoff
172+ #[ allow( clippy:: type_complexity) ]
173+ pub ( crate ) fn poll (
174+ & mut self ,
175+ cx : & mut Context < ' _ > ,
176+ ) -> Poll < Option < & mut Conn < T :: Io , T :: Stats , A > > > {
177+ loop {
178+ // Poll the active connection task, if any
179+ if let Some ( ref mut conn_task) = self . state . conn_task {
180+ if let Poll :: Ready ( result) = conn_task. poll_unpin ( cx) . enter ( ) {
181+ // As soon as the connection task finishes, set it to `None`.
182+ // - If it was successful, set the connection to active
183+ // - If it failed, it will be re-tried until the backoff limit is reached.
184+ self . state . conn_task = None ;
185+
186+ match result. inner {
187+ Ok ( io) => {
188+ tracing:: info!( "connected" ) ;
189+
190+ let metered = MeteredIo :: new ( io, self . transport_stats . clone ( ) ) ;
191+ let framed = Framed :: new ( metered, reqrep:: Codec :: new ( ) ) ;
192+ self . state . conn_ctl = ConnectionState :: Active { channel : framed } ;
193+ }
194+ Err ( e) => {
195+ tracing:: error!( ?e, "failed to connect" ) ;
196+ }
197+ }
198+ }
199+ }
200+
201+ // If the connection is inactive, try to connect to the server or poll the backoff
202+ // timer if we're already trying to connect.
203+ if let ConnectionState :: Inactive { backoff, .. } = & mut self . state . conn_ctl {
204+ let Poll :: Ready ( item) = backoff. poll_next_unpin ( cx) else {
205+ return Poll :: Pending ;
206+ } ;
207+
208+ let _span = tracing:: info_span!( parent: & self . span, "connect" ) . entered ( ) ;
209+
210+ if let Some ( duration) = item {
211+ if self . state . conn_task . is_none ( ) {
212+ tracing:: debug!( backoff = ?duration, "trying connection" ) ;
213+ self . try_connect ( ) ;
214+ } else {
215+ tracing:: debug!(
216+ backoff = ?duration,
217+ "not retrying as there is already a connection task"
218+ ) ;
219+ }
220+ } else {
221+ tracing:: error!( "exceeded maximum number of retries, terminating connection" ) ;
222+ return Poll :: Ready ( None ) ;
223+ }
224+ }
225+
226+ if let ConnectionState :: Active { ref mut channel } = self . state . conn_ctl {
227+ return Poll :: Ready ( Some ( channel) ) ;
228+ }
229+ }
230+ }
231+ }
232+
233+ pub struct ServerOptions { }
234+
235+ impl < T , A > ConnectionManager < T , A , ServerConnection < T , A > >
236+ where
237+ T : Transport < A > ,
238+ A : Address ,
239+ {
240+ pub ( crate ) fn new (
241+ options : ServerOptions ,
242+ transport : T ,
243+ addr : A ,
244+ conn : Conn < T :: Io , T :: Stats , A > ,
245+ transport_stats : Arc < ArcSwap < T :: Stats > > ,
246+ span : tracing:: Span ,
247+ ) -> Self {
248+ let conn = ServerConnection { addr, accept_task : None , conn } ;
249+
250+ Self { state : conn, transport, transport_stats, span }
251+ }
252+ }
253+
36254/// Manages the connection lifecycle: connecting, reconnecting, and maintaining the connection.
37255pub ( crate ) struct ConnManager < T : Transport < A > , A : Address > {
38256 /// Options for the connection manager.
39- options : ConnOptions ,
257+ options : ClientOptions ,
40258 /// The connection task which handles the connection to the server.
41259 conn_task : Option < WithSpan < ConnTask < T :: Io , T :: Error > > > ,
42260 /// The transport controller, wrapped in a [`ConnectionState`] for backoff.
89307 A : Address ,
90308{
91309 pub ( crate ) fn new (
92- options : ConnOptions ,
310+ options : ClientOptions ,
93311 transport : T ,
94312 addr : A ,
95313 conn_ctl : ConnCtl < T :: Io , T :: Stats , A > ,
0 commit comments