Skip to content

Commit 430466d

Browse files
committed
async ttrpc: Fix the bug caused by oversized packets
Fix #198 Signed-off-by: Tim Zhang <[email protected]>
1 parent b445d5b commit 430466d

File tree

6 files changed

+168
-83
lines changed

6 files changed

+168
-83
lines changed

src/asynchronous/client.rs

Lines changed: 73 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ use tokio::{self, sync::mpsc, task};
1717
use crate::common::client_connect;
1818
use crate::error::{Error, Result};
1919
use crate::proto::{
20-
Code, Codec, GenMessage, Message, Request, Response, FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN,
21-
MESSAGE_TYPE_DATA, MESSAGE_TYPE_RESPONSE,
20+
Code, Codec, GenMessage, Message, MessageHeader, Request, Response, FLAG_REMOTE_CLOSED,
21+
FLAG_REMOTE_OPEN, MESSAGE_TYPE_DATA, MESSAGE_TYPE_RESPONSE,
2222
};
2323
use crate::r#async::connection::*;
2424
use crate::r#async::shutdown;
@@ -97,6 +97,7 @@ impl Client {
9797
};
9898

9999
let msg = result?;
100+
100101
let res = Response::decode(msg.payload)
101102
.map_err(err_to_others_err!(e, "Unpack response error "))?;
102103

@@ -223,6 +224,58 @@ impl WriterDelegate for ClientWriter {
223224
}
224225
}
225226

227+
async fn get_resp_tx(
228+
req_map: Arc<Mutex<HashMap<u32, ResultSender>>>,
229+
header: &MessageHeader,
230+
) -> Option<ResultSender> {
231+
let resp_tx = match header.type_ {
232+
MESSAGE_TYPE_RESPONSE => match req_map.lock().unwrap().remove(&header.stream_id) {
233+
Some(tx) => tx,
234+
None => {
235+
debug!("Receiver got unknown response packet {:?}", header);
236+
return None;
237+
}
238+
},
239+
MESSAGE_TYPE_DATA => {
240+
if (header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED {
241+
match req_map.lock().unwrap().remove(&header.stream_id) {
242+
Some(tx) => tx,
243+
None => {
244+
debug!("Receiver got unknown data packet {:?}", header);
245+
return None;
246+
}
247+
}
248+
} else {
249+
match req_map.lock().unwrap().get(&header.stream_id) {
250+
Some(tx) => tx.clone(),
251+
None => {
252+
debug!("Receiver got unknown data packet {:?}", header);
253+
return None;
254+
}
255+
}
256+
}
257+
}
258+
_ => {
259+
let resp_tx = match req_map.lock().unwrap().remove(&header.stream_id) {
260+
Some(tx) => tx,
261+
None => {
262+
debug!("Receiver got unknown packet {:?}", header);
263+
return None;
264+
}
265+
};
266+
resp_tx
267+
.send(Err(Error::Others(format!(
268+
"Receiver got malformed packet {header:?}"
269+
))))
270+
.await
271+
.unwrap_or_else(|_e| error!("The request has returned"));
272+
return None;
273+
}
274+
};
275+
276+
Some(resp_tx)
277+
}
278+
226279
struct ClientReader {
227280
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
228281
shutdown_waiter: shutdown::Waiter,
@@ -252,59 +305,27 @@ impl ReaderDelegate for ClientReader {
252305

253306
async fn exit(&self) {}
254307

308+
async fn handle_err(&self, header: MessageHeader, e: Error) {
309+
let req_map = self.streams.clone();
310+
tokio::spawn(async move {
311+
if let Some(resp_tx) = get_resp_tx(req_map, &header).await {
312+
resp_tx
313+
.send(Err(e))
314+
.await
315+
.unwrap_or_else(|_e| error!("The request has returned"));
316+
}
317+
});
318+
}
319+
255320
async fn handle_msg(&self, msg: GenMessage) {
256321
let req_map = self.streams.clone();
257322
tokio::spawn(async move {
258-
let resp_tx = match msg.header.type_ {
259-
MESSAGE_TYPE_RESPONSE => {
260-
match req_map.lock().unwrap().remove(&msg.header.stream_id) {
261-
Some(tx) => tx,
262-
None => {
263-
debug!("Receiver got unknown response packet {:?}", msg);
264-
return;
265-
}
266-
}
267-
}
268-
MESSAGE_TYPE_DATA => {
269-
if (msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED {
270-
match req_map.lock().unwrap().remove(&msg.header.stream_id) {
271-
Some(tx) => tx.clone(),
272-
None => {
273-
debug!("Receiver got unknown data packet {:?}", msg);
274-
return;
275-
}
276-
}
277-
} else {
278-
match req_map.lock().unwrap().get(&msg.header.stream_id) {
279-
Some(tx) => tx.clone(),
280-
None => {
281-
debug!("Receiver got unknown data packet {:?}", msg);
282-
return;
283-
}
284-
}
285-
}
286-
}
287-
_ => {
288-
let resp_tx = match req_map.lock().unwrap().remove(&msg.header.stream_id) {
289-
Some(tx) => tx,
290-
None => {
291-
debug!("Receiver got unknown packet {:?}", msg);
292-
return;
293-
}
294-
};
295-
resp_tx
296-
.send(Err(Error::Others(format!(
297-
"Receiver got malformed packet {msg:?}"
298-
))))
299-
.await
300-
.unwrap_or_else(|_e| error!("The request has returned"));
301-
return;
302-
}
303-
};
304-
resp_tx
305-
.send(Ok(msg))
306-
.await
307-
.unwrap_or_else(|_e| error!("The request has returned"));
323+
if let Some(resp_tx) = get_resp_tx(req_map, &msg.header).await {
324+
resp_tx
325+
.send(Ok(msg))
326+
.await
327+
.unwrap_or_else(|_e| error!("The request has returned"));
328+
}
308329
});
309330
}
310331
}

src/asynchronous/connection.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use tokio::{
1414
};
1515

1616
use crate::error::Error;
17-
use crate::proto::GenMessage;
17+
use crate::proto::{GenMessage, GenMessageError, MessageHeader};
1818

1919
pub trait Builder {
2020
type Reader;
@@ -36,6 +36,7 @@ pub trait ReaderDelegate {
3636
async fn disconnect(&self, e: Error, task: &mut task::JoinHandle<()>);
3737
async fn exit(&self);
3838
async fn handle_msg(&self, msg: GenMessage);
39+
async fn handle_err(&self, header: MessageHeader, e: Error);
3940
}
4041

4142
pub struct Connection<S, B: Builder> {
@@ -89,7 +90,12 @@ where
8990
trace!("Got Message {:?}", msg);
9091
reader_delegate.handle_msg(msg).await;
9192
}
92-
Err(e) => {
93+
Err(GenMessageError::ReturnError(header, e)) => {
94+
trace!("Read msg err (can be return): {:?}", e);
95+
reader_delegate.handle_err(header, e).await;
96+
}
97+
98+
Err(GenMessageError::InternalError(e)) => {
9399
trace!("Read msg err: {:?}", e);
94100
reader_delegate.disconnect(e, &mut writer_task).await;
95101
break;

src/asynchronous/server.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,10 @@ impl ReaderDelegate for ServerReader {
386386
}
387387
});
388388
}
389+
390+
async fn handle_err(&self, header: MessageHeader, e: Error) {
391+
self.context().handle_err(header, e).await
392+
}
389393
}
390394

391395
impl ServerReader {
@@ -410,6 +414,14 @@ struct HandlerContext {
410414
}
411415

412416
impl HandlerContext {
417+
async fn handle_err(&self, header: MessageHeader, e: Error) {
418+
Self::respond(self.tx.clone(), header.stream_id, e.into())
419+
.await
420+
.map_err(|e| {
421+
error!("respond error got error {:?}", e);
422+
})
423+
.ok();
424+
}
413425
async fn handle_msg(&self, msg: GenMessage) {
414426
let stream_id = msg.header.stream_id;
415427

src/proto.rs

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,27 @@ pub(crate) fn check_oversize(len: usize, return_rpc_error: bool) -> TtResult<()>
4545
Ok(())
4646
}
4747

48+
// Discard the unwanted message body
49+
#[cfg(feature = "async")]
50+
async fn discard_message_body(
51+
mut reader: impl tokio::io::AsyncReadExt + Unpin,
52+
header: &MessageHeader,
53+
) -> TtResult<()> {
54+
let mut need_discard = header.length as usize;
55+
56+
while need_discard > 0 {
57+
let once_discard = std::cmp::min(DEFAULT_PAGE_SIZE, need_discard);
58+
let mut content = vec![0; once_discard];
59+
reader
60+
.read_exact(&mut content)
61+
.await
62+
.map_err(|e| Error::Socket(e.to_string()))?;
63+
need_discard -= once_discard;
64+
}
65+
66+
Ok(())
67+
}
68+
4869
/// Message header of ttrpc.
4970
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
5071
pub struct MessageHeader {
@@ -174,6 +195,18 @@ pub struct GenMessage {
174195
pub payload: Vec<u8>,
175196
}
176197

198+
#[derive(Debug, PartialEq)]
199+
pub enum GenMessageError {
200+
InternalError(Error),
201+
ReturnError(MessageHeader, Error),
202+
}
203+
204+
impl From<Error> for GenMessageError {
205+
fn from(e: Error) -> Self {
206+
Self::InternalError(e)
207+
}
208+
}
209+
177210
#[cfg(feature = "async")]
178211
impl GenMessage {
179212
/// Encodes a MessageHeader to writer.
@@ -193,19 +226,16 @@ impl GenMessage {
193226
}
194227

195228
/// Decodes a MessageHeader from reader.
196-
pub async fn read_from(mut reader: impl tokio::io::AsyncReadExt + Unpin) -> TtResult<Self> {
229+
pub async fn read_from(
230+
mut reader: impl tokio::io::AsyncReadExt + Unpin,
231+
) -> std::result::Result<Self, GenMessageError> {
197232
let header = MessageHeader::read_from(&mut reader)
198233
.await
199234
.map_err(|e| Error::Socket(e.to_string()))?;
200235

201-
if header.length > MESSAGE_LENGTH_MAX as u32 {
202-
return Err(get_rpc_status(
203-
Code::INVALID_ARGUMENT,
204-
format!(
205-
"message length {} exceed maximum message size of {}",
206-
header.length, MESSAGE_LENGTH_MAX
207-
),
208-
));
236+
if let Err(e) = check_oversize(header.length as usize, true) {
237+
discard_message_body(reader, &header).await?;
238+
return Err(GenMessageError::ReturnError(header, e));
209239
}
210240

211241
let mut content = vec![0; header.length as usize];
@@ -328,14 +358,12 @@ where
328358
.await
329359
.map_err(|e| Error::Socket(e.to_string()))?;
330360

331-
if header.length > MESSAGE_LENGTH_MAX as u32 {
332-
return Err(get_rpc_status(
333-
Code::INVALID_ARGUMENT,
334-
format!(
335-
"message length {} exceed maximum message size of {}",
336-
header.length, MESSAGE_LENGTH_MAX
337-
),
338-
));
361+
if check_oversize(header.length as usize, true).is_err() {
362+
discard_message_body(reader, &header).await?;
363+
return Ok(Self {
364+
header,
365+
payload: C::decode("").map_err(err_to_others_err!(e, "Decode payload failed."))?,
366+
});
339367
}
340368

341369
let mut content = vec![0; header.length as usize];
@@ -447,11 +475,21 @@ mod tests {
447475
#[cfg(feature = "async")]
448476
#[tokio::test]
449477
async fn async_gen_message() {
478+
// Test packet which exceeds maximum message size
450479
let mut buf = Vec::from(MESSAGE_HEADER);
451-
buf.extend_from_slice(&PROTOBUF_REQUEST);
452-
let res = GenMessage::read_from(&*buf).await;
453-
// exceed maximum message size
454-
assert!(matches!(res, Err(Error::RpcStatus(_))));
480+
let header = MessageHeader::read_from(&*buf).await.expect("read header");
481+
buf.append(&mut vec![0x0; header.length as usize]);
482+
483+
match GenMessage::read_from(&*buf).await {
484+
Err(GenMessageError::ReturnError(h, Error::RpcStatus(s))) => {
485+
if h != header || s.code() != crate::proto::Code::INVALID_ARGUMENT {
486+
panic!("got invalid error when the size exceeds limit");
487+
}
488+
}
489+
_ => {
490+
panic!("got invalid error when the size exceeds limit");
491+
}
492+
}
455493

456494
let mut buf = Vec::from(PROTOBUF_MESSAGE_HEADER);
457495
buf.extend_from_slice(&PROTOBUF_REQUEST);
@@ -477,11 +515,17 @@ mod tests {
477515
#[cfg(feature = "async")]
478516
#[tokio::test]
479517
async fn async_message() {
518+
// Test packet which exceeds maximum message size
480519
let mut buf = Vec::from(MESSAGE_HEADER);
481-
buf.extend_from_slice(&PROTOBUF_REQUEST);
482-
let res = Message::<Request>::read_from(&*buf).await;
483-
// exceed maximum message size
484-
assert!(matches!(res, Err(Error::RpcStatus(_))));
520+
let header = MessageHeader::read_from(&*buf).await.expect("read header");
521+
buf.append(&mut vec![0x0; header.length as usize]);
522+
523+
let gen = Message::<Request>::read_from(&*buf)
524+
.await
525+
.expect("read message");
526+
527+
assert_eq!(gen.header, header);
528+
assert_eq!(protobuf::Message::compute_size(&gen.payload), 0);
485529

486530
let mut buf = Vec::from(PROTOBUF_MESSAGE_HEADER);
487531
buf.extend_from_slice(&PROTOBUF_REQUEST);

src/sync/client.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ use std::sync::{Arc, Mutex};
2424
use std::thread;
2525
use std::time::Duration;
2626

27-
use crate::common::check_oversize;
2827
use crate::error::{Error, Result};
29-
use crate::proto::{Code, Codec, MessageHeader, Request, Response, MESSAGE_TYPE_RESPONSE};
28+
use crate::proto::{
29+
check_oversize, Code, Codec, MessageHeader, Request, Response, MESSAGE_TYPE_RESPONSE,
30+
};
3031
use crate::sync::channel::{read_message, write_message};
3132
use crate::sync::sys::ClientConnection;
3233

src/sync/utils.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
// SPDX-License-Identifier: Apache-2.0
44
//
55

6-
use crate::common::check_oversize;
76
use crate::error::{Error, Result};
8-
use crate::proto::{Codec, MessageHeader, Request, Response, MESSAGE_TYPE_RESPONSE};
7+
use crate::proto::{
8+
check_oversize, Codec, MessageHeader, Request, Response, MESSAGE_TYPE_RESPONSE,
9+
};
910
use std::collections::HashMap;
1011

1112
/// Response message through a channel.

0 commit comments

Comments
 (0)