Skip to content

Commit bad5e57

Browse files
authored
meshtls: Move TLS e2e tests into the meshtls crate (#1366)
We'll want to run the same end-to-end tests against all meshtls backends. This change moves the rustls `tls_accept` tests into meshtls crate.
1 parent 78df81f commit bad5e57

File tree

5 files changed

+84
-62
lines changed

5 files changed

+84
-62
lines changed

Cargo.lock

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,34 +1022,36 @@ name = "linkerd-meshtls"
10221022
version = "0.1.0"
10231023
dependencies = [
10241024
"futures",
1025+
"linkerd-conditional",
10251026
"linkerd-error",
10261027
"linkerd-identity",
10271028
"linkerd-io",
10281029
"linkerd-meshtls-rustls",
1030+
"linkerd-proxy-transport",
10291031
"linkerd-stack",
10301032
"linkerd-tls",
1033+
"linkerd-tls-test-util",
1034+
"linkerd-tracing",
10311035
"pin-project",
1036+
"tokio",
1037+
"tracing",
10321038
]
10331039

10341040
[[package]]
10351041
name = "linkerd-meshtls-rustls"
10361042
version = "0.1.0"
10371043
dependencies = [
10381044
"futures",
1039-
"linkerd-conditional",
10401045
"linkerd-error",
10411046
"linkerd-identity",
10421047
"linkerd-io",
1043-
"linkerd-proxy-transport",
10441048
"linkerd-stack",
10451049
"linkerd-tls",
10461050
"linkerd-tls-test-util",
1047-
"linkerd-tracing",
10481051
"ring",
10491052
"thiserror",
10501053
"tokio",
10511054
"tokio-rustls",
1052-
"tower",
10531055
"tracing",
10541056
"webpki",
10551057
]

linkerd/meshtls/Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,11 @@ linkerd-meshtls-rustls = { path = "rustls", optional = true }
2020
linkerd-stack = { path = "../stack" }
2121
linkerd-tls = { path = "../tls" }
2222
pin-project = "1"
23+
24+
[dev-dependencies]
25+
linkerd-conditional = { path = "../conditional" }
26+
linkerd-proxy-transport = { path = "../proxy/transport" }
27+
linkerd-tls-test-util = { path = "../tls/test-util" }
28+
linkerd-tracing = { path = "../tracing", features = ["ansi"] }
29+
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
30+
tracing = "0.1"

linkerd/meshtls/rustls/Cargo.toml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,4 @@ tracing = "0.1"
2525
webpki = "0.21"
2626

2727
[dev-dependencies]
28-
linkerd-conditional = { path = "../../conditional" }
29-
linkerd-proxy-transport = { path = "../../proxy/transport" }
3028
linkerd-tls-test-util = { path = "../../tls/test-util" }
31-
linkerd-tracing = { path = "../../tracing", features = ["ansi"] }
32-
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
33-
tower = { version = "0.4.10", default-features = false, features = ["util"] }

linkerd/meshtls/tests/rustls.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#![cfg(feature = "rustls")]
2+
#![deny(warnings, rust_2018_idioms)]
3+
#![forbid(unsafe_code)]
4+
5+
mod util;
6+
7+
use linkerd_meshtls::Mode;
8+
9+
#[tokio::test(flavor = "current_thread")]
10+
async fn plaintext() {
11+
util::plaintext(Mode::Rustls).await;
12+
}
13+
14+
#[tokio::test(flavor = "current_thread")]
15+
async fn proxy_to_proxy_tls_works() {
16+
util::proxy_to_proxy_tls_works(Mode::Rustls).await;
17+
}
18+
19+
#[tokio::test(flavor = "current_thread")]
20+
async fn proxy_to_proxy_tls_pass_through_when_identity_does_not_match() {
21+
util::proxy_to_proxy_tls_pass_through_when_identity_does_not_match(Mode::Rustls).await;
22+
}

linkerd/meshtls/rustls/tests/tls_accept.rs renamed to linkerd/meshtls/tests/util.rs

Lines changed: 48 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,29 @@
1-
#![cfg(test)]
2-
3-
// These are basically integration tests for the `connection` submodule, but
4-
// they cannot be "real" integration tests because `connection` isn't a public
5-
// interface and because `connection` exposes a `#[cfg(test)]`-only API for use
6-
// by these tests.
1+
#![deny(warnings)]
2+
#![forbid(unsafe_code)]
73

84
use futures::prelude::*;
95
use linkerd_conditional::Conditional;
106
use linkerd_error::Infallible;
117
use linkerd_identity::{Credentials, DerX509, Name};
128
use linkerd_io::{self as io, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
13-
use linkerd_meshtls_rustls as meshtls;
9+
use linkerd_meshtls as meshtls;
1410
use linkerd_proxy_transport::{
1511
addrs::*,
1612
listen::{Addrs, Bind, BindTcp},
1713
ConnectTcp, Keepalive, ListenAddr,
1814
};
19-
use linkerd_stack::{ExtractParam, InsertParam, NewService, Param};
15+
use linkerd_stack::{
16+
layer::Layer, service_fn, ExtractParam, InsertParam, NewService, Param, ServiceExt,
17+
};
2018
use linkerd_tls as tls;
2119
use linkerd_tls_test_util as test_util;
2220
use std::{future::Future, net::SocketAddr, sync::mpsc, time::Duration};
2321
use tokio::net::TcpStream;
24-
use tower::{
25-
layer::Layer,
26-
util::{service_fn, ServiceExt},
27-
};
28-
use tracing::instrument::Instrument;
29-
30-
type ServerConn<T, I> = (
31-
(tls::ConditionalServerTls, T),
32-
io::EitherIo<meshtls::ServerIo<tls::server::DetectIo<I>>, tls::server::DetectIo<I>>,
33-
);
34-
35-
fn load(ent: &test_util::Entity) -> (meshtls::creds::Store, meshtls::NewClient, meshtls::Server) {
36-
let roots_pem = std::str::from_utf8(ent.trust_anchors).expect("valid PEM");
37-
let (mut store, rx) = meshtls::creds::watch(
38-
ent.name.parse().unwrap(),
39-
roots_pem,
40-
ent.key,
41-
b"fake CSR data",
42-
)
43-
.expect("credentials must be readable");
22+
use tracing::Instrument;
4423

45-
let expiry = std::time::SystemTime::now() + Duration::from_secs(600);
46-
store
47-
.set_certificate(DerX509(ent.crt.to_vec()), vec![], expiry)
48-
.expect("certificate must be valid");
49-
50-
(store, rx.new_client(), rx.server())
51-
}
52-
53-
#[tokio::test(flavor = "current_thread")]
54-
async fn plaintext() {
55-
let (_foo, _, server_tls) = load(&test_util::FOO_NS1);
56-
let (_bar, client_tls, _) = load(&test_util::BAR_NS1);
24+
pub async fn plaintext(mode: meshtls::Mode) {
25+
let (_foo, _, server_tls) = load(mode, &test_util::FOO_NS1);
26+
let (_bar, client_tls, _) = load(mode, &test_util::BAR_NS1);
5727
let (client_result, server_result) = run_test(
5828
client_tls,
5929
Conditional::None(tls::NoClientTls::NotProvidedByServiceDiscovery),
@@ -76,10 +46,9 @@ async fn plaintext() {
7646
assert_eq!(&server_result.result.expect("ping")[..], PING);
7747
}
7848

79-
#[tokio::test(flavor = "current_thread")]
80-
async fn proxy_to_proxy_tls_works() {
81-
let (_foo, _, server_tls) = load(&test_util::FOO_NS1);
82-
let (_bar, client_tls, _) = load(&test_util::BAR_NS1);
49+
pub async fn proxy_to_proxy_tls_works(mode: meshtls::Mode) {
50+
let (_foo, _, server_tls) = load(mode, &test_util::FOO_NS1);
51+
let (_bar, client_tls, _) = load(mode, &test_util::BAR_NS1);
8352
let server_id = tls::ServerId(test_util::FOO_NS1.name.parse().unwrap());
8453
let (client_result, server_result) = run_test(
8554
client_tls.clone(),
@@ -107,13 +76,12 @@ async fn proxy_to_proxy_tls_works() {
10776
assert_eq!(&server_result.result.expect("ping")[..], PING);
10877
}
10978

110-
#[tokio::test(flavor = "current_thread")]
111-
async fn proxy_to_proxy_tls_pass_through_when_identity_does_not_match() {
112-
let (_foo, _, server_tls) = load(&test_util::FOO_NS1);
79+
pub async fn proxy_to_proxy_tls_pass_through_when_identity_does_not_match(mode: meshtls::Mode) {
80+
let (_foo, _, server_tls) = load(mode, &test_util::FOO_NS1);
11381

11482
// Misuse the client's identity instead of the server's identity. Any
11583
// identity other than `server_tls.server_identity` would work.
116-
let (_bar, client_tls, _) = load(&test_util::BAR_NS1);
84+
let (_bar, client_tls, _) = load(mode, &test_util::BAR_NS1);
11785
let sni = test_util::BAR_NS1.name.parse::<Name>().unwrap();
11886

11987
let (client_result, server_result) = run_test(
@@ -138,6 +106,33 @@ async fn proxy_to_proxy_tls_pass_through_when_identity_does_not_match() {
138106
assert_eq!(&server_result.result.unwrap()[..], START_OF_TLS);
139107
}
140108

109+
type ServerConn<T, I> = (
110+
(tls::ConditionalServerTls, T),
111+
io::EitherIo<meshtls::ServerIo<tls::server::DetectIo<I>>, tls::server::DetectIo<I>>,
112+
);
113+
114+
fn load(
115+
mode: meshtls::Mode,
116+
ent: &test_util::Entity,
117+
) -> (meshtls::creds::Store, meshtls::NewClient, meshtls::Server) {
118+
let roots_pem = std::str::from_utf8(ent.trust_anchors).expect("valid PEM");
119+
let (mut store, rx) = mode
120+
.watch(
121+
ent.name.parse().unwrap(),
122+
roots_pem,
123+
ent.key,
124+
b"fake CSR data",
125+
)
126+
.expect("credentials must be readable");
127+
128+
let expiry = std::time::SystemTime::now() + Duration::from_secs(600);
129+
store
130+
.set_certificate(DerX509(ent.crt.to_vec()), vec![], expiry)
131+
.expect("certificate must be valid");
132+
133+
(store, rx.new_client(), rx.server())
134+
}
135+
141136
struct Transported<I, R> {
142137
tls: Option<I>,
143138

@@ -150,7 +145,8 @@ struct ServerParams {
150145
identity: meshtls::Server,
151146
}
152147

153-
type ClientIo = io::EitherIo<io::ScopedIo<TcpStream>, meshtls::ClientIo<io::ScopedIo<TcpStream>>>;
148+
type ClientIo =
149+
io::EitherIo<io::ScopedIo<TcpStream>, linkerd_meshtls::ClientIo<io::ScopedIo<TcpStream>>>;
154150

155151
/// Runs a test for a single TCP connection. `client` processes the connection
156152
/// on the client side and `server` processes the connection on the server
@@ -159,7 +155,7 @@ async fn run_test<C, CF, CR, S, SF, SR>(
159155
client_tls: meshtls::NewClient,
160156
client_server_id: Conditional<tls::ServerId, tls::NoClientTls>,
161157
client: C,
162-
server_id: meshtls::Server,
158+
server_tls: meshtls::Server,
163159
server: S,
164160
) -> (
165161
Transported<tls::ConditionalClientTls, CR>,
@@ -184,7 +180,7 @@ where
184180

185181
let detect = tls::NewDetectTls::<meshtls::Server, _, _>::new(
186182
ServerParams {
187-
identity: server_id,
183+
identity: server_tls,
188184
},
189185
move |meta: (tls::ConditionalServerTls, Addrs)| {
190186
let server = server.clone();
@@ -233,7 +229,6 @@ where
233229
// type, e.g. `Arc<Mutex>`, but using a channel simplifies the code and
234230
// parallels the server side.
235231
let (sender, receiver) = mpsc::channel::<Transported<tls::ConditionalClientTls, CR>>();
236-
let sender_clone = sender.clone();
237232

238233
let tls = Some(client_server_id.clone().map(Into::into));
239234
let client = async move {
@@ -243,7 +238,7 @@ where
243238
.await;
244239
match conn {
245240
Err(e) => {
246-
sender_clone
241+
sender
247242
.send(Transported {
248243
tls: None,
249244
result: Err(e),

0 commit comments

Comments
 (0)