Skip to content

Commit bc3b903

Browse files
committed
refactor: use dashmap
dashmap is in effect a better `RwLock<HashMap<K, V>>`. https://crates.io/crates/dashmap
1 parent c2b22bc commit bc3b903

File tree

2 files changed

+16
-20
lines changed

2 files changed

+16
-20
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ hyper_client = ["hyper", "hyper-tls", "http-types/hyperium_http", "futures-util"
3131

3232
[dependencies]
3333
async-trait = "0.1.37"
34+
dashmap = "3.11.10"
3435
http-types = "2.3.0"
3536
log = "0.4.7"
3637

src/h1/mod.rs

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
//! http-client implementation for async-h1, with connecton pooling ("Keep-Alive").
22
3-
use std::collections::HashMap;
43
use std::net::SocketAddr;
54
use std::{fmt::Debug, sync::Arc};
65

76
use async_h1::client;
87
use async_std::net::TcpStream;
9-
use async_std::sync::Mutex;
8+
use dashmap::DashMap;
109
use deadpool::managed::Pool;
1110
use http_types::StatusCode;
1211

@@ -27,13 +26,13 @@ use tls::{TlsConnWrapper, TlsConnection};
2726
// random benchmarks and see whatever gave decent perf vs resource use.
2827
static MAX_CONCURRENT_CONNECTIONS: usize = 50;
2928

30-
type HttpPool = HashMap<SocketAddr, Pool<TcpStream, std::io::Error>>;
31-
type HttpsPool = HashMap<SocketAddr, Pool<TlsStream<TcpStream>, Error>>;
29+
type HttpPool = DashMap<SocketAddr, Pool<TcpStream, std::io::Error>>;
30+
type HttpsPool = DashMap<SocketAddr, Pool<TlsStream<TcpStream>, Error>>;
3231

3332
/// Async-h1 based HTTP Client, with connecton pooling ("Keep-Alive").
3433
pub struct H1Client {
35-
http_pool: Arc<Mutex<HttpPool>>,
36-
https_pool: Arc<Mutex<HttpsPool>>,
34+
http_pools: Arc<HttpPool>,
35+
https_pools: Arc<HttpsPool>,
3736
}
3837

3938
impl Debug for H1Client {
@@ -52,17 +51,17 @@ impl H1Client {
5251
/// Create a new instance.
5352
pub fn new() -> Self {
5453
Self {
55-
http_pool: Arc::new(Mutex::new(HashMap::new())),
56-
https_pool: Arc::new(Mutex::new(HashMap::new())),
54+
http_pools: Arc::new(DashMap::new()),
55+
https_pools: Arc::new(DashMap::new()),
5756
}
5857
}
5958
}
6059

6160
#[async_trait]
6261
impl HttpClient for H1Client {
6362
async fn send(&self, mut req: Request) -> Result<Response, Error> {
64-
let http_pool = self.http_pool.clone();
65-
let https_pool = self.https_pool.clone();
63+
let http_pools = self.http_pools.clone();
64+
let https_pools = self.https_pools.clone();
6665
req.insert_header("Connection", "keep-alive");
6766

6867
// Insert host
@@ -95,18 +94,16 @@ impl HttpClient for H1Client {
9594

9695
match scheme {
9796
"http" => {
98-
let mut hash = http_pool.lock().await;
99-
let pool = if let Some(pool) = hash.get(&addr) {
97+
let pool = if let Some(pool) = http_pools.get(&addr) {
10098
pool
10199
} else {
102100
let manager = TcpConnection::new(addr);
103101
let pool =
104102
Pool::<TcpStream, std::io::Error>::new(manager, MAX_CONCURRENT_CONNECTIONS);
105-
hash.insert(addr, pool);
106-
hash.get(&addr).unwrap()
103+
http_pools.insert(addr, pool);
104+
http_pools.get(&addr).unwrap()
107105
};
108106
let pool = pool.clone();
109-
std::mem::drop(hash);
110107
let stream = pool.get().await?;
111108
req.set_peer_addr(stream.peer_addr().ok());
112109
req.set_local_addr(stream.local_addr().ok());
@@ -118,20 +115,18 @@ impl HttpClient for H1Client {
118115
// client::connect(stream, req).await
119116
}
120117
"https" => {
121-
let mut hash = https_pool.lock().await;
122-
let pool = if let Some(pool) = hash.get(&addr) {
118+
let pool = if let Some(pool) = https_pools.get(&addr) {
123119
pool
124120
} else {
125121
let manager = TlsConnection::new(host.clone(), addr);
126122
let pool = Pool::<TlsStream<TcpStream>, Error>::new(
127123
manager,
128124
MAX_CONCURRENT_CONNECTIONS,
129125
);
130-
hash.insert(addr, pool);
131-
hash.get(&addr).unwrap()
126+
https_pools.insert(addr, pool);
127+
https_pools.get(&addr).unwrap()
132128
};
133129
let pool = pool.clone();
134-
std::mem::drop(hash);
135130
let stream = pool.get().await.unwrap(); // TODO: remove unwrap
136131
req.set_peer_addr(stream.get_ref().peer_addr().ok());
137132
req.set_local_addr(stream.get_ref().local_addr().ok());

0 commit comments

Comments
 (0)