Skip to content

Commit b66144b

Browse files
committed
sync ttrpc: Fix the bug caused by oversized packets
Fixes: #198 Signed-off-by: Tim Zhang <[email protected]>
1 parent 8fed599 commit b66144b

File tree

6 files changed

+89
-57
lines changed

6 files changed

+89
-57
lines changed

src/common.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
//! Common functions.
88
9-
use crate::error::{Error, Result};
109
#[cfg(any(
1110
feature = "async",
1211
not(any(target_os = "linux", target_os = "android"))
@@ -16,6 +15,8 @@ use nix::fcntl::{fcntl, FcntlArg, OFlag};
1615
use nix::sys::socket::*;
1716
use std::os::unix::io::RawFd;
1817

18+
use crate::error::{Error, Result};
19+
1920
#[derive(Debug, Clone, Copy, PartialEq)]
2021
pub(crate) enum Domain {
2122
Unix,

src/proto.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ pub use compiled::ttrpc::*;
1313
use byteorder::{BigEndian, ByteOrder};
1414
use protobuf::{CodedInputStream, CodedOutputStream};
1515

16-
#[cfg(feature = "async")]
1716
use crate::error::{get_rpc_status, Error, Result as TtResult};
1817

1918
pub const MESSAGE_HEADER_LENGTH: usize = 10;
2019
pub const MESSAGE_LENGTH_MAX: usize = 4 << 20;
20+
pub const DEFAULT_PAGE_SIZE: usize = 4 << 10;
2121

2222
pub const MESSAGE_TYPE_REQUEST: u8 = 0x1;
2323
pub const MESSAGE_TYPE_RESPONSE: u8 = 0x2;
@@ -27,6 +27,24 @@ pub const FLAG_REMOTE_CLOSED: u8 = 0x1;
2727
pub const FLAG_REMOTE_OPEN: u8 = 0x2;
2828
pub const FLAG_NO_DATA: u8 = 0x4;
2929

30+
pub(crate) fn check_oversize(len: usize, return_rpc_error: bool) -> TtResult<()> {
31+
if len > MESSAGE_LENGTH_MAX {
32+
let msg = format!(
33+
"message length {} exceed maximum message size of {}",
34+
len, MESSAGE_LENGTH_MAX
35+
);
36+
let e = if return_rpc_error {
37+
get_rpc_status(Code::INVALID_ARGUMENT, msg)
38+
} else {
39+
Error::Others(msg)
40+
};
41+
42+
return Err(e);
43+
}
44+
45+
Ok(())
46+
}
47+
3048
/// Message header of ttrpc.
3149
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
3250
pub struct MessageHeader {

src/sync/channel.rs

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
16-
use crate::error::{get_rpc_status, sock_error_msg, Error, Result};
15+
use crate::error::{sock_error_msg, Error, Result};
16+
use crate::proto::{check_oversize, MessageHeader, DEFAULT_PAGE_SIZE, MESSAGE_HEADER_LENGTH};
1717
use crate::sync::sys::PipeConnection;
18-
use crate::proto::{Code, MessageHeader, MESSAGE_HEADER_LENGTH, MESSAGE_LENGTH_MAX};
1918

2019
fn read_count(conn: &PipeConnection, count: usize) -> Result<Vec<u8>> {
2120
let mut v: Vec<u8> = vec![0; count];
@@ -51,7 +50,7 @@ fn write_count(conn: &PipeConnection, buf: &[u8], count: usize) -> Result<usize>
5150
}
5251

5352
loop {
54-
match conn.write(&buf[len..]){
53+
match conn.write(&buf[len..]) {
5554
Ok(l) => {
5655
len += l;
5756
if len == count {
@@ -67,6 +66,18 @@ fn write_count(conn: &PipeConnection, buf: &[u8], count: usize) -> Result<usize>
6766
Ok(len)
6867
}
6968

69+
fn discard_count(conn: &PipeConnection, count: usize) -> Result<()> {
70+
let mut need_discard = count;
71+
72+
while need_discard > 0 {
73+
let once_discard = std::cmp::min(DEFAULT_PAGE_SIZE, need_discard);
74+
read_count(conn, once_discard)?;
75+
need_discard -= once_discard;
76+
}
77+
78+
Ok(())
79+
}
80+
7081
fn read_message_header(conn: &PipeConnection) -> Result<MessageHeader> {
7182
let buf = read_count(conn, MESSAGE_HEADER_LENGTH)?;
7283
let size = buf.len();
@@ -82,18 +93,14 @@ fn read_message_header(conn: &PipeConnection) -> Result<MessageHeader> {
8293
Ok(mh)
8394
}
8495

85-
pub fn read_message(conn: &PipeConnection) -> Result<(MessageHeader, Vec<u8>)> {
96+
pub fn read_message(conn: &PipeConnection) -> Result<(MessageHeader, Result<Vec<u8>>)> {
8697
let mh = read_message_header(conn)?;
8798
trace!("Got Message header {:?}", mh);
8899

89-
if mh.length > MESSAGE_LENGTH_MAX as u32 {
90-
return Err(get_rpc_status(
91-
Code::INVALID_ARGUMENT,
92-
format!(
93-
"message length {} exceed maximum message size of {}",
94-
mh.length, MESSAGE_LENGTH_MAX
95-
),
96-
));
100+
let mh_len = mh.length as usize;
101+
if let Err(e) = check_oversize(mh_len, true) {
102+
discard_count(conn, mh_len)?;
103+
return Ok((mh, Err(e)));
97104
}
98105

99106
let buf = read_count(conn, mh.length as usize)?;
@@ -106,7 +113,7 @@ pub fn read_message(conn: &PipeConnection) -> Result<(MessageHeader, Vec<u8>)> {
106113
}
107114
trace!("Got Message body {:?}", buf);
108115

109-
Ok((mh, buf))
116+
Ok((mh, Ok(buf)))
110117
}
111118

112119
fn write_message_header(conn: &PipeConnection, mh: MessageHeader) -> Result<()> {

src/sync/client.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,9 @@ impl Client {
111111
}
112112
}
113113

114-
let mh;
115-
let buf;
116114
match read_message(&receiver_connection) {
117-
Ok((x, y)) => {
118-
mh = x;
119-
buf = y;
115+
Ok((mh, buf)) => {
116+
trans_resp(recver_map_orig.clone(), mh, buf);
120117
}
121118
Err(x) => match x {
122119
Error::Socket(y) => {
@@ -138,8 +135,6 @@ impl Client {
138135
}
139136
},
140137
};
141-
142-
trans_resp(recver_map_orig.clone(), mh, Ok(buf));
143138
}
144139

145140
let _ = receiver_client

src/sync/server.rs

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use std::sync::{Arc, Mutex};
2626
use std::thread;
2727
use std::thread::JoinHandle;
2828

29-
use super::utils::response_to_channel;
29+
use super::utils::{response_error_to_channel, response_to_channel};
3030
use crate::context;
3131
use crate::error::{get_status, Error, Result};
3232
use crate::proto::{Code, MessageHeader, Request, Response, MESSAGE_TYPE_REQUEST};
@@ -43,8 +43,8 @@ const DEFAULT_WAIT_THREAD_COUNT_MAX: usize = 5;
4343

4444
type MessageSender = Sender<(MessageHeader, Vec<u8>)>;
4545
type MessageReceiver = Receiver<(MessageHeader, Vec<u8>)>;
46-
type WorkloadSender = crossbeam::channel::Sender<(MessageHeader, Vec<u8>)>;
47-
type WorkloadReceiver = crossbeam::channel::Receiver<(MessageHeader, Vec<u8>)>;
46+
type WorkloadSender = crossbeam::channel::Sender<(MessageHeader, Result<Vec<u8>>)>;
47+
type WorkloadReceiver = crossbeam::channel::Receiver<(MessageHeader, Result<Vec<u8>>)>;
4848

4949
/// A ttrpc Server (sync).
5050
pub struct Server {
@@ -134,20 +134,22 @@ fn start_method_handler_thread(
134134
let mh;
135135
let buf;
136136
match result {
137-
Ok((x, y)) => {
137+
Ok((x, Ok(y))) => {
138138
mh = x;
139139
buf = y;
140140
}
141+
Ok((mh, Err(e))) => {
142+
if let Err(x) = response_error_to_channel(mh.stream_id, e, res_tx.clone()) {
143+
debug!("response_error_to_channel get error {:?}", x);
144+
quit_connection(quit, control_tx);
145+
break;
146+
}
147+
continue;
148+
}
141149
Err(x) => match x {
142150
crossbeam::channel::RecvError => {
143151
trace!("workload_rx recv error");
144-
quit.store(true, Ordering::SeqCst);
145-
// the workload tx would be dropped and
146-
// the connection dealing main thread would
147-
// have exited.
148-
control_tx
149-
.send(())
150-
.unwrap_or_else(|err| trace!("Failed to send {:?}", err));
152+
quit_connection(quit, control_tx);
151153
trace!("workload_rx recv error, send control_tx");
152154
break;
153155
}
@@ -165,13 +167,7 @@ fn start_method_handler_thread(
165167
res.set_status(status);
166168
if let Err(x) = response_to_channel(mh.stream_id, res, res_tx.clone()) {
167169
debug!("response_to_channel get error {:?}", x);
168-
quit.store(true, Ordering::SeqCst);
169-
// the client connection would be closed and
170-
// the connection dealing main thread would have
171-
// exited.
172-
control_tx
173-
.send(())
174-
.unwrap_or_else(|err| trace!("Failed to send {:?}", err));
170+
quit_connection(quit, control_tx);
175171
break;
176172
}
177173
continue;
@@ -187,13 +183,7 @@ fn start_method_handler_thread(
187183
res.set_status(status);
188184
if let Err(x) = response_to_channel(mh.stream_id, res, res_tx.clone()) {
189185
info!("response_to_channel get error {:?}", x);
190-
quit.store(true, Ordering::SeqCst);
191-
// the client connection would be closed and
192-
// the connection dealing main thread would have
193-
// exited.
194-
control_tx
195-
.send(())
196-
.unwrap_or_else(|err| trace!("Failed to send {:?}", err));
186+
quit_connection(quit, control_tx);
197187
break;
198188
}
199189
continue;
@@ -208,13 +198,7 @@ fn start_method_handler_thread(
208198
};
209199
if let Err(x) = method.handler(ctx, req) {
210200
debug!("method handle {} get error {:?}", path, x);
211-
quit.store(true, Ordering::SeqCst);
212-
// the client connection would be closed and
213-
// the connection dealing main thread would have
214-
// exited.
215-
control_tx
216-
.send(())
217-
.unwrap_or_else(|err| trace!("Failed to send {:?}", err));
201+
quit_connection(quit, control_tx);
218202
break;
219203
}
220204
}
@@ -595,3 +579,13 @@ impl AsRawFd for Server {
595579
self.listeners[0].as_raw_fd()
596580
}
597581
}
582+
583+
fn quit_connection(quit: Arc<AtomicBool>, control_tx: SyncSender<()>) {
584+
quit.store(true, Ordering::SeqCst);
585+
// the client connection would be closed and
586+
// the connection dealing main thread would
587+
// have exited.
588+
control_tx
589+
.send(())
590+
.unwrap_or_else(|err| debug!("Failed to send {:?}", err));
591+
}

src/sync/utils.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
// SPDX-License-Identifier: Apache-2.0
44
//
55

6-
use crate::error::{Error, Result};
7-
use crate::proto::{MessageHeader, Request, Response, MESSAGE_TYPE_RESPONSE};
6+
use crate::error::{get_status, Error, Result};
7+
use crate::proto::{Code, MessageHeader, Request, Response, MESSAGE_TYPE_RESPONSE};
88
use protobuf::Message;
99
use std::collections::HashMap;
1010

@@ -33,6 +33,23 @@ pub fn response_to_channel(
3333
Ok(())
3434
}
3535

36+
pub fn response_error_to_channel(
37+
stream_id: u32,
38+
e: Error,
39+
tx: std::sync::mpsc::Sender<(MessageHeader, Vec<u8>)>,
40+
) -> Result<()> {
41+
let status = if let Error::RpcStatus(stat) = e {
42+
stat
43+
} else {
44+
get_status(Code::UNKNOWN, e)
45+
};
46+
47+
let mut res = Response::new();
48+
res.set_status(status);
49+
50+
response_to_channel(stream_id, res, tx)
51+
}
52+
3653
/// Handle request in sync mode.
3754
#[macro_export]
3855
macro_rules! request_handler {

0 commit comments

Comments
 (0)