Skip to content

Commit 4cafce1

Browse files
committed
async: Refactor connection handling.
The client and server handle connections almost identically. Both use a sender side and a receiver side to handle the connection. Their respective differences are implemented using the delegate. Signed-off-by: wllenyj <[email protected]>
1 parent 9f97207 commit 4cafce1

File tree

4 files changed

+390
-291
lines changed

4 files changed

+390
-291
lines changed

src/asynchronous/client.rs

Lines changed: 145 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@ use std::convert::TryInto;
88
use std::os::unix::io::RawFd;
99
use std::sync::{Arc, Mutex};
1010

11+
use async_trait::async_trait;
1112
use nix::unistd::close;
12-
use tokio::{self, io::split, sync::mpsc, sync::Notify};
13+
use tokio::{self, sync::mpsc, task};
1314

1415
use crate::common::client_connect;
1516
use crate::error::{Error, Result};
1617
use crate::proto::{Code, Codec, GenMessage, Message, Request, Response, MESSAGE_TYPE_RESPONSE};
18+
use crate::r#async::connection::*;
19+
use crate::r#async::shutdown;
1720
use crate::r#async::stream::{ResultReceiver, ResultSender};
1821
use crate::r#async::utils;
1922

@@ -36,122 +39,12 @@ impl Client {
3639
pub fn new(fd: RawFd) -> Client {
3740
let stream = utils::new_unix_stream_from_raw_fd(fd);
3841

39-
let (mut reader, mut writer) = split(stream);
40-
let (req_tx, mut rx): (RequestSender, RequestReceiver) = mpsc::channel(100);
42+
let (req_tx, rx): (RequestSender, RequestReceiver) = mpsc::channel(100);
4143

42-
let req_map = Arc::new(Mutex::new(HashMap::new()));
43-
let req_map2 = req_map.clone();
44-
45-
let notify = Arc::new(Notify::new());
46-
let notify2 = notify.clone();
47-
48-
// Request sender
49-
let request_sender = tokio::spawn(async move {
50-
let mut stream_id: u32 = 1;
51-
52-
while let Some((mut msg, resp_tx)) = rx.recv().await {
53-
let current_stream_id = stream_id;
54-
msg.header.set_stream_id(current_stream_id);
55-
stream_id += 2;
56-
57-
{
58-
let mut map = req_map2.lock().unwrap();
59-
map.insert(current_stream_id, resp_tx.clone());
60-
}
61-
62-
if let Err(e) = msg.write_to(&mut writer).await {
63-
error!("write_message got error: {:?}", e);
64-
65-
{
66-
let mut map = req_map2.lock().unwrap();
67-
map.remove(&current_stream_id);
68-
}
44+
let delegate = ClientBuilder { rx: Some(rx) };
6945

70-
let e = Error::Socket(format!("{:?}", e));
71-
resp_tx
72-
.send(Err(e))
73-
.await
74-
.unwrap_or_else(|_e| error!("The request has returned"));
75-
76-
break; // The stream is dead, exit the loop.
77-
}
78-
}
79-
80-
// rx.recv will abort when client.req_tx and client is dropped.
81-
// notify the response-receiver to quit at this time.
82-
notify.notify_one();
83-
});
84-
85-
// Response receiver
86-
tokio::spawn(async move {
87-
loop {
88-
tokio::select! {
89-
_ = notify2.notified() => {
90-
break;
91-
}
92-
res = GenMessage::read_from(&mut reader) => {
93-
match res {
94-
Ok(msg) => {
95-
trace!("Got Message body {:?}", msg.payload);
96-
let req_map = req_map.clone();
97-
tokio::spawn(async move {
98-
let resp_tx2;
99-
{
100-
let mut map = req_map.lock().unwrap();
101-
let resp_tx = match map.get(&msg.header.stream_id) {
102-
Some(tx) => tx,
103-
None => {
104-
debug!(
105-
"Receiver got unknown packet {:?}",
106-
msg
107-
);
108-
return;
109-
}
110-
};
111-
112-
resp_tx2 = resp_tx.clone();
113-
map.remove(&msg.header.stream_id); // Forget the result, just remove.
114-
}
115-
116-
if msg.header.type_ != MESSAGE_TYPE_RESPONSE {
117-
resp_tx2
118-
.send(Err(Error::Others(format!(
119-
"Recver got malformed packet {:?}",
120-
msg
121-
))))
122-
.await
123-
.unwrap_or_else(|_e| error!("The request has returned"));
124-
return;
125-
}
126-
127-
resp_tx2.send(Ok(msg)).await.unwrap_or_else(|_e| error!("The request has returned"));
128-
});
129-
}
130-
Err(e) => {
131-
debug!("Connection closed by the ttRPC server: {}", e);
132-
133-
// Abort the request sender task to prevent incoming RPC requests
134-
// from being processed.
135-
request_sender.abort();
136-
let _ = request_sender.await;
137-
138-
// Take all items out of `req_map`.
139-
let mut map = std::mem::take(&mut *req_map.lock().unwrap());
140-
// Terminate outstanding RPC requests with the error.
141-
for (_stream_id, resp_tx) in map.drain() {
142-
if let Err(_e) = resp_tx.send(Err(e.clone())).await {
143-
warn!("Failed to terminate pending RPC: \
144-
the request has returned");
145-
}
146-
}
147-
148-
break;
149-
}
150-
}
151-
}
152-
};
153-
}
154-
});
46+
let conn = Connection::new(stream, delegate);
47+
tokio::spawn(async move { conn.run().await });
15548

15649
Client { req_tx }
15750
}
@@ -208,3 +101,140 @@ impl Drop for ClientClose {
208101
trace!("All client is droped");
209102
}
210103
}
104+
105+
#[derive(Debug)]
106+
struct ClientBuilder {
107+
rx: Option<RequestReceiver>,
108+
}
109+
110+
impl Builder for ClientBuilder {
111+
type Reader = ClientReader;
112+
type Writer = ClientWriter;
113+
114+
fn build(&mut self) -> (Self::Reader, Self::Writer) {
115+
let (notifier, waiter) = shutdown::new();
116+
let req_map = Arc::new(Mutex::new(HashMap::new()));
117+
(
118+
ClientReader {
119+
shutdown_waiter: waiter,
120+
req_map: req_map.clone(),
121+
},
122+
ClientWriter {
123+
stream_id: 1,
124+
rx: self.rx.take().unwrap(),
125+
shutdown_notifier: notifier,
126+
req_map,
127+
},
128+
)
129+
}
130+
}
131+
132+
struct ClientWriter {
133+
stream_id: u32,
134+
rx: RequestReceiver,
135+
shutdown_notifier: shutdown::Notifier,
136+
req_map: Arc<Mutex<HashMap<u32, ResultSender>>>,
137+
}
138+
139+
#[async_trait]
140+
impl WriterDelegate for ClientWriter {
141+
async fn recv(&mut self) -> Option<GenMessage> {
142+
if let Some((mut msg, resp_tx)) = self.rx.recv().await {
143+
let current_stream_id = self.stream_id;
144+
msg.header.set_stream_id(current_stream_id);
145+
self.stream_id += 2;
146+
{
147+
let mut map = self.req_map.lock().unwrap();
148+
map.insert(current_stream_id, resp_tx);
149+
}
150+
return Some(msg);
151+
} else {
152+
return None;
153+
}
154+
}
155+
156+
async fn disconnect(&self, msg: &GenMessage, e: Error) {
157+
let resp_tx = {
158+
let mut map = self.req_map.lock().unwrap();
159+
map.remove(&msg.header.stream_id)
160+
};
161+
162+
if let Some(resp_tx) = resp_tx {
163+
let e = Error::Socket(format!("{:?}", e));
164+
resp_tx
165+
.send(Err(e))
166+
.await
167+
.unwrap_or_else(|_e| error!("The request has returned"));
168+
}
169+
}
170+
171+
async fn exit(&self) {
172+
self.shutdown_notifier.shutdown();
173+
}
174+
}
175+
176+
struct ClientReader {
177+
shutdown_waiter: shutdown::Waiter,
178+
req_map: Arc<Mutex<HashMap<u32, ResultSender>>>,
179+
}
180+
181+
#[async_trait]
182+
impl ReaderDelegate for ClientReader {
183+
async fn wait_shutdown(&self) {
184+
self.shutdown_waiter.wait_shutdown().await
185+
}
186+
187+
async fn disconnect(&self, e: Error, sender: &mut task::JoinHandle<()>) {
188+
// Abort the request sender task to prevent incoming RPC requests
189+
// from being processed.
190+
sender.abort();
191+
let _ = sender.await;
192+
193+
// Take all items out of `req_map`.
194+
let mut map = std::mem::take(&mut *self.req_map.lock().unwrap());
195+
// Terminate outstanding RPC requests with the error.
196+
for (_stream_id, resp_tx) in map.drain() {
197+
if let Err(_e) = resp_tx.send(Err(e.clone())).await {
198+
warn!("Failed to terminate pending RPC: the request has returned");
199+
}
200+
}
201+
}
202+
203+
async fn exit(&self) {}
204+
205+
async fn handle_msg(&self, msg: GenMessage) {
206+
let req_map = self.req_map.clone();
207+
tokio::spawn(async move {
208+
let resp_tx2;
209+
{
210+
let mut map = req_map.lock().unwrap();
211+
let resp_tx = match map.get(&msg.header.stream_id) {
212+
Some(tx) => tx,
213+
None => {
214+
debug!("Receiver got unknown packet {:?}", msg);
215+
return;
216+
}
217+
};
218+
219+
resp_tx2 = resp_tx.clone();
220+
map.remove(&msg.header.stream_id); // Forget the result, just remove.
221+
}
222+
223+
if msg.header.type_ != MESSAGE_TYPE_RESPONSE {
224+
resp_tx2
225+
.send(Err(Error::Others(format!(
226+
"Recver got malformed packet {:?}",
227+
msg
228+
))))
229+
.await
230+
.unwrap_or_else(|_e| error!("The request has returned"));
231+
return;
232+
}
233+
234+
resp_tx2
235+
.send(Ok(msg))
236+
.await
237+
.unwrap_or_else(|_e| error!("The request has returned"));
238+
});
239+
}
240+
}

0 commit comments

Comments
 (0)