Skip to content

Commit 58741fe

Browse files
committed
Change tls_config to return Arc
Since we ultimately need an Arc anyway, let's return one in the first place to avoid a clone later.
1 parent 96777ad commit 58741fe

File tree

4 files changed

+16
-13
lines changed

4 files changed

+16
-13
lines changed

src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ use std::io;
100100
use std::io::prelude::*;
101101
use std::iter;
102102
use std::net;
103+
use std::sync::Arc;
103104

104105
use myc::constants::CapabilityFlags;
105106

@@ -189,7 +190,7 @@ pub trait MysqlShim<W: Read + Write> {
189190
}
190191

191192
/// Provides the TLS configuration, if we want to support TLS.
192-
fn tls_config(&self) -> Option<&rustls::ServerConfig> {
193+
fn tls_config(&self) -> Option<Arc<rustls::ServerConfig>> {
193194
None
194195
}
195196
}
@@ -298,7 +299,7 @@ impl<B: MysqlShim<RW>, RW: Read + Write> MysqlIntermediary<B, RW> {
298299
)
299300
})?;
300301

301-
self.rw.switch_to_tls(&config)?;
302+
self.rw.switch_to_tls(config)?;
302303

303304
let (seq, handshake) = self.rw.next()?.ok_or_else(|| {
304305
io::Error::new(

src/packet.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use byteorder::{ByteOrder, LittleEndian};
2-
use rustls::ServerConnection;
2+
use rustls::{ServerConfig, ServerConnection};
33
use std::io;
44
use std::io::prelude::*;
55

@@ -68,7 +68,7 @@ impl<W: Read + Write> PacketConn<W> {
6868
self.maybe_end_packet()
6969
}
7070

71-
pub fn switch_to_tls(&mut self, config: &rustls::ServerConfig) -> io::Result<()> {
71+
pub fn switch_to_tls(&mut self, config: Arc<ServerConfig>) -> io::Result<()> {
7272
assert!(self.remaining() == 0); // otherwise we've read ahead into the TLS handshake and will be in trouble.
7373

7474
self.rw.switch_to_tls(config)
@@ -171,6 +171,7 @@ impl AsRef<[u8]> for Packet {
171171
}
172172

173173
use std::ops::Deref;
174+
use std::sync::Arc;
174175

175176
use crate::tls;
176177
impl Deref for Packet {
@@ -250,7 +251,7 @@ impl<T: Read + Write> SwitchableConn<T> {
250251
SwitchableConn(Some(EitherConn::Plain(rw)))
251252
}
252253

253-
pub fn switch_to_tls(&mut self, config: &rustls::ServerConfig) -> io::Result<()> {
254+
pub fn switch_to_tls(&mut self, config: Arc<ServerConfig>) -> io::Result<()> {
254255
let replacement = match self.0.take() {
255256
Some(EitherConn::Plain(plain)) => {
256257
Ok(EitherConn::TLS(tls::create_stream(plain, config)?))

src/tls.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ use std::io;
22
use std::io::{Read, Write};
33
use std::sync::Arc;
44

5-
use rustls::{self, ServerConnection};
5+
use rustls::{self, ServerConfig, ServerConnection};
66

77
pub fn create_stream<T: Read + Write + Sized>(
88
sock: T,
9-
config: &rustls::ServerConfig,
9+
config: Arc<ServerConfig>,
1010
) -> Result<rustls::StreamOwned<ServerConnection, T>, io::Error> {
11-
let conn = ServerConnection::new(Arc::new(config.clone())).unwrap();
11+
let conn = ServerConnection::new(config).unwrap();
1212
let stream = rustls::StreamOwned { conn, sock };
1313
Ok(stream)
1414
}

tests/main.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use rustls::ServerConfig;
1515
use std::error::Error;
1616
use std::io;
1717
use std::net;
18+
use std::sync::Arc;
1819
use std::thread;
1920

2021
use msql_srv::{
@@ -29,7 +30,7 @@ struct TestingShim<Q, P, E, I> {
2930
on_p: P,
3031
on_e: E,
3132
on_i: I,
32-
server_tls: Option<rustls::ServerConfig>,
33+
server_tls: Option<Arc<rustls::ServerConfig>>,
3334
client_tls: Option<SslOpts>,
3435
}
3536

@@ -74,8 +75,8 @@ where
7475
(self.on_q)(query, results)
7576
}
7677

77-
fn tls_config(&self) -> Option<&rustls::ServerConfig> {
78-
self.server_tls.as_ref()
78+
fn tls_config(&self) -> Option<Arc<rustls::ServerConfig>> {
79+
self.server_tls.as_ref().map(Arc::clone)
7980
}
8081
}
8182

@@ -114,7 +115,7 @@ where
114115
fn with_server_tls(mut self) -> Self {
115116
let cert = generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
116117

117-
self.server_tls = Some(
118+
self.server_tls = Some(Arc::new(
118119
ServerConfig::builder()
119120
.with_safe_defaults()
120121
.with_no_client_auth()
@@ -123,7 +124,7 @@ where
123124
PrivateKey(cert.get_key_pair().serialize_der()),
124125
)
125126
.unwrap(),
126-
);
127+
));
127128

128129
self
129130
}

0 commit comments

Comments
 (0)