Skip to content

Commit c4ec63c

Browse files
authored
Merge pull request #38 from mjgarton/fix_buffering_during_tls_neg
Fix buffering during switch to TLS
2 parents 3fcc1ec + 2a4fe76 commit c4ec63c

File tree

3 files changed

+170
-7
lines changed

3 files changed

+170
-7
lines changed

src/packet.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,11 @@ impl<W: Read + Write> PacketConn<W> {
7777

7878
#[cfg(feature = "tls")]
7979
pub fn switch_to_tls(&mut self, config: std::sync::Arc<ServerConfig>) -> io::Result<()> {
80-
assert_eq!(self.remaining, 0); // otherwise we've read ahead into the TLS handshake and will be in trouble.
81-
82-
self.rw.switch_to_tls(config)
80+
let res = self
81+
.rw
82+
.switch_to_tls(config, &self.bytes[self.bytes.len() - self.remaining..]);
83+
self.remaining = 0;
84+
res
8385
}
8486
}
8587

src/tls.rs

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::io;
1+
use std::io::{self, Chain, Cursor};
22
use std::io::{Read, Write};
33
use std::sync::Arc;
44

@@ -17,7 +17,7 @@ pub(crate) struct SwitchableConn<T: Read + Write>(Option<EitherConn<T>>);
1717

1818
pub(crate) enum EitherConn<T: Read + Write> {
1919
Plain(T),
20-
Tls(rustls::StreamOwned<ServerConnection, T>),
20+
Tls(rustls::StreamOwned<ServerConnection, PrependedReader<T>>),
2121
}
2222

2323
impl<T: Read + Write> Read for SwitchableConn<T> {
@@ -50,9 +50,16 @@ impl<T: Read + Write> SwitchableConn<T> {
5050
SwitchableConn(Some(EitherConn::Plain(rw)))
5151
}
5252

53-
pub fn switch_to_tls(&mut self, config: Arc<ServerConfig>) -> io::Result<()> {
53+
pub fn switch_to_tls(
54+
&mut self,
55+
config: Arc<ServerConfig>,
56+
to_prepend: &[u8],
57+
) -> io::Result<()> {
5458
let replacement = match self.0.take() {
55-
Some(EitherConn::Plain(plain)) => Ok(EitherConn::Tls(create_stream(plain, config)?)),
59+
Some(EitherConn::Plain(plain)) => Ok(EitherConn::Tls(create_stream(
60+
PrependedReader::new(to_prepend, plain),
61+
config,
62+
)?)),
5663
Some(EitherConn::Tls(_)) => Err(io::Error::new(
5764
io::ErrorKind::Other,
5865
"tls variant found when plain was expected",
@@ -64,3 +71,48 @@ impl<T: Read + Write> SwitchableConn<T> {
6471
Ok(())
6572
}
6673
}
74+
75+
pub(crate) struct PrependedReader<RW: Read + Write> {
76+
inner: Chain<Cursor<Vec<u8>>, RW>,
77+
}
78+
79+
impl<RW: Read + Write> PrependedReader<RW> {
80+
fn new(prepended: &[u8], rw: RW) -> PrependedReader<RW> {
81+
PrependedReader {
82+
inner: Cursor::new(prepended.to_vec()).chain(rw),
83+
}
84+
}
85+
}
86+
87+
impl<RW: Read + Write> Read for PrependedReader<RW> {
88+
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
89+
self.inner.read(buf)
90+
}
91+
}
92+
93+
impl<RW: Read + Write> Write for PrependedReader<RW> {
94+
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
95+
self.inner.get_mut().1.write(buf)
96+
}
97+
98+
fn flush(&mut self) -> io::Result<()> {
99+
self.inner.get_mut().1.flush()
100+
}
101+
}
102+
103+
#[cfg(test)]
104+
mod tests {
105+
use std::io::{Cursor, Read};
106+
107+
use super::PrependedReader;
108+
109+
#[test]
110+
fn test_bufreader_replace() {
111+
let mut rw = Cursor::new(vec![1, 2, 3]);
112+
let mut br = PrependedReader::new(&[0, 1, 2], &mut rw);
113+
let mut out = Vec::new();
114+
br.read_to_end(&mut out).unwrap();
115+
116+
assert_eq!(&out, &[0, 1, 2, 1, 2, 3]);
117+
}
118+
}

tests/main.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@ use mysql::SslOpts;
1212
use rustls::{Certificate, PrivateKey, ServerConfig};
1313
use std::error::Error;
1414
use std::io;
15+
use std::io::Read;
16+
use std::io::Write;
1517
use std::net;
1618
use std::thread;
19+
use std::time::Duration;
1720

1821
use msql_srv::{
1922
Column, ErrorKind, InitWriter, MysqlIntermediary, MysqlShim, ParamParser, QueryResultWriter,
@@ -210,6 +213,112 @@ fn it_connects_tls_both() {
210213
.test(|_| {})
211214
}
212215

216+
#[test]
217+
#[cfg(feature = "tls")]
218+
fn it_connects_tls_both_with_delayed_server_read() {
219+
// This test is to ensure correctly handle the case when we read both the pre-TLS data as well
220+
// as (at least part of) the TLS handshake into our the buffer. When that happens, we need to
221+
// ensure we correctly pass that TLS part of the data to rustls so that is can handle the TLS
222+
// handshake properly.
223+
use std::{marker::PhantomData, sync::Arc};
224+
225+
struct MyShim<RW> {
226+
ph: PhantomData<RW>,
227+
}
228+
229+
impl<RW: Read + Write> MysqlShim<RW> for MyShim<RW> {
230+
type Error = io::Error;
231+
232+
fn on_prepare(
233+
&mut self,
234+
_: &str,
235+
_: StatementMetaWriter<'_, RW>,
236+
) -> Result<(), Self::Error> {
237+
unreachable!()
238+
}
239+
240+
fn on_execute(
241+
&mut self,
242+
_: u32,
243+
_: ParamParser<'_>,
244+
_: QueryResultWriter<'_, RW>,
245+
) -> Result<(), Self::Error> {
246+
unreachable!()
247+
}
248+
249+
fn on_close(&mut self, _: u32) {
250+
unreachable!()
251+
}
252+
253+
fn on_query(&mut self, _: &str, _: QueryResultWriter<'_, RW>) -> Result<(), Self::Error> {
254+
unreachable!()
255+
}
256+
257+
fn tls_config(&self) -> Option<Arc<ServerConfig>> {
258+
let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
259+
260+
Some(std::sync::Arc::new(
261+
ServerConfig::builder()
262+
.with_safe_defaults()
263+
.with_no_client_auth()
264+
.with_single_cert(
265+
vec![Certificate(cert.serialize_der().unwrap())],
266+
PrivateKey(cert.get_key_pair().serialize_der()),
267+
)
268+
.unwrap(),
269+
))
270+
}
271+
}
272+
273+
let shim = MyShim {
274+
ph: PhantomData::default(),
275+
};
276+
277+
let listener = net::TcpListener::bind("127.0.0.1:0").unwrap();
278+
let port = listener.local_addr().unwrap().port();
279+
let jh = thread::spawn(move || {
280+
let (s, _) = listener.accept().unwrap();
281+
let s = DelayedReadRW {
282+
s,
283+
read_delay: Duration::from_millis(200),
284+
};
285+
MysqlIntermediary::run_on(shim, s)
286+
});
287+
288+
let db = mysql::Conn::new(
289+
OptsBuilder::default()
290+
.ip_or_hostname(Some("localhost"))
291+
.tcp_port(port)
292+
.ssl_opts(Some(
293+
SslOpts::default().with_danger_accept_invalid_certs(true),
294+
)),
295+
)
296+
.unwrap();
297+
drop(db);
298+
jh.join().unwrap().unwrap();
299+
}
300+
301+
struct DelayedReadRW<RW: Read + Write> {
302+
s: RW,
303+
read_delay: Duration,
304+
}
305+
306+
impl<RW: Read + Write> Read for DelayedReadRW<RW> {
307+
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
308+
thread::sleep(self.read_delay);
309+
self.s.read(buf)
310+
}
311+
}
312+
313+
impl<RW: Read + Write> Write for DelayedReadRW<RW> {
314+
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
315+
self.s.write(buf)
316+
}
317+
318+
fn flush(&mut self) -> io::Result<()> {
319+
self.s.flush()
320+
}
321+
}
213322
#[test]
214323
fn it_does_not_connect_tls_client_only() {
215324
// Client requesting tls fails as expected when server does not support it.

0 commit comments

Comments
 (0)