Skip to content

Commit e25e286

Browse files
committed
Make TLS optional
This introduces a feature in Cargo.toml to enable TLS support to be optional. Various code has either been conditioned on the new feature, or moved into the tls module if it's only ever used for TLS.
1 parent cee582d commit e25e286

File tree

5 files changed

+102
-77
lines changed

5 files changed

+102
-77
lines changed

Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ categories = ["api-bindings", "network-programming", "database-implementations"]
1717

1818
license = "MIT/Apache-2.0"
1919

20+
[features]
21+
default = ["tls"]
22+
tls = ["rustls"]
23+
2024
[badges]
2125
azure-devops = { project = "jonhoo/jonhoo", pipeline = "msql-srv", build = "27" }
2226
codecov = { repository = "jonhoo/msql-srv", branch = "master", service = "github" }
@@ -28,7 +32,7 @@ mysql_common = "0.22"
2832
byteorder = "1"
2933
chrono = "0.4"
3034
time = "0.2.25"
31-
rustls = "0.20.0"
35+
rustls = {version = "0.20.0", optional=true}
3236

3337
[dev-dependencies]
3438
postgres = "0.19.1"

src/lib.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ use std::io;
100100
use std::io::prelude::*;
101101
use std::iter;
102102
use std::net;
103-
use std::sync::Arc;
104103

105104
use myc::constants::CapabilityFlags;
106105

@@ -111,6 +110,7 @@ mod errorcodes;
111110
mod packet;
112111
mod params;
113112
mod resultset;
113+
#[cfg(feature = "tls")]
114114
mod tls;
115115
mod value;
116116
mod writers;
@@ -190,7 +190,8 @@ pub trait MysqlShim<W: Read + Write> {
190190
}
191191

192192
/// Provides the TLS configuration, if we want to support TLS.
193-
fn tls_config(&self) -> Option<Arc<rustls::ServerConfig>> {
193+
#[cfg(feature = "tls")]
194+
fn tls_config(&self) -> Option<std::sync::Arc<rustls::ServerConfig>> {
194195
None
195196
}
196197
}
@@ -238,6 +239,7 @@ impl<B: MysqlShim<RW>, RW: Read + Write> MysqlIntermediary<B, RW> {
238239
}
239240

240241
fn init(&mut self) -> Result<(), B::Error> {
242+
#[cfg(feature = "tls")]
241243
let tls_conf = self.shim.tls_config();
242244

243245
self.rw.write_all(&[10])?; // protocol 10
@@ -248,6 +250,7 @@ impl<B: MysqlShim<RW>, RW: Read + Write> MysqlIntermediary<B, RW> {
248250
self.rw.write_all(&[0x08, 0x00, 0x00, 0x00])?; // TODO: connection ID
249251
self.rw.write_all(&b";X,po_k}\0"[..])?; // auth seed
250252
let capabilities = &mut [0x00, 0x42]; // 4.1 proto
253+
#[cfg(feature = "tls")]
251254
if tls_conf.is_some() {
252255
capabilities[1] |= 0x08; // SSL support flag
253256
}
@@ -293,6 +296,16 @@ impl<B: MysqlShim<RW>, RW: Read + Write> MysqlIntermediary<B, RW> {
293296

294297
self.rw.set_seq(seq + 1);
295298

299+
#[cfg(not(feature = "tls"))]
300+
if handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL) {
301+
return Err(io::Error::new(
302+
io::ErrorKind::InvalidData,
303+
"client requested SSL despite us not advertising support for it",
304+
)
305+
.into());
306+
}
307+
308+
#[cfg(feature = "tls")]
296309
if handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL) {
297310
let config = tls_conf.ok_or_else(|| {
298311
io::Error::new(

src/packet.rs

Lines changed: 15 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
use byteorder::{ByteOrder, LittleEndian};
2-
use rustls::{ServerConfig, ServerConnection};
2+
#[cfg(feature = "tls")]
3+
use rustls::ServerConfig;
34
use std::io;
45
use std::io::prelude::*;
56

67
const U24_MAX: usize = 16_777_215;
78

89
pub struct PacketConn<RW: Read + Write> {
9-
rw: SwitchableConn<RW>,
10+
#[cfg(feature = "tls")]
11+
rw: tls::SwitchableConn<RW>,
12+
#[cfg(not(feature = "tls"))]
13+
rw: RW,
1014

1115
// read variables
1216
bytes: Vec<u8>,
@@ -38,14 +42,17 @@ impl<W: Read + Write> Write for PacketConn<W> {
3842

3943
impl<RW: Read + Write> PacketConn<RW> {
4044
pub fn new(rw: RW) -> Self {
45+
#[cfg(feature = "tls")]
46+
let rw = tls::SwitchableConn::new(rw);
47+
4148
PacketConn {
4249
bytes: Vec::new(),
4350
start: 0,
4451
remaining: 0,
4552

4653
to_write: vec![0, 0, 0, 0],
4754
seq: 0,
48-
rw: SwitchableConn::new(rw),
55+
rw,
4956
}
5057
}
5158
}
@@ -68,8 +75,9 @@ impl<W: Read + Write> PacketConn<W> {
6875
self.maybe_end_packet()
6976
}
7077

71-
pub fn switch_to_tls(&mut self, config: Arc<ServerConfig>) -> io::Result<()> {
72-
assert_eq!(self.remaining(), 0); // otherwise we've read ahead into the TLS handshake and will be in trouble.
78+
#[cfg(feature = "tls")]
79+
pub fn switch_to_tls(&mut self, config: std::sync::Arc<ServerConfig>) -> io::Result<()> {
80+
assert_eq!(self.remaining, 0); // otherwise we've read ahead into the TLS handshake and will be in trouble.
7381

7482
self.rw.switch_to_tls(config)
7583
}
@@ -134,10 +142,6 @@ impl<R: Read + Write> PacketConn<R> {
134142
}
135143
}
136144
}
137-
138-
pub fn remaining(&self) -> usize {
139-
self.remaining
140-
}
141145
}
142146

143147
pub fn fullpacket(i: &[u8]) -> nom::IResult<&[u8], (u8, &[u8])> {
@@ -171,9 +175,10 @@ impl AsRef<[u8]> for Packet {
171175
}
172176

173177
use std::ops::Deref;
174-
use std::sync::Arc;
175178

179+
#[cfg(feature = "tls")]
176180
use crate::tls;
181+
177182
impl Deref for Packet {
178183
type Target = [u8];
179184
fn deref(&self) -> &Self::Target {
@@ -214,60 +219,6 @@ fn packet(i: &[u8]) -> nom::IResult<&[u8], (u8, Packet)> {
214219
)(i)
215220
}
216221

217-
pub(crate) struct SwitchableConn<T: Read + Write>(Option<EitherConn<T>>);
218-
219-
pub(crate) enum EitherConn<T: Read + Write> {
220-
Plain(T),
221-
TLS(rustls::StreamOwned<ServerConnection, T>),
222-
}
223-
224-
impl<T: Read + Write> Read for SwitchableConn<T> {
225-
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
226-
match &mut self.0.as_mut().unwrap() {
227-
EitherConn::Plain(p) => p.read(buf),
228-
EitherConn::TLS(t) => t.read(buf),
229-
}
230-
}
231-
}
232-
233-
impl<T: Read + Write> Write for SwitchableConn<T> {
234-
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
235-
match &mut self.0.as_mut().unwrap() {
236-
EitherConn::Plain(p) => p.write(buf),
237-
EitherConn::TLS(t) => t.write(buf),
238-
}
239-
}
240-
241-
fn flush(&mut self) -> io::Result<()> {
242-
match &mut self.0.as_mut().unwrap() {
243-
EitherConn::Plain(p) => p.flush(),
244-
EitherConn::TLS(t) => t.flush(),
245-
}
246-
}
247-
}
248-
249-
impl<T: Read + Write> SwitchableConn<T> {
250-
pub fn new(rw: T) -> SwitchableConn<T> {
251-
SwitchableConn(Some(EitherConn::Plain(rw)))
252-
}
253-
254-
pub fn switch_to_tls(&mut self, config: Arc<ServerConfig>) -> io::Result<()> {
255-
let replacement = match self.0.take() {
256-
Some(EitherConn::Plain(plain)) => {
257-
Ok(EitherConn::TLS(tls::create_stream(plain, config)?))
258-
}
259-
Some(EitherConn::TLS(_)) => Err(io::Error::new(
260-
io::ErrorKind::Other,
261-
"tls variant found when plain was expected",
262-
)),
263-
None => unreachable!(),
264-
}?;
265-
266-
self.0 = Some(replacement);
267-
Ok(())
268-
}
269-
}
270-
271222
#[cfg(test)]
272223
mod tests {
273224
use super::*;

src/tls.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,55 @@ pub fn create_stream<T: Read + Write + Sized>(
1212
let stream = rustls::StreamOwned { conn, sock };
1313
Ok(stream)
1414
}
15+
16+
pub(crate) struct SwitchableConn<T: Read + Write>(Option<EitherConn<T>>);
17+
18+
pub(crate) enum EitherConn<T: Read + Write> {
19+
Plain(T),
20+
TLS(rustls::StreamOwned<ServerConnection, T>),
21+
}
22+
23+
impl<T: Read + Write> Read for SwitchableConn<T> {
24+
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
25+
match &mut self.0.as_mut().unwrap() {
26+
EitherConn::Plain(p) => p.read(buf),
27+
EitherConn::TLS(t) => t.read(buf),
28+
}
29+
}
30+
}
31+
32+
impl<T: Read + Write> Write for SwitchableConn<T> {
33+
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
34+
match &mut self.0.as_mut().unwrap() {
35+
EitherConn::Plain(p) => p.write(buf),
36+
EitherConn::TLS(t) => t.write(buf),
37+
}
38+
}
39+
40+
fn flush(&mut self) -> io::Result<()> {
41+
match &mut self.0.as_mut().unwrap() {
42+
EitherConn::Plain(p) => p.flush(),
43+
EitherConn::TLS(t) => t.flush(),
44+
}
45+
}
46+
}
47+
48+
impl<T: Read + Write> SwitchableConn<T> {
49+
pub fn new(rw: T) -> SwitchableConn<T> {
50+
SwitchableConn(Some(EitherConn::Plain(rw)))
51+
}
52+
53+
pub fn switch_to_tls(&mut self, config: Arc<ServerConfig>) -> io::Result<()> {
54+
let replacement = match self.0.take() {
55+
Some(EitherConn::Plain(plain)) => Ok(EitherConn::TLS(create_stream(plain, config)?)),
56+
Some(EitherConn::TLS(_)) => Err(io::Error::new(
57+
io::ErrorKind::Other,
58+
"tls variant found when plain was expected",
59+
)),
60+
None => unreachable!(),
61+
}?;
62+
63+
self.0 = Some(replacement);
64+
Ok(())
65+
}
66+
}

tests/main.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,11 @@ use mysql::prelude::*;
88
use mysql::DriverError;
99
use mysql::OptsBuilder;
1010
use mysql::SslOpts;
11-
use rcgen::generate_simple_self_signed;
12-
use rustls::Certificate;
13-
use rustls::PrivateKey;
14-
use rustls::ServerConfig;
11+
#[cfg(feature = "tls")]
12+
use rustls::{Certificate, PrivateKey, ServerConfig};
1513
use std::error::Error;
1614
use std::io;
1715
use std::net;
18-
use std::sync::Arc;
1916
use std::thread;
2017

2118
use msql_srv::{
@@ -30,7 +27,8 @@ struct TestingShim<Q, P, E, I> {
3027
on_p: P,
3128
on_e: E,
3229
on_i: I,
33-
server_tls: Option<Arc<rustls::ServerConfig>>,
30+
#[cfg(feature = "tls")]
31+
server_tls: Option<std::sync::Arc<rustls::ServerConfig>>,
3432
client_tls: Option<SslOpts>,
3533
}
3634

@@ -75,8 +73,9 @@ where
7573
(self.on_q)(query, results)
7674
}
7775

78-
fn tls_config(&self) -> Option<Arc<rustls::ServerConfig>> {
79-
self.server_tls.as_ref().map(Arc::clone)
76+
#[cfg(feature = "tls")]
77+
fn tls_config(&self) -> Option<std::sync::Arc<rustls::ServerConfig>> {
78+
self.server_tls.as_ref().map(std::sync::Arc::clone)
8079
}
8180
}
8281

@@ -97,6 +96,7 @@ where
9796
on_p,
9897
on_e,
9998
on_i,
99+
#[cfg(feature = "tls")]
100100
server_tls: None,
101101
client_tls: None,
102102
}
@@ -112,10 +112,11 @@ where
112112
self
113113
}
114114

115+
#[cfg(feature = "tls")]
115116
fn with_server_tls(mut self) -> Self {
116-
let cert = generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
117+
let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
117118

118-
self.server_tls = Some(Arc::new(
119+
self.server_tls = Some(std::sync::Arc::new(
119120
ServerConfig::builder()
120121
.with_safe_defaults()
121122
.with_no_client_auth()
@@ -180,6 +181,8 @@ fn it_connects() {
180181
}
181182

182183
#[test]
184+
#[cfg(feature = "tls")]
185+
183186
fn it_connects_tls_server_only() {
184187
// Client can connect ok without SSL when SSL is enabled on the server.
185188
TestingShim::new(
@@ -193,6 +196,8 @@ fn it_connects_tls_server_only() {
193196
}
194197

195198
#[test]
199+
#[cfg(feature = "tls")]
200+
196201
fn it_connects_tls_both() {
197202
// SSL connection when ssl enabled on server and used by client
198203
TestingShim::new(

0 commit comments

Comments
 (0)