Skip to content

Commit 10070e6

Browse files
author
mempirate
authored
Clean up socket + durable IO (#7)
2 parents 3f52adc + e74b11a commit 10070e6

File tree

11 files changed

+496
-373
lines changed

11 files changed

+496
-373
lines changed

README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,20 @@ It was built because we needed a Rust-native messaging library like those above.
2020

2121
- [ ] Multiple socket types
2222
- [x] Request/Reply
23-
- [ ] Channel
2423
- [ ] Publish/Subscribe
24+
- [ ] Channel
2525
- [ ] Push/Pull
2626
- [ ] Survey/Respond
2727
- [ ] Stats (RTT, throughput, packet drops etc.)
28-
- [ ] Durable transports (built-in retries and reconnections)
28+
- [x] Request/Reply basic stats
2929
- [ ] Queuing
30-
- [ ] Pluggable transport layer (TCP, UDP, QUIC etc.)
30+
- [ ] Pluggable transport layer
31+
- [x] TCP
32+
- [ ] TLS
33+
- [ ] IPC
34+
- [ ] UDP
35+
- [ ] Inproc
36+
- [x] Durable IO abstraction (built-in retries and reconnections)
3137
- [ ] Simulation modes with [Turmoil](https://github.com/tokio-rs/turmoil)
3238

3339
## Socket Types
@@ -65,6 +71,8 @@ async fn main() {
6571
println!("Response: {:?}", res);
6672
}
6773
```
74+
## MSRV
75+
The minimum supported Rust version is 1.70.
6876

6977
## Contributions & Bug Reports
7078

msg-socket/src/rep/driver.rs

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
use bytes::Bytes;
2+
use futures::{Future, SinkExt, Stream, StreamExt};
3+
use std::{
4+
collections::VecDeque,
5+
net::SocketAddr,
6+
pin::Pin,
7+
sync::Arc,
8+
task::{Context, Poll},
9+
};
10+
use tokio::{
11+
io::{AsyncRead, AsyncWrite},
12+
sync::{mpsc, oneshot},
13+
task::JoinSet,
14+
};
15+
use tokio_stream::{StreamMap, StreamNotifyClose};
16+
use tokio_util::codec::Framed;
17+
18+
use crate::{rep::SocketState, Authenticator, RepError, Request};
19+
use msg_transport::ServerTransport;
20+
use msg_wire::{auth, reqrep};
21+
22+
pub(crate) struct PeerState<T: AsyncRead + AsyncWrite> {
23+
pending_requests: JoinSet<Option<(u32, Bytes)>>,
24+
conn: Framed<T, reqrep::Codec>,
25+
addr: SocketAddr,
26+
egress_queue: VecDeque<reqrep::Message>,
27+
state: Arc<SocketState>,
28+
}
29+
30+
pub(crate) struct RepDriver<T: ServerTransport> {
31+
/// The server transport used to accept incoming connections.
32+
pub(crate) transport: T,
33+
/// The reply socket state, shared with the socket front-end.
34+
pub(crate) state: Arc<SocketState>,
35+
/// [`StreamMap`] of connected peers. The key is the peer's address.
36+
/// Note that when the [`PeerState`] stream ends, it will be silently removed
37+
/// from this map.
38+
pub(crate) peer_states: StreamMap<SocketAddr, StreamNotifyClose<PeerState<T::Io>>>,
39+
/// Sender to the socket front-end. Used to notify the socket of incoming requests.
40+
pub(crate) to_socket: mpsc::Sender<Request>,
41+
/// Optional connection authenticator.
42+
pub(crate) auth: Option<Arc<dyn Authenticator>>,
43+
/// A joinset of authentication tasks.
44+
pub(crate) auth_tasks: JoinSet<Result<AuthResult<T::Io>, RepError>>,
45+
}
46+
47+
pub(crate) struct AuthResult<S: AsyncRead + AsyncWrite> {
48+
id: Bytes,
49+
addr: SocketAddr,
50+
stream: S,
51+
}
52+
53+
impl<T: ServerTransport> Future for RepDriver<T> {
54+
type Output = Result<(), RepError>;
55+
56+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
57+
let this = self.get_mut();
58+
59+
loop {
60+
if let Poll::Ready(Some((peer, msg))) = this.peer_states.poll_next_unpin(cx) {
61+
match msg {
62+
Some(Ok(request)) => {
63+
tracing::debug!("Received request from peer {}", peer);
64+
this.state.stats.increment_rx(request.msg().len());
65+
let _ = this.to_socket.try_send(request);
66+
}
67+
Some(Err(e)) => {
68+
tracing::error!("Error receiving message from peer {}: {:?}", peer, e);
69+
}
70+
None => {
71+
tracing::debug!("Peer {} disconnected", peer);
72+
this.state.stats.decrement_active_clients();
73+
}
74+
}
75+
76+
continue;
77+
}
78+
79+
if let Poll::Ready(Some(Ok(auth))) = this.auth_tasks.poll_join_next(cx) {
80+
match auth {
81+
Ok(auth) => {
82+
// Run custom authenticator
83+
tracing::debug!("Authentication passed for {:?} ({})", auth.id, auth.addr);
84+
this.state.stats.increment_active_clients();
85+
86+
this.peer_states.insert(
87+
auth.addr,
88+
StreamNotifyClose::new(PeerState {
89+
pending_requests: JoinSet::new(),
90+
conn: Framed::new(auth.stream, reqrep::Codec::new()),
91+
addr: auth.addr,
92+
// TODO: pre-allocate according to some options
93+
egress_queue: VecDeque::with_capacity(64),
94+
state: Arc::clone(&this.state),
95+
}),
96+
);
97+
}
98+
Err(e) => {
99+
tracing::error!("Error authenticating client: {:?}", e);
100+
}
101+
}
102+
103+
continue;
104+
}
105+
106+
// Poll the transport for new incoming connections
107+
match this.transport.poll_accept(cx) {
108+
Poll::Ready(Ok((stream, addr))) => {
109+
// If authentication is enabled, start the authentication process
110+
if let Some(ref auth) = this.auth {
111+
let authenticator = Arc::clone(auth);
112+
tracing::debug!("New connection from {}, authenticating", addr);
113+
this.auth_tasks.spawn(async move {
114+
let mut conn = Framed::new(stream, auth::Codec::new_server());
115+
116+
tracing::debug!("Waiting for auth");
117+
// Wait for the response
118+
let auth = conn
119+
.next()
120+
.await
121+
.ok_or(RepError::SocketClosed)?
122+
.map_err(|e| RepError::Auth(e.to_string()))?;
123+
124+
tracing::debug!("Auth received: {:?}", auth);
125+
126+
let auth::Message::Auth(id) = auth else {
127+
conn.send(auth::Message::Reject).await?;
128+
conn.flush().await?;
129+
conn.close().await?;
130+
return Err(RepError::Auth("Invalid auth message".to_string()));
131+
};
132+
133+
// If authentication fails, send a reject message and close the connection
134+
if !authenticator.authenticate(&id) {
135+
conn.send(auth::Message::Reject).await?;
136+
conn.flush().await?;
137+
conn.close().await?;
138+
return Err(RepError::Auth("Authentication failed".to_string()));
139+
}
140+
141+
// Send ack
142+
conn.send(auth::Message::Ack).await?;
143+
conn.flush().await?;
144+
145+
Ok(AuthResult {
146+
id,
147+
addr,
148+
stream: conn.into_inner(),
149+
})
150+
});
151+
} else {
152+
this.state.stats.increment_active_clients();
153+
this.peer_states.insert(
154+
addr,
155+
StreamNotifyClose::new(PeerState {
156+
pending_requests: JoinSet::new(),
157+
conn: Framed::new(stream, reqrep::Codec::new()),
158+
addr,
159+
// TODO: pre-allocate according to some options
160+
egress_queue: VecDeque::with_capacity(64),
161+
state: Arc::clone(&this.state),
162+
}),
163+
);
164+
165+
tracing::debug!("New connection from {}", addr);
166+
}
167+
168+
continue;
169+
}
170+
Poll::Ready(Err(e)) => {
171+
// Errors here are usually about `WouldBlock`
172+
tracing::error!("Error accepting connection: {:?}", e);
173+
174+
continue;
175+
}
176+
Poll::Pending => {}
177+
}
178+
179+
return Poll::Pending;
180+
}
181+
}
182+
}
183+
184+
impl<T: AsyncRead + AsyncWrite + Unpin> Stream for PeerState<T> {
185+
type Item = Result<Request, RepError>;
186+
187+
/// Advances the state of the peer.
188+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
189+
let this = self.get_mut();
190+
191+
loop {
192+
// Flush any messages on the outgoing buffer
193+
let _ = this.conn.poll_flush_unpin(cx);
194+
195+
// Then, try to drain the egress queue.
196+
if this.conn.poll_ready_unpin(cx).is_ready() {
197+
if let Some(msg) = this.egress_queue.pop_front() {
198+
let msg_len = msg.size();
199+
match this.conn.start_send_unpin(msg) {
200+
Ok(_) => {
201+
this.state.stats.increment_tx(msg_len);
202+
// We might be able to send more queued messages
203+
continue;
204+
}
205+
Err(e) => {
206+
tracing::error!("Failed to send message to socket: {:?}", e);
207+
// End this stream as we can't send any more messages
208+
return Poll::Ready(None);
209+
}
210+
}
211+
}
212+
}
213+
214+
// Then we check for completed requests, and push them onto the egress queue.
215+
match this.pending_requests.poll_join_next(cx) {
216+
Poll::Ready(Some(Ok(Some((id, payload))))) => {
217+
let msg = reqrep::Message::new(id, payload);
218+
this.egress_queue.push_back(msg);
219+
220+
continue;
221+
}
222+
Poll::Ready(Some(Ok(None))) => {
223+
tracing::error!("Failed to respond to request");
224+
this.state.stats.increment_failed_requests();
225+
226+
continue;
227+
}
228+
Poll::Ready(Some(Err(e))) => {
229+
tracing::error!("Error receiving response: {:?}", e);
230+
this.state.stats.increment_failed_requests();
231+
232+
continue;
233+
}
234+
_ => {}
235+
}
236+
237+
// Finally we accept incoming requests from the peer.
238+
match this.conn.poll_next_unpin(cx) {
239+
Poll::Ready(Some(result)) => {
240+
tracing::trace!("Received message from peer {}: {:?}", this.addr, result);
241+
let msg = result?;
242+
let msg_id = msg.id();
243+
244+
let (tx, rx) = oneshot::channel();
245+
246+
// Spawn a task to listen for the response. On success, return message ID and response.
247+
this.pending_requests
248+
.spawn(async move { rx.await.ok().map(|res| (msg_id, res)) });
249+
250+
let request = Request {
251+
source: this.addr,
252+
response: tx,
253+
msg: msg.into_payload(),
254+
};
255+
256+
return Poll::Ready(Some(Ok(request)));
257+
}
258+
Poll::Ready(None) => {
259+
tracing::debug!("Connection closed");
260+
return Poll::Ready(None);
261+
}
262+
Poll::Pending => {}
263+
}
264+
265+
return Poll::Pending;
266+
}
267+
}
268+
}

0 commit comments

Comments
 (0)