1
- //! http-client implementation for async-h1.
1
+ //! http-client implementation for async-h1, with connecton pooling ("Keep-Alive") .
2
2
3
- use super :: { async_trait, Error , HttpClient , Request , Response } ;
3
+ use std:: collections:: HashMap ;
4
+ use std:: net:: SocketAddr ;
5
+ use std:: { fmt:: Debug , sync:: Arc } ;
4
6
5
7
use async_h1:: client;
8
+ use async_std:: net:: TcpStream ;
9
+ use async_std:: sync:: Mutex ;
10
+ use deadpool:: managed:: Pool ;
6
11
use http_types:: StatusCode ;
7
12
8
- /// Async-h1 based HTTP Client.
9
- #[ derive( Debug ) ]
13
+ #[ cfg( not( feature = "h1_client_rustls" ) ) ]
14
+ use async_native_tls:: TlsStream ;
15
+ #[ cfg( feature = "h1_client_rustls" ) ]
16
+ use async_tls:: client:: TlsStream ;
17
+
18
+ use super :: { async_trait, Error , HttpClient , Request , Response } ;
19
+
20
+ mod tcp;
21
+ mod tls;
22
+
23
+ use tcp:: { TcpConnWrapper , TcpConnection } ;
24
+ use tls:: { TlsConnWrapper , TlsConnection } ;
25
+
26
+ // TODO: Move this to a parameter. This current number is based on a few
27
+ // random benchmarks and see whatever gave decent perf vs resource use.
28
+ static MAX_CONCURRENT_CONNECTIONS : usize = 50 ;
29
+
30
+ type HttpPool = HashMap < SocketAddr , Pool < TcpStream , std:: io:: Error > > ;
31
+ type HttpsPool = HashMap < SocketAddr , Pool < TlsStream < TcpStream > , Error > > ;
32
+
33
+ /// Async-h1 based HTTP Client, with connecton pooling ("Keep-Alive").
10
34
pub struct H1Client {
11
- _priv : ( ) ,
35
+ http_pool : Arc < Mutex < HttpPool > > ,
36
+ https_pool : Arc < Mutex < HttpsPool > > ,
37
+ }
38
+
39
+ impl Debug for H1Client {
40
+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
41
+ f. write_str ( "H1Client" )
42
+ }
12
43
}
13
44
14
45
impl Default for H1Client {
@@ -20,13 +51,20 @@ impl Default for H1Client {
20
51
impl H1Client {
21
52
/// Create a new instance.
22
53
pub fn new ( ) -> Self {
23
- Self { _priv : ( ) }
54
+ Self {
55
+ http_pool : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
56
+ https_pool : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
57
+ }
24
58
}
25
59
}
26
60
27
61
#[ async_trait]
28
62
impl HttpClient for H1Client {
29
63
async fn send ( & self , mut req : Request ) -> Result < Response , Error > {
64
+ let http_pool = self . http_pool . clone ( ) ;
65
+ let https_pool = self . https_pool . clone ( ) ;
66
+ req. insert_header ( "Connection" , "keep-alive" ) ;
67
+
30
68
// Insert host
31
69
let host = req
32
70
. url ( )
@@ -57,40 +95,62 @@ impl HttpClient for H1Client {
57
95
58
96
match scheme {
59
97
"http" => {
60
- let stream = async_std:: net:: TcpStream :: connect ( addr) . await ?;
98
+ let mut hash = http_pool. lock ( ) . await ;
99
+ let pool = if let Some ( pool) = hash. get ( & addr) {
100
+ pool
101
+ } else {
102
+ let manager = TcpConnection :: new ( addr) ;
103
+ let pool =
104
+ Pool :: < TcpStream , std:: io:: Error > :: new ( manager, MAX_CONCURRENT_CONNECTIONS ) ;
105
+ hash. insert ( addr, pool) ;
106
+ hash. get ( & addr) . unwrap ( )
107
+ } ;
108
+ let pool = pool. clone ( ) ;
109
+ std:: mem:: drop ( hash) ;
110
+ let stream = pool. get ( ) . await ?;
61
111
req. set_peer_addr ( stream. peer_addr ( ) . ok ( ) ) ;
62
112
req. set_local_addr ( stream. local_addr ( ) . ok ( ) ) ;
63
- client:: connect ( stream, req) . await
113
+ client:: connect ( TcpConnWrapper :: new ( stream) , req) . await
114
+
115
+ // let stream = async_std::net::TcpStream::connect(addr).await?;
116
+ // req.set_peer_addr(stream.peer_addr().ok());
117
+ // req.set_local_addr(stream.local_addr().ok());
118
+ // client::connect(stream, req).await
64
119
}
65
120
"https" => {
66
- let raw_stream = async_std:: net:: TcpStream :: connect ( addr) . await ?;
67
- req. set_peer_addr ( raw_stream. peer_addr ( ) . ok ( ) ) ;
68
- req. set_local_addr ( raw_stream. local_addr ( ) . ok ( ) ) ;
69
- let tls_stream = add_tls ( host, raw_stream) . await ?;
70
- client:: connect ( tls_stream, req) . await
121
+ let mut hash = https_pool. lock ( ) . await ;
122
+ let pool = if let Some ( pool) = hash. get ( & addr) {
123
+ pool
124
+ } else {
125
+ let manager = TlsConnection :: new ( host. clone ( ) , addr) ;
126
+ let pool = Pool :: < TlsStream < TcpStream > , Error > :: new (
127
+ manager,
128
+ MAX_CONCURRENT_CONNECTIONS ,
129
+ ) ;
130
+ hash. insert ( addr, pool) ;
131
+ hash. get ( & addr) . unwrap ( )
132
+ } ;
133
+ let pool = pool. clone ( ) ;
134
+ std:: mem:: drop ( hash) ;
135
+ let stream = pool. get ( ) . await . unwrap ( ) ; // TODO: remove unwrap
136
+ req. set_peer_addr ( stream. get_ref ( ) . peer_addr ( ) . ok ( ) ) ;
137
+ req. set_local_addr ( stream. get_ref ( ) . local_addr ( ) . ok ( ) ) ;
138
+
139
+ client:: connect ( TlsConnWrapper :: new ( stream) , req) . await
140
+
141
+ // let raw_stream = async_std::net::TcpStream::connect(addr).await?;
142
+ // req.set_peer_addr(raw_stream.peer_addr().ok());
143
+ // req.set_local_addr(raw_stream.local_addr().ok());
144
+
145
+ // let stream = async_native_tls::connect(host, raw_stream).await?;
146
+
147
+ // client::connect(stream, req).await
71
148
}
72
149
_ => unreachable ! ( ) ,
73
150
}
74
151
}
75
152
}
76
153
77
- #[ cfg( not( feature = "h1_client_rustls" ) ) ]
78
- async fn add_tls (
79
- host : String ,
80
- stream : async_std:: net:: TcpStream ,
81
- ) -> Result < async_native_tls:: TlsStream < async_std:: net:: TcpStream > , async_native_tls:: Error > {
82
- async_native_tls:: connect ( host, stream) . await
83
- }
84
-
85
- #[ cfg( feature = "h1_client_rustls" ) ]
86
- async fn add_tls (
87
- host : String ,
88
- stream : async_std:: net:: TcpStream ,
89
- ) -> std:: io:: Result < async_tls:: client:: TlsStream < async_std:: net:: TcpStream > > {
90
- let connector = async_tls:: TlsConnector :: default ( ) ;
91
- connector. connect ( host, stream) . await
92
- }
93
-
94
154
#[ cfg( test) ]
95
155
mod tests {
96
156
use super :: * ;
0 commit comments