Skip to content

Commit 6b467d3

Browse files
committed
Fix buffering during switch to TLS
This fixes a race bug where the intereaction between our buffering and the TLS handshake can cause problems. There was an assert in the code to ensure we don't read "too far" into the stream before we enter the TLS handshake. However, it turns out that sometimes (when the client sends the tls handshake information quickly enough after the pre-tls data) we can end up hitting this assertion. This adds a test for this scenario, by using a new `DelayedReadRW` wrapper in the server which reliably triggers this. The fix introduces a new `PrependedReader` type that allows us to push any extra data we may have buffered into it (together with the underlying connection) before passing it in to rustls for the handshake.
1 parent 6fe6041 commit 6b467d3

File tree

3 files changed

+166
-7
lines changed

3 files changed

+166
-7
lines changed

src/packet.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,11 @@ impl<W: Read + Write> PacketConn<W> {
7777

7878
#[cfg(feature = "tls")]
7979
pub fn switch_to_tls(&mut self, config: std::sync::Arc<ServerConfig>) -> io::Result<()> {
80-
assert_eq!(self.remaining, 0); // otherwise we've read ahead into the TLS handshake and will be in trouble.
81-
82-
self.rw.switch_to_tls(config)
80+
let res = self
81+
.rw
82+
.switch_to_tls(config, &self.bytes[self.bytes.len() - self.remaining..]);
83+
self.remaining = 0;
84+
res
8385
}
8486
}
8587

src/tls.rs

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::io;
1+
use std::io::{self, Chain, Cursor};
22
use std::io::{Read, Write};
33
use std::sync::Arc;
44

@@ -17,7 +17,7 @@ pub(crate) struct SwitchableConn<T: Read + Write>(Option<EitherConn<T>>);
1717

1818
pub(crate) enum EitherConn<T: Read + Write> {
1919
Plain(T),
20-
Tls(rustls::StreamOwned<ServerConnection, T>),
20+
Tls(rustls::StreamOwned<ServerConnection, PrependedReader<T>>),
2121
}
2222

2323
impl<T: Read + Write> Read for SwitchableConn<T> {
@@ -50,9 +50,16 @@ impl<T: Read + Write> SwitchableConn<T> {
5050
SwitchableConn(Some(EitherConn::Plain(rw)))
5151
}
5252

53-
pub fn switch_to_tls(&mut self, config: Arc<ServerConfig>) -> io::Result<()> {
53+
pub fn switch_to_tls(
54+
&mut self,
55+
config: Arc<ServerConfig>,
56+
to_prepend: &[u8],
57+
) -> io::Result<()> {
5458
let replacement = match self.0.take() {
55-
Some(EitherConn::Plain(plain)) => Ok(EitherConn::Tls(create_stream(plain, config)?)),
59+
Some(EitherConn::Plain(plain)) => Ok(EitherConn::Tls(create_stream(
60+
PrependedReader::new(to_prepend, plain),
61+
config,
62+
)?)),
5663
Some(EitherConn::Tls(_)) => Err(io::Error::new(
5764
io::ErrorKind::Other,
5865
"tls variant found when plain was expected",
@@ -64,3 +71,48 @@ impl<T: Read + Write> SwitchableConn<T> {
6471
Ok(())
6572
}
6673
}
74+
75+
pub(crate) struct PrependedReader<RW: Read + Write> {
76+
inner: Chain<Cursor<Vec<u8>>, RW>,
77+
}
78+
79+
impl<RW: Read + Write> PrependedReader<RW> {
80+
fn new(prepended: &[u8], rw: RW) -> PrependedReader<RW> {
81+
PrependedReader {
82+
inner: Cursor::new(prepended.to_vec()).chain(rw),
83+
}
84+
}
85+
}
86+
87+
impl<RW: Read + Write> Read for PrependedReader<RW> {
88+
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
89+
self.inner.read(buf)
90+
}
91+
}
92+
93+
impl<RW: Read + Write> Write for PrependedReader<RW> {
94+
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
95+
self.inner.get_mut().1.write(buf)
96+
}
97+
98+
fn flush(&mut self) -> io::Result<()> {
99+
self.inner.get_mut().1.flush()
100+
}
101+
}
102+
103+
#[cfg(test)]
104+
mod tests {
105+
use std::io::{Cursor, Read};
106+
107+
use super::PrependedReader;
108+
109+
#[test]
110+
fn test_bufreader_replace() {
111+
let mut rw = Cursor::new(vec![1, 2, 3]);
112+
let mut br = PrependedReader::new(&[0, 1, 2], &mut rw);
113+
let mut out = Vec::new();
114+
br.read_to_end(&mut out).unwrap();
115+
116+
assert_eq!(&out, &[0, 1, 2, 1, 2, 3]);
117+
}
118+
}

tests/main.rs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@ use mysql::SslOpts;
1212
use rustls::{Certificate, PrivateKey, ServerConfig};
1313
use std::error::Error;
1414
use std::io;
15+
use std::io::Read;
16+
use std::io::Write;
1517
use std::net;
1618
use std::thread;
19+
use std::time::Duration;
1720

1821
use msql_srv::{
1922
Column, ErrorKind, InitWriter, MysqlIntermediary, MysqlShim, ParamParser, QueryResultWriter,
@@ -210,6 +213,108 @@ fn it_connects_tls_both() {
210213
.test(|_| {})
211214
}
212215

216+
#[test]
217+
#[cfg(feature = "tls")]
218+
fn it_connects_tls_both_with_delayed_server_read() {
219+
use std::{marker::PhantomData, sync::Arc};
220+
221+
struct MyShim<RW> {
222+
ph: PhantomData<RW>,
223+
}
224+
225+
impl<RW: Read + Write> MysqlShim<RW> for MyShim<RW> {
226+
type Error = io::Error;
227+
228+
fn on_prepare(
229+
&mut self,
230+
_: &str,
231+
_: StatementMetaWriter<'_, RW>,
232+
) -> Result<(), Self::Error> {
233+
unreachable!()
234+
}
235+
236+
fn on_execute(
237+
&mut self,
238+
_: u32,
239+
_: ParamParser<'_>,
240+
_: QueryResultWriter<'_, RW>,
241+
) -> Result<(), Self::Error> {
242+
unreachable!()
243+
}
244+
245+
fn on_close(&mut self, _: u32) {
246+
unreachable!()
247+
}
248+
249+
fn on_query(&mut self, _: &str, _: QueryResultWriter<'_, RW>) -> Result<(), Self::Error> {
250+
unreachable!()
251+
}
252+
253+
fn tls_config(&self) -> Option<Arc<ServerConfig>> {
254+
let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
255+
256+
Some(std::sync::Arc::new(
257+
ServerConfig::builder()
258+
.with_safe_defaults()
259+
.with_no_client_auth()
260+
.with_single_cert(
261+
vec![Certificate(cert.serialize_der().unwrap())],
262+
PrivateKey(cert.get_key_pair().serialize_der()),
263+
)
264+
.unwrap(),
265+
))
266+
}
267+
}
268+
269+
let shim = MyShim {
270+
ph: PhantomData::default(),
271+
};
272+
273+
let listener = net::TcpListener::bind("127.0.0.1:0").unwrap();
274+
let port = listener.local_addr().unwrap().port();
275+
let jh = thread::spawn(move || {
276+
let (s, _) = listener.accept().unwrap();
277+
let s = DelayedReadRW {
278+
s,
279+
read_delay: Duration::from_millis(200),
280+
};
281+
MysqlIntermediary::run_on(shim, s)
282+
});
283+
284+
let db = mysql::Conn::new(
285+
OptsBuilder::default()
286+
.ip_or_hostname(Some("localhost"))
287+
.tcp_port(port)
288+
.ssl_opts(Some(
289+
SslOpts::default().with_danger_accept_invalid_certs(true),
290+
)),
291+
)
292+
.unwrap();
293+
drop(db);
294+
jh.join().unwrap().unwrap();
295+
}
296+
297+
struct DelayedReadRW<RW: Read + Write> {
298+
s: RW,
299+
read_delay: Duration,
300+
}
301+
302+
impl<RW: Read + Write> Read for DelayedReadRW<RW> {
303+
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
304+
thread::sleep(self.read_delay);
305+
self.s.read(buf)
306+
}
307+
}
308+
309+
impl<RW: Read + Write> Write for DelayedReadRW<RW> {
310+
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
311+
self.s.write(buf)
312+
}
313+
314+
fn flush(&mut self) -> io::Result<()> {
315+
self.s.flush()
316+
}
317+
}
213318
#[test]
214319
fn it_does_not_connect_tls_client_only() {
215320
// Client requesting tls fails as expected when server does not support it.

0 commit comments

Comments
 (0)