1- use std:: { future:: Future , net:: SocketAddr , sync:: Arc } ;
2-
31use protosocket:: Connection ;
2+ use socket2:: TcpKeepalive ;
3+ use std:: { future:: Future , net:: SocketAddr , sync:: Arc } ;
44use tokio:: { net:: TcpStream , sync:: mpsc} ;
55use tokio_rustls:: rustls:: pki_types:: ServerName ;
66
@@ -172,6 +172,7 @@ pub struct Configuration<TStreamConnector> {
172172 max_buffer_length : usize ,
173173 buffer_allocation_increment : usize ,
174174 max_queued_outbound_messages : usize ,
175+ tcp_keepalive_duration : Option < std:: time:: Duration > ,
175176 stream_connector : TStreamConnector ,
176177}
177178
@@ -185,6 +186,7 @@ where
185186 max_buffer_length : 4 * ( 1 << 20 ) , // 4 MiB
186187 buffer_allocation_increment : 1 << 20 ,
187188 max_queued_outbound_messages : 256 ,
189+ tcp_keepalive_duration : None ,
188190 stream_connector,
189191 }
190192 }
@@ -209,6 +211,13 @@ where
209211 pub fn buffer_allocation_increment ( & mut self , buffer_allocation_increment : usize ) {
210212 self . buffer_allocation_increment = buffer_allocation_increment;
211213 }
214+
215+ /// The duration to set for tcp_keepalive on the underlying socket.
216+ ///
217+ /// Default: None
218+ pub fn tcp_keepalive_duration ( & mut self , tcp_keepalive_duration : Option < std:: time:: Duration > ) {
219+ self . tcp_keepalive_duration = tcp_keepalive_duration;
220+ }
212221}
213222
214223/// Connect a new protosocket rpc client to a server
@@ -233,8 +242,25 @@ where
233242{
234243 log:: trace!( "new client {address}, {configuration:?}" ) ;
235244
236- let stream = tokio:: net:: TcpStream :: connect ( address) . await ?;
237- stream. set_nodelay ( true ) ?;
245+ let socket = socket2:: Socket :: new (
246+ match address {
247+ SocketAddr :: V4 ( _) => socket2:: Domain :: IPV4 ,
248+ SocketAddr :: V6 ( _) => socket2:: Domain :: IPV6 ,
249+ } ,
250+ socket2:: Type :: STREAM ,
251+ None ,
252+ ) ?;
253+
254+ let mut tcp_keepalive = TcpKeepalive :: new ( ) ;
255+ if let Some ( duration) = configuration. tcp_keepalive_duration {
256+ tcp_keepalive = tcp_keepalive. with_time ( duration) ;
257+ }
258+
259+ socket. set_nonblocking ( true ) ?;
260+ socket. set_tcp_nodelay ( true ) ?;
261+ socket. set_tcp_keepalive ( & tcp_keepalive) ?;
262+
263+ let stream = TcpStream :: from_std ( socket. into ( ) ) ?;
238264
239265 let message_reactor: RpcCompletionReactor <
240266 Deserializer :: Message ,
0 commit comments