Skip to content

Commit 69bd03d

Browse files
committed
async: Refactor message encoding/deocding.
Make message encoding/decoding uniform. Signed-off-by: wllenyj <[email protected]>
1 parent 54ab070 commit 69bd03d

File tree

4 files changed

+64
-165
lines changed

4 files changed

+64
-165
lines changed

src/asynchronous/client.rs

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,16 @@ use std::os::unix::io::RawFd;
99
use std::sync::{Arc, Mutex};
1010

1111
use nix::unistd::close;
12-
use tokio::{
13-
self,
14-
io::split,
15-
sync::mpsc::{channel, Receiver, Sender},
16-
sync::Notify,
17-
};
12+
use tokio::{self, io::split, sync::mpsc, sync::Notify};
1813

1914
use crate::common::client_connect;
2015
use crate::error::{Error, Result};
2116
use crate::proto::{Code, Codec, GenMessage, Message, Request, Response, MESSAGE_TYPE_RESPONSE};
17+
use crate::r#async::stream::{ResultReceiver, ResultSender};
2218
use crate::r#async::utils;
2319

24-
type RequestSender = Sender<(GenMessage, Sender<Result<Vec<u8>>>)>;
25-
type RequestReceiver = Receiver<(GenMessage, Sender<Result<Vec<u8>>>)>;
26-
27-
type ResponseSender = Sender<Result<Vec<u8>>>;
28-
type ResponseReceiver = Receiver<Result<Vec<u8>>>;
20+
type RequestSender = mpsc::Sender<(GenMessage, ResultSender)>;
21+
type RequestReceiver = mpsc::Receiver<(GenMessage, ResultSender)>;
2922

3023
/// A ttrpc Client (async).
3124
#[derive(Clone)]
@@ -44,7 +37,7 @@ impl Client {
4437
let stream = utils::new_unix_stream_from_raw_fd(fd);
4538

4639
let (mut reader, mut writer) = split(stream);
47-
let (req_tx, mut rx): (RequestSender, RequestReceiver) = channel(100);
40+
let (req_tx, mut rx): (RequestSender, RequestReceiver) = mpsc::channel(100);
4841

4942
let req_map = Arc::new(Mutex::new(HashMap::new()));
5043
let req_map2 = req_map.clone();
@@ -131,7 +124,7 @@ impl Client {
131124
return;
132125
}
133126

134-
resp_tx2.send(Ok(msg.payload)).await.unwrap_or_else(|_e| error!("The request has returned"));
127+
resp_tx2.send(Ok(msg)).await.unwrap_or_else(|_e| error!("The request has returned"));
135128
});
136129
}
137130
Err(e) => {
@@ -170,7 +163,7 @@ impl Client {
170163
.try_into()
171164
.map_err(|e: protobuf::error::ProtobufError| Error::Others(e.to_string()))?;
172165

173-
let (tx, mut rx): (ResponseSender, ResponseReceiver) = channel(100);
166+
let (tx, mut rx): (ResultSender, ResultReceiver) = mpsc::channel(100);
174167
self.req_tx
175168
.send((msg, tx))
176169
.await
@@ -190,9 +183,9 @@ impl Client {
190183
.ok_or_else(|| Error::Others("Receive packet from receiver error".to_string()))?
191184
};
192185

193-
let buf = result?;
194-
let res =
195-
Response::decode(&buf).map_err(err_to_others_err!(e, "Unpack response error "))?;
186+
let msg = result?;
187+
let res = Response::decode(&msg.payload)
188+
.map_err(err_to_others_err!(e, "Unpack response error "))?;
196189

197190
let status = res.get_status();
198191
if status.get_code() != Code::OK {

src/asynchronous/server.rs

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,23 @@ use futures::StreamExt as _;
1717
use nix::unistd;
1818
use tokio::{
1919
self,
20-
io::{split, AsyncRead, AsyncWrite, AsyncWriteExt},
20+
io::{split, AsyncRead, AsyncWrite},
2121
net::UnixListener,
2222
select, spawn,
23-
sync::mpsc::{channel, Receiver, Sender},
23+
sync::mpsc::{channel, Sender},
2424
time::timeout,
2525
};
2626
#[cfg(target_os = "linux")]
2727
use tokio_vsock::VsockListener;
2828

29-
use crate::asynchronous::stream::{receive, respond, respond_with_status};
29+
use crate::asynchronous::stream::{respond, respond_with_status};
3030
use crate::asynchronous::unix_incoming::UnixIncoming;
3131
use crate::common::{self, Domain};
3232
use crate::context;
3333
use crate::error::{get_status, Error, Result};
34-
use crate::proto::{Code, MessageHeader, Status, MESSAGE_TYPE_REQUEST};
34+
use crate::proto::{Code, GenMessage, MessageHeader, Response, Status, MESSAGE_TYPE_REQUEST};
3535
use crate::r#async::shutdown;
36+
use crate::r#async::stream::{MessageReceiver, MessageSender};
3637
use crate::r#async::utils;
3738
use crate::r#async::{MethodHandler, TtrpcContext};
3839

@@ -244,15 +245,15 @@ async fn spawn_connection_handler<S>(
244245
{
245246
spawn(async move {
246247
let (mut reader, mut writer) = split(stream);
247-
let (tx, mut rx): (Sender<Vec<u8>>, Receiver<Vec<u8>>) = channel(100);
248+
let (tx, mut rx): (MessageSender, MessageReceiver) = channel(100);
248249

249250
let server_shutdown = shutdown_waiter.clone();
250251
let (disconnect_notifier, disconnect_waiter) =
251252
shutdown::with_timeout(DEFAULT_CONN_SHUTDOWN_TIMEOUT);
252253

253254
spawn(async move {
254-
while let Some(buf) = rx.recv().await {
255-
if let Err(e) = writer.write_all(&buf).await {
255+
while let Some(msg) = rx.recv().await {
256+
if let Err(e) = msg.write_to(&mut writer).await {
256257
error!("write_message got error: {:?}", e);
257258
}
258259
}
@@ -264,8 +265,8 @@ async fn spawn_connection_handler<S>(
264265
let handler_shutdown_waiter = disconnect_waiter.clone();
265266

266267
select! {
267-
resp = receive(&mut reader) => {
268-
match resp {
268+
res = GenMessage::read_from(&mut reader) => {
269+
match res {
269270
Ok(message) => {
270271
spawn(async move {
271272
select! {
@@ -304,7 +305,7 @@ async fn do_handle_request(
304305
methods: Arc<HashMap<String, Box<dyn MethodHandler + Send + Sync>>>,
305306
header: MessageHeader,
306307
body: &[u8],
307-
) -> StdResult<(u32, Vec<u8>), Status> {
308+
) -> StdResult<Option<Response>, Status> {
308309
let req = utils::body_to_request(body)?;
309310
let path = utils::get_path(&req.service, &req.method);
310311
let method = methods
@@ -328,6 +329,7 @@ async fn do_handle_request(
328329
.handler(ctx, req)
329330
.await
330331
.map_err(get_unknown_status_and_log_err)
332+
.map(Some)
331333
} else {
332334
timeout(
333335
Duration::from_nanos(req.timeout_nano as u64),
@@ -343,32 +345,39 @@ async fn do_handle_request(
343345
// Handler finished
344346
r.map_err(get_unknown_status_and_log_err)
345347
})
348+
.map(Some)
346349
}
347350
}
348351

349352
async fn handle_request(
350-
tx: Sender<Vec<u8>>,
353+
tx: MessageSender,
351354
fd: RawFd,
352355
methods: Arc<HashMap<String, Box<dyn MethodHandler + Send + Sync>>>,
353-
message: (MessageHeader, Vec<u8>),
356+
message: GenMessage,
354357
) {
355-
let (header, body) = message;
358+
let GenMessage {
359+
header,
360+
payload: body,
361+
} = message;
356362
let stream_id = header.stream_id;
357363

358364
if header.type_ != MESSAGE_TYPE_REQUEST {
359365
return;
360366
}
361367

362368
match do_handle_request(fd, methods, header, &body).await {
363-
Ok((stream_id, resp_body)) => {
364-
if let Err(x) = respond(tx.clone(), stream_id, resp_body).await {
365-
error!("respond got error {:?}", x);
369+
Ok(opt_msg) => match opt_msg {
370+
Some(msg) => {
371+
if let Err(x) = respond(tx.clone(), stream_id, msg).await {
372+
error!("respond got error {:?}", x);
373+
}
366374
}
367-
}
368-
Err(status) => {
369-
if let Err(x) = respond_with_status(tx.clone(), stream_id, status).await {
370-
error!("respond got error {:?}", x);
375+
None => {
376+
unimplemented!();
371377
}
378+
},
379+
Err(status) => {
380+
respond_with_status(tx.clone(), stream_id, status).await;
372381
}
373382
}
374383
}

src/asynchronous/stream.rs

Lines changed: 22 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -3,129 +3,40 @@
33
// SPDX-License-Identifier: Apache-2.0
44
//
55

6-
use protobuf::Message;
7-
use tokio::io::AsyncReadExt;
6+
use tokio::sync::mpsc;
87

9-
use crate::error::{get_rpc_status, sock_error_msg, Error, Result};
10-
use crate::proto::{
11-
Code, Response, Status, MESSAGE_HEADER_LENGTH, MESSAGE_LENGTH_MAX, MESSAGE_TYPE_RESPONSE,
12-
};
13-
use crate::r#async::utils;
8+
use crate::error::{Error, Result};
9+
use crate::proto::{Codec, GenMessage, Response, Status};
1410
use crate::MessageHeader;
1511

16-
async fn receive_count<T>(reader: &mut T, count: usize) -> Result<Vec<u8>>
17-
where
18-
T: AsyncReadExt + std::marker::Unpin,
19-
{
20-
let mut content = vec![0u8; count];
21-
if let Err(e) = reader.read_exact(&mut content).await {
22-
return Err(Error::Socket(e.to_string()));
23-
}
12+
pub type MessageSender = mpsc::Sender<GenMessage>;
13+
pub type MessageReceiver = mpsc::Receiver<GenMessage>;
2414

25-
Ok(content)
26-
}
27-
28-
async fn receive_header<T>(reader: &mut T) -> Result<MessageHeader>
29-
where
30-
T: AsyncReadExt + std::marker::Unpin,
31-
{
32-
let buf = receive_count(reader, MESSAGE_HEADER_LENGTH).await?;
33-
let size = buf.len();
34-
if size != MESSAGE_HEADER_LENGTH {
35-
return Err(sock_error_msg(
36-
size,
37-
format!("Message header length {} is too small", size),
38-
));
39-
}
40-
41-
let mh = MessageHeader::from(&buf);
42-
43-
Ok(mh)
44-
}
45-
46-
pub(crate) async fn receive<T>(reader: &mut T) -> Result<(MessageHeader, Vec<u8>)>
47-
where
48-
T: AsyncReadExt + std::marker::Unpin,
49-
{
50-
let mh = receive_header(reader).await?;
51-
trace!("Got Message header {:?}", mh);
52-
53-
if mh.length > MESSAGE_LENGTH_MAX as u32 {
54-
return Err(get_rpc_status(
55-
Code::INVALID_ARGUMENT,
56-
format!(
57-
"message length {} exceed maximum message size of {}",
58-
mh.length, MESSAGE_LENGTH_MAX
59-
),
60-
));
61-
}
62-
63-
let buf = receive_count(reader, mh.length as usize).await?;
64-
let size = buf.len();
65-
if size != mh.length as usize {
66-
return Err(sock_error_msg(
67-
size,
68-
format!("Message length {} is not {}", size, mh.length),
69-
));
70-
}
71-
trace!("Got Message body {:?}", buf);
72-
73-
Ok((mh, buf))
74-
}
75-
76-
fn header_to_buf(mh: MessageHeader) -> Vec<u8> {
77-
mh.into()
78-
}
15+
pub type ResultSender = mpsc::Sender<Result<GenMessage>>;
16+
pub type ResultReceiver = mpsc::Receiver<Result<GenMessage>>;
7917

80-
pub(crate) fn to_res_buf(stream_id: u32, mut body: Vec<u8>) -> Vec<u8> {
81-
let header = utils::get_response_header_from_body(stream_id, &body);
82-
let mut buf = header_to_buf(header);
83-
buf.append(&mut body);
84-
85-
buf
86-
}
87-
88-
fn get_response_body(res: &Response) -> Result<Vec<u8>> {
89-
let mut buf = Vec::with_capacity(res.compute_size() as usize);
90-
let mut s = protobuf::CodedOutputStream::vec(&mut buf);
91-
res.write_to(&mut s).map_err(err_to_others_err!(e, ""))?;
92-
s.flush().map_err(err_to_others_err!(e, ""))?;
93-
94-
Ok(buf)
95-
}
96-
97-
pub(crate) async fn respond(
98-
tx: tokio::sync::mpsc::Sender<Vec<u8>>,
99-
stream_id: u32,
100-
body: Vec<u8>,
101-
) -> Result<()> {
102-
let buf = to_res_buf(stream_id, body);
18+
pub(crate) async fn respond(tx: MessageSender, stream_id: u32, resp: Response) -> Result<()> {
19+
let payload = resp
20+
.encode()
21+
.map_err(err_to_others_err!(e, "Encode Response failed."))?;
22+
let msg = GenMessage {
23+
header: MessageHeader::new_response(stream_id, payload.len() as u32),
24+
payload,
25+
};
10326

104-
tx.send(buf)
27+
tx.send(msg)
10528
.await
10629
.map_err(err_to_others_err!(e, "Send packet to sender error "))
10730
}
10831

109-
pub(crate) async fn respond_with_status(
110-
tx: tokio::sync::mpsc::Sender<Vec<u8>>,
111-
stream_id: u32,
112-
status: Status,
113-
) -> Result<()> {
32+
pub(crate) async fn respond_with_status(tx: MessageSender, stream_id: u32, status: Status) {
11433
let mut res = Response::new();
11534
res.set_status(status);
116-
let mut body = get_response_body(&res)?;
117-
118-
let mh = MessageHeader {
119-
length: body.len() as u32,
120-
stream_id,
121-
type_: MESSAGE_TYPE_RESPONSE,
122-
flags: 0,
123-
};
12435

125-
let mut buf = header_to_buf(mh);
126-
buf.append(&mut body);
127-
128-
tx.send(buf)
36+
respond(tx, stream_id, res)
12937
.await
130-
.map_err(err_to_others_err!(e, "Send packet to sender error "))
38+
.map_err(|e| {
39+
error!("respond with status got error {:?}", e);
40+
})
41+
.ok();
13142
}

src/asynchronous/utils.rs

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use protobuf::{CodedInputStream, Message};
1212
use tokio::net::UnixStream;
1313

1414
use crate::error::{get_status, Result};
15-
use crate::proto::{Code, MessageHeader, Request, Status, MESSAGE_TYPE_RESPONSE};
15+
use crate::proto::{Code, MessageHeader, Request, Response, Status};
1616

1717
/// Handle request in async mode.
1818
#[macro_export]
@@ -48,12 +48,7 @@ macro_rules! async_request_handler {
4848
},
4949
}
5050

51-
let mut buf = Vec::with_capacity(res.compute_size() as usize);
52-
let mut s = protobuf::CodedOutputStream::vec(&mut buf);
53-
res.write_to(&mut s).map_err(ttrpc::err_to_others!(e, ""))?;
54-
s.flush().map_err(ttrpc::err_to_others!(e, ""))?;
55-
56-
return Ok(($ctx.mh.stream_id, buf));
51+
return Ok(res);
5752
};
5853
}
5954

@@ -88,7 +83,7 @@ macro_rules! async_client_request {
8883
/// Trait that implements handler which is a proxy to the desired method (async).
8984
#[async_trait]
9085
pub trait MethodHandler {
91-
async fn handler(&self, ctx: TtrpcContext, req: Request) -> Result<(u32, Vec<u8>)>;
86+
async fn handler(&self, ctx: TtrpcContext, req: Request) -> Result<Response>;
9287
}
9388

9489
/// The context of ttrpc (async).
@@ -100,15 +95,6 @@ pub struct TtrpcContext {
10095
pub timeout_nano: i64,
10196
}
10297

103-
pub(crate) fn get_response_header_from_body(stream_id: u32, body: &[u8]) -> MessageHeader {
104-
MessageHeader {
105-
length: body.len() as u32,
106-
stream_id,
107-
type_: MESSAGE_TYPE_RESPONSE,
108-
flags: 0,
109-
}
110-
}
111-
11298
pub(crate) fn new_unix_stream_from_raw_fd(fd: RawFd) -> UnixStream {
11399
let std_stream: std::os::unix::net::UnixStream;
114100
unsafe {

0 commit comments

Comments
 (0)