@@ -3,6 +3,7 @@ use std::io;
33use std:: path:: Path ;
44use std:: pin:: Pin ;
55use std:: task:: { ready, Context , Poll } ;
6+ use std:: time:: Duration ;
67
78pub use buffered:: { BufferedSocket , WriteBuffer } ;
89use bytes:: BufMut ;
@@ -12,6 +13,25 @@ use crate::io::ReadBuf;
1213
1314mod buffered;
1415
16+ /// Configuration for TCP keepalive probes on a connection.
17+ ///
18+ /// All fields default to `None`, meaning the OS default is used.
19+ /// Constructing a `KeepaliveConfig::default()` and passing it enables keepalive
20+ /// with OS defaults for all parameters.
21+ #[ derive( Debug , Clone , Copy , Default , PartialEq , Eq ) ]
22+ pub struct KeepaliveConfig {
23+ /// Time the connection must be idle before keepalive probes begin.
24+ /// `None` means the OS default.
25+ pub idle : Option < Duration > ,
26+ /// Interval between keepalive probes.
27+ /// `None` means the OS default.
28+ pub interval : Option < Duration > ,
29+ /// Maximum number of failed probes before the connection is dropped.
30+ /// Only supported on Unix; ignored on other platforms.
31+ /// `None` means the OS default.
32+ pub retries : Option < u32 > ,
33+ }
34+
1535pub trait Socket : Send + Sync + Unpin + ' static {
1636 fn try_read ( & mut self , buf : & mut dyn ReadBuf ) -> io:: Result < usize > ;
1737
@@ -181,23 +201,63 @@ impl<S: Socket + ?Sized> Socket for Box<S> {
181201 }
182202}
183203
204+ #[ cfg( any( feature = "_rt-tokio" , feature = "_rt-async-io" ) ) ]
205+ fn build_tcp_keepalive ( config : & KeepaliveConfig ) -> socket2:: TcpKeepalive {
206+ let mut ka = socket2:: TcpKeepalive :: new ( ) ;
207+
208+ if let Some ( idle) = config. idle {
209+ ka = ka. with_time ( idle) ;
210+ }
211+
212+ // socket2's `with_interval` is unavailable on these platforms.
213+ #[ cfg( not( any(
214+ target_os = "haiku" ,
215+ target_os = "openbsd" ,
216+ target_os = "redox" ,
217+ target_os = "solaris" ,
218+ ) ) ) ]
219+ if let Some ( interval) = config. interval {
220+ ka = ka. with_interval ( interval) ;
221+ }
222+
223+ // socket2's `with_retries` is unavailable on these platforms.
224+ #[ cfg( not( any(
225+ target_os = "haiku" ,
226+ target_os = "openbsd" ,
227+ target_os = "redox" ,
228+ target_os = "solaris" ,
229+ target_os = "windows" ,
230+ ) ) ) ]
231+ if let Some ( retries) = config. retries {
232+ ka = ka. with_retries ( retries) ;
233+ }
234+
235+ ka
236+ }
237+
184238pub async fn connect_tcp < Ws : WithSocket > (
185239 host : & str ,
186240 port : u16 ,
241+ keepalive : Option < & KeepaliveConfig > ,
187242 with_socket : Ws ,
188243) -> crate :: Result < Ws :: Output > {
189244 #[ cfg( feature = "_rt-tokio" ) ]
190245 if crate :: rt:: rt_tokio:: available ( ) {
191- return Ok ( with_socket
192- . with_socket ( tokio:: net:: TcpStream :: connect ( ( host, port) ) . await ?)
193- . await ) ;
246+ let stream = tokio:: net:: TcpStream :: connect ( ( host, port) ) . await ?;
247+
248+ if let Some ( ka) = keepalive {
249+ let sock = socket2:: SockRef :: from ( & stream) ;
250+ sock. set_tcp_keepalive ( & build_tcp_keepalive ( ka) ) ?;
251+ }
252+
253+ return Ok ( with_socket. with_socket ( stream) . await ) ;
194254 }
195255
196256 cfg_if ! {
197257 if #[ cfg( feature = "_rt-async-io" ) ] {
198- Ok ( with_socket. with_socket( connect_tcp_async_io( host, port) . await ?) . await )
258+ Ok ( with_socket. with_socket( connect_tcp_async_io( host, port, keepalive ) . await ?) . await )
199259 } else {
200- crate :: rt:: missing_rt( ( host, port, with_socket) )
260+ crate :: rt:: missing_rt( ( host, port, keepalive , with_socket) )
201261 }
202262 }
203263}
@@ -208,15 +268,26 @@ pub async fn connect_tcp<Ws: WithSocket>(
208268///
209269/// This implements the same behavior as [`tokio::net::TcpStream::connect()`].
210270#[ cfg( feature = "_rt-async-io" ) ]
211- async fn connect_tcp_async_io ( host : & str , port : u16 ) -> crate :: Result < impl Socket > {
271+ async fn connect_tcp_async_io (
272+ host : & str ,
273+ port : u16 ,
274+ keepalive : Option < & KeepaliveConfig > ,
275+ ) -> crate :: Result < impl Socket > {
212276 use async_io:: Async ;
213277 use std:: net:: { IpAddr , TcpStream , ToSocketAddrs } ;
214278
215279 // IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.
216280 let host = host. trim_matches ( & [ '[' , ']' ] [ ..] ) ;
217281
218282 if let Ok ( addr) = host. parse :: < IpAddr > ( ) {
219- return Ok ( Async :: < TcpStream > :: connect ( ( addr, port) ) . await ?) ;
283+ let stream = Async :: < TcpStream > :: connect ( ( addr, port) ) . await ?;
284+
285+ if let Some ( ka) = keepalive {
286+ let sock = socket2:: SockRef :: from ( stream. get_ref ( ) ) ;
287+ sock. set_tcp_keepalive ( & build_tcp_keepalive ( ka) ) ?;
288+ }
289+
290+ return Ok ( stream) ;
220291 }
221292
222293 let host = host. to_string ( ) ;
@@ -232,7 +303,14 @@ async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result<impl Socke
232303 // Loop through all the Socket Addresses that the hostname resolves to
233304 for socket_addr in addresses {
234305 match Async :: < TcpStream > :: connect ( socket_addr) . await {
235- Ok ( stream) => return Ok ( stream) ,
306+ Ok ( stream) => {
307+ if let Some ( ka) = keepalive {
308+ let sock = socket2:: SockRef :: from ( stream. get_ref ( ) ) ;
309+ sock. set_tcp_keepalive ( & build_tcp_keepalive ( ka) ) ?;
310+ }
311+
312+ return Ok ( stream) ;
313+ }
236314 Err ( e) => last_err = Some ( e) ,
237315 }
238316 }
0 commit comments