Skip to content

Commit 555c412

Browse files
author
Hui Zhu
authored
Merge pull request #203 from Tim-Zhang/fix-over-size-limit-master
[master] Fix the bug caused by oversized packets
2 parents e5e1dbe + 3ef0e4e commit 555c412

File tree

11 files changed

+340
-193
lines changed

11 files changed

+340
-193
lines changed

src/asynchronous/client.rs

Lines changed: 75 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ use tokio::{self, sync::mpsc, task};
1717
use crate::common::client_connect;
1818
use crate::error::{Error, Result};
1919
use crate::proto::{
20-
Code, Codec, GenMessage, Message, Request, Response, FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN,
21-
MESSAGE_TYPE_DATA, MESSAGE_TYPE_RESPONSE,
20+
Code, Codec, GenMessage, Message, MessageHeader, Request, Response, FLAG_REMOTE_CLOSED,
21+
FLAG_REMOTE_OPEN, MESSAGE_TYPE_DATA, MESSAGE_TYPE_RESPONSE,
2222
};
2323
use crate::r#async::connection::*;
2424
use crate::r#async::shutdown;
@@ -68,7 +68,7 @@ impl Client {
6868
let timeout_nano = req.timeout_nano;
6969
let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);
7070

71-
let msg: GenMessage = Message::new_request(stream_id, req)
71+
let msg: GenMessage = Message::new_request(stream_id, req)?
7272
.try_into()
7373
.map_err(|e: protobuf::Error| Error::Others(e.to_string()))?;
7474

@@ -97,6 +97,7 @@ impl Client {
9797
};
9898

9999
let msg = result?;
100+
100101
let res = Response::decode(msg.payload)
101102
.map_err(err_to_others_err!(e, "Unpack response error "))?;
102103

@@ -117,7 +118,7 @@ impl Client {
117118
) -> Result<StreamInner> {
118119
let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);
119120

120-
let mut msg: GenMessage = Message::new_request(stream_id, req)
121+
let mut msg: GenMessage = Message::new_request(stream_id, req)?
121122
.try_into()
122123
.map_err(|e: protobuf::Error| Error::Others(e.to_string()))?;
123124

@@ -223,6 +224,58 @@ impl WriterDelegate for ClientWriter {
223224
}
224225
}
225226

227+
async fn get_resp_tx(
228+
req_map: Arc<Mutex<HashMap<u32, ResultSender>>>,
229+
header: &MessageHeader,
230+
) -> Option<ResultSender> {
231+
let resp_tx = match header.type_ {
232+
MESSAGE_TYPE_RESPONSE => match req_map.lock().unwrap().remove(&header.stream_id) {
233+
Some(tx) => tx,
234+
None => {
235+
debug!("Receiver got unknown response packet {:?}", header);
236+
return None;
237+
}
238+
},
239+
MESSAGE_TYPE_DATA => {
240+
if (header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED {
241+
match req_map.lock().unwrap().remove(&header.stream_id) {
242+
Some(tx) => tx,
243+
None => {
244+
debug!("Receiver got unknown data packet {:?}", header);
245+
return None;
246+
}
247+
}
248+
} else {
249+
match req_map.lock().unwrap().get(&header.stream_id) {
250+
Some(tx) => tx.clone(),
251+
None => {
252+
debug!("Receiver got unknown data packet {:?}", header);
253+
return None;
254+
}
255+
}
256+
}
257+
}
258+
_ => {
259+
let resp_tx = match req_map.lock().unwrap().remove(&header.stream_id) {
260+
Some(tx) => tx,
261+
None => {
262+
debug!("Receiver got unknown packet {:?}", header);
263+
return None;
264+
}
265+
};
266+
resp_tx
267+
.send(Err(Error::Others(format!(
268+
"Receiver got malformed packet {header:?}"
269+
))))
270+
.await
271+
.unwrap_or_else(|_e| error!("The request has returned"));
272+
return None;
273+
}
274+
};
275+
276+
Some(resp_tx)
277+
}
278+
226279
struct ClientReader {
227280
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
228281
shutdown_waiter: shutdown::Waiter,
@@ -252,59 +305,27 @@ impl ReaderDelegate for ClientReader {
252305

253306
async fn exit(&self) {}
254307

308+
async fn handle_err(&self, header: MessageHeader, e: Error) {
309+
let req_map = self.streams.clone();
310+
tokio::spawn(async move {
311+
if let Some(resp_tx) = get_resp_tx(req_map, &header).await {
312+
resp_tx
313+
.send(Err(e))
314+
.await
315+
.unwrap_or_else(|_e| error!("The request has returned"));
316+
}
317+
});
318+
}
319+
255320
async fn handle_msg(&self, msg: GenMessage) {
256321
let req_map = self.streams.clone();
257322
tokio::spawn(async move {
258-
let resp_tx = match msg.header.type_ {
259-
MESSAGE_TYPE_RESPONSE => {
260-
match req_map.lock().unwrap().remove(&msg.header.stream_id) {
261-
Some(tx) => tx,
262-
None => {
263-
debug!("Receiver got unknown response packet {:?}", msg);
264-
return;
265-
}
266-
}
267-
}
268-
MESSAGE_TYPE_DATA => {
269-
if (msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED {
270-
match req_map.lock().unwrap().remove(&msg.header.stream_id) {
271-
Some(tx) => tx.clone(),
272-
None => {
273-
debug!("Receiver got unknown data packet {:?}", msg);
274-
return;
275-
}
276-
}
277-
} else {
278-
match req_map.lock().unwrap().get(&msg.header.stream_id) {
279-
Some(tx) => tx.clone(),
280-
None => {
281-
debug!("Receiver got unknown data packet {:?}", msg);
282-
return;
283-
}
284-
}
285-
}
286-
}
287-
_ => {
288-
let resp_tx = match req_map.lock().unwrap().remove(&msg.header.stream_id) {
289-
Some(tx) => tx,
290-
None => {
291-
debug!("Receiver got unknown packet {:?}", msg);
292-
return;
293-
}
294-
};
295-
resp_tx
296-
.send(Err(Error::Others(format!(
297-
"Receiver got malformed packet {msg:?}"
298-
))))
299-
.await
300-
.unwrap_or_else(|_e| error!("The request has returned"));
301-
return;
302-
}
303-
};
304-
resp_tx
305-
.send(Ok(msg))
306-
.await
307-
.unwrap_or_else(|_e| error!("The request has returned"));
323+
if let Some(resp_tx) = get_resp_tx(req_map, &msg.header).await {
324+
resp_tx
325+
.send(Ok(msg))
326+
.await
327+
.unwrap_or_else(|_e| error!("The request has returned"));
328+
}
308329
});
309330
}
310331
}

src/asynchronous/connection.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use tokio::{
1414
};
1515

1616
use crate::error::Error;
17-
use crate::proto::GenMessage;
17+
use crate::proto::{GenMessage, GenMessageError, MessageHeader};
1818

1919
pub trait Builder {
2020
type Reader;
@@ -36,6 +36,7 @@ pub trait ReaderDelegate {
3636
async fn disconnect(&self, e: Error, task: &mut task::JoinHandle<()>);
3737
async fn exit(&self);
3838
async fn handle_msg(&self, msg: GenMessage);
39+
async fn handle_err(&self, header: MessageHeader, e: Error);
3940
}
4041

4142
pub struct Connection<S, B: Builder> {
@@ -89,7 +90,12 @@ where
8990
trace!("Got Message {:?}", msg);
9091
reader_delegate.handle_msg(msg).await;
9192
}
92-
Err(e) => {
93+
Err(GenMessageError::ReturnError(header, e)) => {
94+
trace!("Read msg err (can be return): {:?}", e);
95+
reader_delegate.handle_err(header, e).await;
96+
}
97+
98+
Err(GenMessageError::InternalError(e)) => {
9399
trace!("Read msg err: {:?}", e);
94100
reader_delegate.disconnect(e, &mut writer_task).await;
95101
break;

src/asynchronous/server.rs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use async_trait::async_trait;
1818
use futures::stream::Stream;
1919
use futures::StreamExt as _;
2020
use nix::unistd;
21+
use protobuf::Message as _;
2122
use tokio::{
2223
self,
2324
io::{AsyncRead, AsyncWrite},
@@ -35,8 +36,8 @@ use crate::common::{self, Domain};
3536
use crate::context;
3637
use crate::error::{get_status, Error, Result};
3738
use crate::proto::{
38-
Code, Codec, GenMessage, Message, MessageHeader, Request, Response, Status, FLAG_NO_DATA,
39-
FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN, MESSAGE_TYPE_DATA, MESSAGE_TYPE_REQUEST,
39+
check_oversize, Code, Codec, GenMessage, Message, MessageHeader, Request, Response, Status,
40+
FLAG_NO_DATA, FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN, MESSAGE_TYPE_DATA, MESSAGE_TYPE_REQUEST,
4041
};
4142
use crate::r#async::connection::*;
4243
use crate::r#async::shutdown;
@@ -386,6 +387,10 @@ impl ReaderDelegate for ServerReader {
386387
}
387388
});
388389
}
390+
391+
async fn handle_err(&self, header: MessageHeader, e: Error) {
392+
self.context().handle_err(header, e).await
393+
}
389394
}
390395

391396
impl ServerReader {
@@ -410,6 +415,14 @@ struct HandlerContext {
410415
}
411416

412417
impl HandlerContext {
418+
async fn handle_err(&self, header: MessageHeader, e: Error) {
419+
Self::respond(self.tx.clone(), header.stream_id, e.into())
420+
.await
421+
.map_err(|e| {
422+
error!("respond error got error {:?}", e);
423+
})
424+
.ok();
425+
}
413426
async fn handle_msg(&self, msg: GenMessage) {
414427
let stream_id = msg.header.stream_id;
415428

@@ -426,8 +439,13 @@ impl HandlerContext {
426439
match msg.header.type_ {
427440
MESSAGE_TYPE_REQUEST => match self.handle_request(msg).await {
428441
Ok(opt_msg) => match opt_msg {
429-
Some(msg) => {
430-
Self::respond(self.tx.clone(), stream_id, msg)
442+
Some(mut resp) => {
443+
// Server: check size before sending to client
444+
if let Err(e) = check_oversize(resp.compute_size() as usize, true) {
445+
resp = e.into();
446+
}
447+
448+
Self::respond(self.tx.clone(), stream_id, resp)
431449
.await
432450
.map_err(|e| {
433451
error!("respond got error {:?}", e);

src/asynchronous/stream.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,9 @@ impl StreamSender {
418418
header,
419419
payload: buf,
420420
};
421+
422+
msg.check()?;
423+
421424
_send(&self.tx, msg).await?;
422425

423426
Ok(())
@@ -447,6 +450,7 @@ impl StreamReceiver {
447450
return Err(Error::RemoteClosed);
448451
}
449452
let msg = _recv(&mut self.rx).await?;
453+
450454
let payload = match msg.header.type_ {
451455
MESSAGE_TYPE_RESPONSE => {
452456
debug_assert_eq!(self.kind, Kind::Client);

src/common.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
//! Common functions.
88
9-
use crate::error::{Error, Result};
109
#[cfg(any(
1110
feature = "async",
1211
not(any(target_os = "linux", target_os = "android"))
@@ -16,6 +15,8 @@ use nix::fcntl::{fcntl, FcntlArg, OFlag};
1615
use nix::sys::socket::*;
1716
use std::os::unix::io::RawFd;
1817

18+
use crate::error::{Error, Result};
19+
1920
#[derive(Debug, Clone, Copy, PartialEq)]
2021
pub(crate) enum Domain {
2122
Unix,

src/error.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
//! Error and Result of ttrpc and relevant functions, macros.
1616
17-
use crate::proto::{Code, Status};
17+
use crate::proto::{Code, Response, Status};
1818
use std::result;
1919
use thiserror::Error;
2020

@@ -48,6 +48,20 @@ pub enum Error {
4848
Others(String),
4949
}
5050

51+
impl From<Error> for Response {
52+
fn from(e: Error) -> Self {
53+
let status = if let Error::RpcStatus(stat) = e {
54+
stat
55+
} else {
56+
get_status(Code::UNKNOWN, e)
57+
};
58+
59+
let mut res = Response::new();
60+
res.set_status(status);
61+
res
62+
}
63+
}
64+
5165
/// A specialized Result type for ttrpc.
5266
pub type Result<T> = result::Result<T, Error>;
5367

0 commit comments

Comments
 (0)