forked from kvc0/protosocket
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconnection_pool.rs
More file actions
198 lines (180 loc) · 7.46 KB
/
connection_pool.rs
File metadata and controls
198 lines (180 loc) · 7.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
use std::{
cell::RefCell,
future::Future,
pin::{pin, Pin},
sync::Mutex,
task::{Context, Poll},
};
use futures::FutureExt;
use rand::{Rng, SeedableRng};
use crate::{client::RpcClient, Message};
/// A connection strategy for protosocket rpc clients.
///
/// This is called asynchronously by the connection pool to create new connections.
pub trait ClientConnector: Clone {
type Request: Message;
type Response: Message;
/// Connect to the server and return a new RpcClient. See [`crate::client::connect`] for
/// the typical way to connect. This is called by the ConnectionPool when it needs a new
/// connection.
///
/// Your returned future needs to be `'static`, and your connector needs to be cheap to
/// clone. One easy way to do that is to just impl ClientConnector on `Arc<YourConnectorType>`
/// instead of directly on `YourConnectorType`.
///
/// If you have rolling credentials, initialization messages, changing endpoints, or other
/// adaptive connection logic, this is the place to do it or consult those sources of truth.
fn connect(
self,
) -> impl Future<Output = crate::Result<RpcClient<Self::Request, Self::Response>>> + Send + 'static;
}
/// A connection pool for protosocket rpc clients.
///
/// Protosocket-rpc connections are shared and multiplexed, so this vends cloned handles.
/// You can hold onto a handle from the pool for as long as you want. There is a small
/// synchronization cost to getting a handle from a pool, so caching is a good idea - but
/// if you want to load balance a lot, you can just make a pool with as many "slots" as you
/// want to dilute any contention on connection state locks. The locks are typically held for
/// the time it takes to clone an `Arc`, so it's usually nanosecond-scale synchronization,
/// per connection. So if you have several connections, you'll rarely contend.
#[derive(Debug)]
pub struct ConnectionPool<Connector: ClientConnector> {
connector: Connector,
connections: Vec<Mutex<ConnectionState<Connector::Request, Connector::Response>>>,
}
impl<Connector: ClientConnector> ConnectionPool<Connector> {
/// Create a new connection pool.
///
/// It will try to maintain `connection_count` healthy connections.
pub fn new(connector: Connector, connection_count: usize) -> Self {
Self {
connector,
connections: (0..connection_count)
.map(|_| Mutex::new(ConnectionState::Disconnected))
.collect(),
}
}
/// Get a consistent connection from the pool for a given key.
pub async fn get_connection_for_key(
&self,
key: usize,
) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
let slot = key % self.connections.len();
self.get_connection_by_slot(slot).await
}
/// Get a connection from the pool.
pub async fn get_connection(
&self,
) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
thread_local! {
static THREAD_LOCAL_SMALL_RANDOM: RefCell<rand::rngs::SmallRng> = RefCell::new(rand::rngs::SmallRng::from_os_rng());
}
// Safety: This is executed on a thread, in only one place. It cannot be borrowed anywhere else.
let slot = THREAD_LOCAL_SMALL_RANDOM
.with_borrow_mut(|rng| rng.random_range(0..self.connections.len()));
self.get_connection_by_slot(slot).await
}
async fn get_connection_by_slot(
&self,
slot: usize,
) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
let connection_state = &self.connections[slot];
// The connection state requires a mutex, so I need to keep await out of the scope to satisfy clippy (and for paranoia).
let connecting_handle = loop {
let mut state = connection_state.lock().expect("internal mutex must work");
break match &mut *state {
ConnectionState::Connected(shared_connection) => {
if shared_connection.is_alive() {
return Ok(shared_connection.clone());
} else {
*state = ConnectionState::Disconnected;
continue;
}
}
ConnectionState::Connecting(join_handle) => join_handle.clone(),
ConnectionState::Disconnected => {
let connector = self.connector.clone();
let load = SpawnedConnect {
inner: tokio::task::spawn(connector.connect()),
}
.shared();
*state = ConnectionState::Connecting(load.clone());
continue;
}
};
};
match connecting_handle.await {
Ok(client) => Ok(reconcile_client_slot(connection_state, client)),
Err(connect_error) => {
let mut state = connection_state.lock().expect("internal mutex must work");
*state = ConnectionState::Disconnected;
Err(connect_error)
}
}
}
}
fn reconcile_client_slot<Request, Response>(
connection_state: &Mutex<ConnectionState<Request, Response>>,
client: RpcClient<Request, Response>,
) -> RpcClient<Request, Response>
where
Request: Message,
Response: Message,
{
let mut state = connection_state.lock().expect("internal mutex must work");
match &mut *state {
ConnectionState::Connecting(_shared) => {
// Here we drop the shared handle. If there is another task still waiting on it, they will get notified when
// the spawned connection task completes. When they come to reconcile with the connection slot, they will
// favor this connection and drop their own.
*state = ConnectionState::Connected(client.clone());
client
}
ConnectionState::Connected(rpc_client) => {
if rpc_client.is_alive() {
// someone else beat us to it
rpc_client.clone()
} else {
// well this one is broken too, so we should just replace it with our new one
*state = ConnectionState::Connected(client.clone());
client
}
}
ConnectionState::Disconnected => {
// we raced with a disconnect, but we have a new client, so use it
*state = ConnectionState::Connected(client.clone());
client
}
}
}
struct SpawnedConnect<Request, Response>
where
Request: Message,
Response: Message,
{
inner: tokio::task::JoinHandle<crate::Result<RpcClient<Request, Response>>>,
}
impl<Request, Response> Future for SpawnedConnect<Request, Response>
where
Request: Message,
Response: Message,
{
type Output = crate::Result<RpcClient<Request, Response>>;
fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
match pin!(&mut self.inner).poll(context) {
Poll::Ready(Ok(client_result)) => Poll::Ready(client_result),
Poll::Ready(Err(_join_err)) => Poll::Ready(Err(crate::Error::ConnectionIsClosed)),
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug)]
enum ConnectionState<Request, Response>
where
Request: Message,
Response: Message,
{
Connecting(futures::future::Shared<SpawnedConnect<Request, Response>>),
Connected(RpcClient<Request, Response>),
Disconnected,
}