Skip to content

Commit 1d0a72c

Browse files
committed
Merge branch 'master' into further_tls_functionality
2 parents 1841ab6 + 492f0a0 commit 1d0a72c

File tree

5 files changed

+365
-231
lines changed

5 files changed

+365
-231
lines changed

Cargo.toml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,15 @@ nom = "7"
3131
mysql_common = { version = "0.28.0", features = ["chrono"] }
3232
byteorder = "1"
3333
chrono = "0.4"
34-
time = "0.2.25"
3534
rustls = {version = "0.20.0", optional=true}
3635

3736
[dev-dependencies]
3837
postgres = "0.19.1"
39-
mysql = "18"
40-
mysql_async = "0.20.0"
38+
mysql = "22"
39+
mysql_async = "0.29.0"
4140
slab = "0.4.2"
42-
tokio = "0.1.19"
43-
futures = "0.1.26"
41+
tokio = { version = "1.15.0", features = ["full"] }
42+
futures = "0.3.0"
4443
rcgen = "0.8.14"
4544
tempfile = "3.3.0"
4645
native-tls = "0.2.8"

src/packet.rs

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,11 @@ impl<W: Read + Write> PacketConn<W> {
7777

7878
#[cfg(feature = "tls")]
7979
pub fn switch_to_tls(&mut self, config: std::sync::Arc<ServerConfig>) -> io::Result<()> {
80-
assert_eq!(self.remaining, 0); // otherwise we've read ahead into the TLS handshake and will be in trouble.
81-
82-
self.rw.switch_to_tls(config)
80+
let res = self
81+
.rw
82+
.switch_to_tls(config, &self.bytes[self.bytes.len() - self.remaining..]);
83+
self.remaining = 0;
84+
res
8385
}
8486

8587
#[cfg(feature = "tls")]
@@ -103,15 +105,7 @@ impl<R: Read + Write> PacketConn<R> {
103105

104106
loop {
105107
if self.remaining != 0 {
106-
let bytes = {
107-
// NOTE: this is all sorts of unfortunate. what we really want to do is to give
108-
// &self.bytes[self.start..] to `packet()`, and the lifetimes should all work
109-
// out. however, without NLL, borrowck doesn't realize that self.bytes is no
110-
// longer borrowed after the match, and so can be mutated.
111-
let bytes = &self.bytes[self.start..];
112-
unsafe { ::std::slice::from_raw_parts(bytes.as_ptr(), bytes.len()) }
113-
};
114-
match packet(bytes) {
108+
match packet(&self.bytes[self.start..]) {
115109
Ok((rest, p)) => {
116110
self.remaining = rest.len();
117111
return Ok(Some(p));

src/tls.rs

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::io;
1+
use std::io::{self, Chain, Cursor};
22
use std::io::{Read, Write};
33
use std::sync::Arc;
44

@@ -17,7 +17,7 @@ pub(crate) struct SwitchableConn<T: Read + Write>(pub(crate) Option<EitherConn<T
1717

1818
pub(crate) enum EitherConn<T: Read + Write> {
1919
Plain(T),
20-
Tls(rustls::StreamOwned<ServerConnection, T>),
20+
Tls(rustls::StreamOwned<ServerConnection, PrependedReader<T>>),
2121
}
2222

2323
impl<T: Read + Write> Read for SwitchableConn<T> {
@@ -50,9 +50,16 @@ impl<T: Read + Write> SwitchableConn<T> {
5050
SwitchableConn(Some(EitherConn::Plain(rw)))
5151
}
5252

53-
pub fn switch_to_tls(&mut self, config: Arc<ServerConfig>) -> io::Result<()> {
53+
pub fn switch_to_tls(
54+
&mut self,
55+
config: Arc<ServerConfig>,
56+
to_prepend: &[u8],
57+
) -> io::Result<()> {
5458
let replacement = match self.0.take() {
55-
Some(EitherConn::Plain(plain)) => Ok(EitherConn::Tls(create_stream(plain, config)?)),
59+
Some(EitherConn::Plain(plain)) => Ok(EitherConn::Tls(create_stream(
60+
PrependedReader::new(to_prepend, plain),
61+
config,
62+
)?)),
5663
Some(EitherConn::Tls(_)) => Err(io::Error::new(
5764
io::ErrorKind::Other,
5865
"tls variant found when plain was expected",
@@ -64,3 +71,48 @@ impl<T: Read + Write> SwitchableConn<T> {
6471
Ok(())
6572
}
6673
}
74+
75+
pub(crate) struct PrependedReader<RW: Read + Write> {
76+
inner: Chain<Cursor<Vec<u8>>, RW>,
77+
}
78+
79+
impl<RW: Read + Write> PrependedReader<RW> {
80+
fn new(prepended: &[u8], rw: RW) -> PrependedReader<RW> {
81+
PrependedReader {
82+
inner: Cursor::new(prepended.to_vec()).chain(rw),
83+
}
84+
}
85+
}
86+
87+
impl<RW: Read + Write> Read for PrependedReader<RW> {
88+
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
89+
self.inner.read(buf)
90+
}
91+
}
92+
93+
impl<RW: Read + Write> Write for PrependedReader<RW> {
94+
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
95+
self.inner.get_mut().1.write(buf)
96+
}
97+
98+
fn flush(&mut self) -> io::Result<()> {
99+
self.inner.get_mut().1.flush()
100+
}
101+
}
102+
103+
#[cfg(test)]
104+
mod tests {
105+
use std::io::{Cursor, Read};
106+
107+
use super::PrependedReader;
108+
109+
#[test]
110+
fn test_bufreader_replace() {
111+
let mut rw = Cursor::new(vec![1, 2, 3]);
112+
let mut br = PrependedReader::new(&[0, 1, 2], &mut rw);
113+
let mut out = Vec::new();
114+
br.read_to_end(&mut out).unwrap();
115+
116+
assert_eq!(&out, &[0, 1, 2, 1, 2, 3]);
117+
}
118+
}

0 commit comments

Comments
 (0)