11#![ cfg( any( feature = "native-tls-client" , feature = "rustls-client" ) ) ]
22
3- use bytes:: Bytes ;
4- use http_body_util:: Empty ;
3+ use bytes:: { Buf , Bytes } ;
4+ use http_body_util:: { BodyExt , Empty , combinators :: BoxBody } ;
55use hyper:: {
66 Request , Response , StatusCode , Uri , Version ,
77 body:: { Body , Incoming } ,
88 client, header,
99} ;
1010use hyper_util:: rt:: { TokioExecutor , TokioIo } ;
11- use std:: task:: { Context , Poll } ;
11+ use std:: {
12+ collections:: HashMap ,
13+ future:: poll_fn,
14+ sync:: Arc ,
15+ task:: { Context , Poll } ,
16+ } ;
17+ use tokio:: sync:: Mutex ;
1218use tokio:: { net:: TcpStream , task:: JoinHandle } ;
1319
1420#[ cfg( all( feature = "native-tls-client" , feature = "rustls-client" ) ) ]
@@ -55,6 +61,96 @@ pub struct Upgraded {
5561 /// A socket to Server
5662 pub server : TokioIo < hyper:: upgrade:: Upgraded > ,
5763}
64+
65+ type DynError = Box < dyn std:: error:: Error + Send + Sync > ;
66+ type PooledBody = BoxBody < Bytes , DynError > ;
67+ type Http1Sender = hyper:: client:: conn:: http1:: SendRequest < PooledBody > ;
68+ type Http2Sender = hyper:: client:: conn:: http2:: SendRequest < PooledBody > ;
69+
70+ #[ derive( Clone , Copy , Debug , Eq , PartialEq , Hash ) ]
71+ enum ConnectionProtocol {
72+ Http1 ,
73+ Http2 ,
74+ }
75+
76+ #[ derive( Clone , Debug , Eq , PartialEq , Hash ) ]
77+ struct ConnectionKey {
78+ host : String ,
79+ port : u16 ,
80+ is_tls : bool ,
81+ protocol : ConnectionProtocol ,
82+ }
83+
84+ impl ConnectionKey {
85+ fn new ( host : String , port : u16 , is_tls : bool , protocol : ConnectionProtocol ) -> Self {
86+ Self {
87+ host,
88+ port,
89+ is_tls,
90+ protocol,
91+ }
92+ }
93+
94+ fn from_uri ( uri : & Uri , protocol : ConnectionProtocol ) -> Result < Self , Error > {
95+ let ( host, port, is_tls) = host_port ( uri) ?;
96+ Ok ( ConnectionKey :: new ( host, port, is_tls, protocol) )
97+ }
98+ }
99+
100+ #[ derive( Clone , Default ) ]
101+ struct ConnectionPool {
102+ http1 : Arc < Mutex < HashMap < ConnectionKey , Vec < Http1Sender > > > > ,
103+ http2 : Arc < Mutex < HashMap < ConnectionKey , Http2Sender > > > ,
104+ }
105+
106+ impl ConnectionPool {
107+ async fn take_http1 ( & self , key : & ConnectionKey ) -> Option < Http1Sender > {
108+ let mut guard = self . http1 . lock ( ) . await ;
109+ let entry = guard. get_mut ( key) ?;
110+ while let Some ( mut conn) = entry. pop ( ) {
111+ if sender_alive_http1 ( & mut conn) . await {
112+ return Some ( conn) ;
113+ }
114+ }
115+ if entry. is_empty ( ) {
116+ guard. remove ( key) ;
117+ }
118+ None
119+ }
120+
121+ async fn put_http1 ( & self , key : ConnectionKey , sender : Http1Sender ) {
122+ let mut guard = self . http1 . lock ( ) . await ;
123+ guard. entry ( key) . or_default ( ) . push ( sender) ;
124+ }
125+
126+ async fn get_http2 ( & self , key : & ConnectionKey ) -> Option < Http2Sender > {
127+ let mut guard = self . http2 . lock ( ) . await ;
128+ let mut sender = guard. get ( key) . cloned ( ) ?;
129+
130+ let alive = sender_alive_http2 ( & mut sender) . await ;
131+
132+ if alive {
133+ Some ( sender)
134+ } else {
135+ guard. remove ( key) ;
136+ None
137+ }
138+ }
139+
140+ async fn insert_http2_if_absent ( & self , key : ConnectionKey , sender : Http2Sender ) {
141+ let mut guard = self . http2 . lock ( ) . await ;
142+ guard. entry ( key) . or_insert ( sender) ;
143+ }
144+ }
145+
146+ async fn sender_alive_http1 ( sender : & mut Http1Sender ) -> bool {
147+ poll_fn ( |cx| sender. poll_ready ( cx) ) . await . is_ok ( )
148+ }
149+
150+ async fn sender_alive_http2 ( sender : & mut Http2Sender ) -> bool {
151+ poll_fn ( |cx| sender. poll_ready ( cx) ) . await . is_ok ( )
152+ }
153+
58154#[ derive( Clone ) ]
59155/// Default HTTP client for this crate
60156pub struct DefaultClient {
@@ -71,6 +167,8 @@ pub struct DefaultClient {
71167 /// If true, send_request will returns an Upgraded struct when the response is an upgrade
72168 /// If false, send_request never returns an Upgraded struct and just copy bidirectional when the response is an upgrade
73169 pub with_upgrades : bool ,
170+
171+ pool : ConnectionPool ,
74172}
75173impl Default for DefaultClient {
76174 fn default ( ) -> Self {
@@ -102,6 +200,7 @@ impl DefaultClient {
102200 tls_connector_no_alpn : tokio_native_tls:: TlsConnector :: from ( tls_connector_no_alpn) ,
103201 tls_connector_alpn_h2 : tokio_native_tls:: TlsConnector :: from ( tls_connector_alpn_h2) ,
104202 with_upgrades : false ,
203+ pool : ConnectionPool :: default ( ) ,
105204 } )
106205 }
107206
@@ -135,6 +234,7 @@ impl DefaultClient {
135234 tls_connector_alpn_h2,
136235 ) ) ,
137236 with_upgrades : false ,
237+ pool : ConnectionPool :: default ( ) ,
138238 } )
139239 }
140240
@@ -175,17 +275,54 @@ impl DefaultClient {
175275 Error ,
176276 >
177277 where
178- B : Body + Unpin + Send + ' static ,
179- B :: Data : Send ,
180- B :: Error : Into < Box < dyn std :: error :: Error + Send + Sync > > ,
278+ B : Body < Data = Bytes > + Send + Sync + ' static ,
279+ B :: Data : Send + Buf ,
280+ B :: Error : Into < DynError > ,
181281 {
182- let mut send_request = self . connect ( req. uri ( ) , req. version ( ) ) . await ?;
282+ let target_uri = req. uri ( ) . clone ( ) ;
283+ let mut send_request = if req. version ( ) == Version :: HTTP_2 {
284+ match ConnectionKey :: from_uri ( & target_uri, ConnectionProtocol :: Http2 ) {
285+ Ok ( pool_key) => {
286+ if let Some ( conn) = self . pool . get_http2 ( & pool_key) . await {
287+ SendRequest :: Http2 ( conn)
288+ } else {
289+ self . connect ( req. uri ( ) , req. version ( ) , Some ( pool_key) )
290+ . await ?
291+ }
292+ }
293+ Err ( err) => {
294+ tracing:: warn!(
295+ "ConnectionKey::from_uri failed for HTTP/2 ({}): continuing without pool" ,
296+ err
297+ ) ;
298+ self . connect ( req. uri ( ) , req. version ( ) , None ) . await ?
299+ }
300+ }
301+ } else {
302+ match ConnectionKey :: from_uri ( & target_uri, ConnectionProtocol :: Http1 ) {
303+ Ok ( pool_key) => {
304+ if let Some ( conn) = self . pool . take_http1 ( & pool_key) . await {
305+ SendRequest :: Http1 ( conn)
306+ } else {
307+ self . connect ( req. uri ( ) , req. version ( ) , Some ( pool_key) )
308+ . await ?
309+ }
310+ }
311+ Err ( err) => {
312+ tracing:: warn!(
313+ "ConnectionKey::from_uri failed for HTTP/1 ({}): continuing without pool" ,
314+ err
315+ ) ;
316+ self . connect ( req. uri ( ) , req. version ( ) , None ) . await ?
317+ }
318+ }
319+ } ;
183320
184321 let ( req_parts, req_body) = req. into_parts ( ) ;
185322
186- let res = send_request
187- . send_request ( Request :: from_parts ( req_parts . clone ( ) , req_body ) )
188- . await ?;
323+ let boxed_req = Request :: from_parts ( req_parts . clone ( ) , to_boxed_body ( req_body ) ) ;
324+
325+ let res = send_request . send_request ( boxed_req ) . await ?;
189326
190327 if res. status ( ) == StatusCode :: SWITCHING_PROTOCOLS {
191328 let ( res_parts, res_body) = res. into_parts ( ) ;
@@ -221,36 +358,41 @@ impl DefaultClient {
221358
222359 Ok ( ( Response :: from_parts ( res_parts, res_body) , upgrade) )
223360 } else {
361+ match send_request {
362+ SendRequest :: Http1 ( sender) => {
363+ if let Ok ( pool_key) =
364+ ConnectionKey :: from_uri ( & target_uri, ConnectionProtocol :: Http1 )
365+ {
366+ self . pool . put_http1 ( pool_key, sender) . await ;
367+ } else {
368+ // If we couldn't build a pool key, skip pooling.
369+ }
370+ }
371+ SendRequest :: Http2 ( _) => {
372+ // For HTTP/2 the pool retains a shared sender; no action needed.
373+ }
374+ }
224375 Ok ( ( res, None ) )
225376 }
226377 }
227378
228- async fn connect < B > ( & self , uri : & Uri , http_version : Version ) -> Result < SendRequest < B > , Error >
229- where
230- B : Body + Unpin + Send + ' static ,
231- B :: Data : Send ,
232- B :: Error : Into < Box < dyn std:: error:: Error + Send + Sync > > ,
233- {
234- let host = uri
235- . host ( )
236- . ok_or_else ( || Error :: InvalidHost ( Box :: new ( uri. clone ( ) ) ) ) ?;
237- let port =
238- uri. port_u16 ( )
239- . unwrap_or ( if uri. scheme ( ) == Some ( & hyper:: http:: uri:: Scheme :: HTTPS ) {
240- 443
241- } else {
242- 80
243- } ) ;
379+ async fn connect (
380+ & self ,
381+ uri : & Uri ,
382+ http_version : Version ,
383+ key : Option < ConnectionKey > ,
384+ ) -> Result < SendRequest , Error > {
385+ let ( host, port, is_tls) = host_port ( uri) ?;
244386
245- let tcp = TcpStream :: connect ( ( host, port) ) . await ?;
387+ let tcp = TcpStream :: connect ( ( host. as_str ( ) , port) ) . await ?;
246388 // This is actually needed to some servers
247389 let _ = tcp. set_nodelay ( true ) ;
248390
249- if uri . scheme ( ) == Some ( & hyper :: http :: uri :: Scheme :: HTTPS ) {
391+ if is_tls {
250392 #[ cfg( feature = "native-tls-client" ) ]
251393 let tls = self
252394 . tls_connector ( http_version)
253- . connect ( host, tcp)
395+ . connect ( & host, tcp)
254396 . await
255397 . map_err ( |err| Error :: TlsConnectError ( Box :: new ( uri. clone ( ) ) , err) ) ?;
256398 #[ cfg( feature = "rustls-client" ) ]
@@ -284,6 +426,14 @@ impl DefaultClient {
284426
285427 tokio:: spawn ( conn) ;
286428
429+ if let Some ( ref k) = key
430+ && matches ! ( k. protocol, ConnectionProtocol :: Http2 )
431+ {
432+ self . pool
433+ . insert_http2_if_absent ( k. clone ( ) , sender. clone ( ) )
434+ . await ;
435+ }
436+
287437 Ok ( SendRequest :: Http2 ( sender) )
288438 } else {
289439 let ( sender, conn) = client:: conn:: http1:: Builder :: new ( )
@@ -310,18 +460,15 @@ impl DefaultClient {
310460 }
311461}
312462
313- enum SendRequest < B > {
314- Http1 ( hyper :: client :: conn :: http1 :: SendRequest < B > ) ,
315- Http2 ( hyper :: client :: conn :: http2 :: SendRequest < B > ) ,
463+ enum SendRequest {
464+ Http1 ( Http1Sender ) ,
465+ Http2 ( Http2Sender ) ,
316466}
317467
318- impl < B > SendRequest < B >
319- where
320- B : Body + ' static ,
321- {
468+ impl SendRequest {
322469 async fn send_request (
323470 & mut self ,
324- mut req : Request < B > ,
471+ mut req : Request < PooledBody > ,
325472 ) -> Result < Response < Incoming > , hyper:: Error > {
326473 match self {
327474 SendRequest :: Http1 ( sender) => {
@@ -357,13 +504,13 @@ where
357504 }
358505}
359506
360- impl < B > SendRequest < B > {
507+ impl SendRequest {
361508 #[ allow( dead_code) ]
362509 // TODO: connection pooling
363510 fn poll_ready ( & mut self , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , hyper:: Error > > {
364511 match self {
365512 SendRequest :: Http1 ( sender) => sender. poll_ready ( cx) ,
366- SendRequest :: Http2 ( sender ) => sender . poll_ready ( cx ) ,
513+ SendRequest :: Http2 ( _sender ) => Poll :: Ready ( Ok ( ( ) ) ) ,
367514 }
368515 }
369516}
@@ -375,3 +522,24 @@ fn remove_authority<B>(req: &mut Request<B>) -> Result<(), hyper::http::uri::Inv
375522 * req. uri_mut ( ) = Uri :: from_parts ( parts) ?;
376523 Ok ( ( ) )
377524}
525+
526+ fn to_boxed_body < B > ( body : B ) -> PooledBody
527+ where
528+ B : Body < Data = Bytes > + Send + Sync + ' static ,
529+ B :: Data : Send + Buf ,
530+ B :: Error : Into < DynError > ,
531+ {
532+ body. map_err ( |err| err. into ( ) ) . boxed ( )
533+ }
534+
535+ fn host_port ( uri : & Uri ) -> Result < ( String , u16 , bool ) , Error > {
536+ let host = uri
537+ . host ( )
538+ . ok_or_else ( || Error :: InvalidHost ( Box :: new ( uri. clone ( ) ) ) ) ?
539+ . to_string ( ) ;
540+ let is_tls = uri. scheme ( ) == Some ( & hyper:: http:: uri:: Scheme :: HTTPS ) ;
541+ let port = uri. port_u16 ( ) . unwrap_or ( if is_tls { 443 } else { 80 } ) ;
542+ Ok ( ( host, port, is_tls) )
543+ }
544+
545+ impl DefaultClient { }
0 commit comments