Skip to content

Commit 0ca6d4f

Browse files
authored
Expose Wasi-TLS handshake error (#10429)
* Expose TLS errors to the guest. * Add a test to check the non-happy flow and verify that the error is properly propagated into the guest. * Update WIT versions
1 parent cb7d881 commit 0ca6d4f

File tree

7 files changed

+163
-61
lines changed

7 files changed

+163
-61
lines changed

ci/vendor-wit.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ make_vendor "wasi-http" "
6161

6262
make_vendor "wasi-tls" "
6363
64-
64+
6565
"
6666

6767
make_vendor "wasi-config" "config@f4d699b"
Lines changed: 59 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,96 @@
1-
use anyhow::{Context, Result};
1+
use anyhow::{anyhow, Context, Result};
22
use core::str;
3-
use test_programs::wasi::sockets::network::{IpSocketAddress, Network};
3+
use test_programs::wasi::sockets::network::{IpAddress, IpSocketAddress, Network};
44
use test_programs::wasi::sockets::tcp::{ShutdownType, TcpSocket};
55
use test_programs::wasi::tls::types::ClientHandshake;
66

7-
fn make_tls_request(domain: &str) -> Result<String> {
8-
const PORT: u16 = 443;
7+
const PORT: u16 = 443;
98

9+
fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> {
1010
let request =
1111
format!("GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\n\r\n");
1212

1313
let net = Network::default();
1414

15-
let Some(ip) = net
16-
.permissive_blocking_resolve_addresses(domain)
17-
.unwrap()
18-
.first()
19-
.map(|a| a.to_owned())
20-
else {
21-
return Err(anyhow::anyhow!("DNS lookup failed."));
22-
};
23-
2415
let socket = TcpSocket::new(ip.family()).unwrap();
2516
let (tcp_input, tcp_output) = socket
2617
.blocking_connect(&net, IpSocketAddress::new(ip, PORT))
27-
.context("failed to connect")?;
18+
.context("tcp connect failed")?;
2819

2920
let (client_connection, tls_input, tls_output) =
3021
ClientHandshake::new(domain, tcp_input, tcp_output)
3122
.blocking_finish()
32-
.map_err(|_| anyhow::anyhow!("failed to finish handshake"))?;
23+
.context("tls handshake failed")?;
3324

34-
tls_output.blocking_write_util(request.as_bytes()).unwrap();
25+
tls_output
26+
.blocking_write_util(request.as_bytes())
27+
.context("writing http request failed")?;
3528
client_connection
3629
.blocking_close_output(&tls_output)
37-
.map_err(|_| anyhow::anyhow!("failed to close tls connection"))?;
30+
.context("closing tls connection failed")?;
3831
socket.shutdown(ShutdownType::Send)?;
3932
let response = tls_input
4033
.blocking_read_to_end()
41-
.map_err(|_| anyhow::anyhow!("failed to read output"))?;
42-
String::from_utf8(response).context("error converting response")
34+
.context("reading http response failed")?;
35+
36+
if String::from_utf8(response)?.contains("HTTP/1.1 200 OK") {
37+
Ok(())
38+
} else {
39+
Err(anyhow!("server did not respond with 200 OK"))
40+
}
4341
}
4442

45-
fn test_tls_sample_application() {
46-
// since this is testing remote endpoint to ensure system cert store works
43+
/// This test sets up a TCP connection using one domain, and then attempts to
44+
/// perform a TLS handshake using another unrelated domain. This should result
45+
/// in a handshake error.
46+
fn test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()> {
47+
const BAD_DOMAIN: &'static str = "wrongdomain.localhost";
48+
49+
let net = Network::default();
50+
51+
let socket = TcpSocket::new(ip.family()).unwrap();
52+
let (tcp_input, tcp_output) = socket
53+
.blocking_connect(&net, IpSocketAddress::new(ip, PORT))
54+
.context("tcp connect failed")?;
55+
56+
match ClientHandshake::new(BAD_DOMAIN, tcp_input, tcp_output).blocking_finish() {
57+
// We're expecting an error regarding the "certificate" is some form or
58+
// another. When we add more TLS backends other than rustls, this naive
59+
// check will likely need to be revisited/expanded:
60+
Err(e) if e.to_debug_string().contains("certificate") => Ok(()),
61+
62+
Err(e) => Err(e.into()),
63+
Ok(_) => panic!("expecting server name mismatch"),
64+
}
65+
}
66+
67+
fn try_live_endpoints(test: impl Fn(&str, IpAddress) -> Result<()>) {
68+
// since this is testing remote endpoints to ensure system cert store works
4769
// the test uses a couple different endpoints to reduce the number of flakes
4870
const DOMAINS: &'static [&'static str] = &["example.com", "api.github.com"];
4971

72+
let net = Network::default();
73+
5074
for &domain in DOMAINS {
51-
match make_tls_request(domain) {
52-
Ok(r) => {
53-
assert!(r.contains("HTTP/1.1 200 OK"));
54-
return;
55-
}
75+
let lookup = net
76+
.permissive_blocking_resolve_addresses(domain)
77+
.unwrap()
78+
.first()
79+
.map(|a| a.to_owned())
80+
.ok_or_else(|| anyhow!("DNS lookup failed."));
81+
82+
match lookup.and_then(|ip| test(&domain, ip)) {
83+
Ok(()) => return,
5684
Err(e) => {
57-
eprintln!("Failed to make TLS request to {domain}: {e}");
85+
eprintln!("test for {domain} failed: {e:#}");
5886
}
5987
}
6088
}
61-
panic!("All TLS requests failed.");
89+
90+
panic!("all tests failed");
6291
}
6392

6493
fn main() {
65-
test_tls_sample_application();
94+
try_live_endpoints(test_tls_sample_application);
95+
try_live_endpoints(test_tls_invalid_certificate);
6696
}

crates/test-programs/src/lib.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,11 @@ pub mod proxy {
4848
},
4949
});
5050
}
51+
52+
impl std::fmt::Display for wasi::io::error::Error {
53+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54+
f.write_str(&self.to_debug_string())
55+
}
56+
}
57+
58+
impl std::error::Error for wasi::io::error::Error {}

crates/test-programs/src/tls.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use crate::wasi::clocks::monotonic_clock;
2+
use crate::wasi::io::error::Error as IoError;
23
use crate::wasi::io::streams::StreamError;
34
use crate::wasi::tls::types::{ClientConnection, ClientHandshake, InputStream, OutputStream};
45

56
const TIMEOUT_NS: u64 = 1_000_000_000;
67

78
impl ClientHandshake {
8-
pub fn blocking_finish(self) -> Result<(ClientConnection, InputStream, OutputStream), ()> {
9+
pub fn blocking_finish(self) -> Result<(ClientConnection, InputStream, OutputStream), IoError> {
910
let future = ClientHandshake::finish(self);
1011
let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS * 200);
1112
let pollable = future.subscribe();

crates/wasi-tls/src/lib.rs

Lines changed: 88 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
#![doc(test(attr(deny(warnings))))]
7272
#![doc(test(attr(allow(dead_code, unused_variables, unused_mut))))]
7373

74-
use anyhow::{Context, Result};
74+
use anyhow::Result;
7575
use bytes::Bytes;
7676
use rustls::pki_types::ServerName;
7777
use std::io;
@@ -88,6 +88,7 @@ use wasmtime_wasi::OutputStream;
8888
use wasmtime_wasi::{
8989
async_trait,
9090
bindings::io::{
91+
error::Error as HostIoError,
9192
poll::Pollable as HostPollable,
9293
streams::{InputStream as BoxInputStream, OutputStream as BoxOutputStream},
9394
},
@@ -149,6 +150,57 @@ pub fn add_to_linker<T: Send>(
149150
generated::types::add_to_linker_get_host(l, &opts, f)?;
150151
Ok(())
151152
}
153+
154+
enum TlsError {
155+
/// The component should trap. Under normal circumstances, this only occurs
156+
/// when the underlying transport stream returns [`StreamError::Trap`].
157+
Trap(anyhow::Error),
158+
159+
/// A failure indicated by the underlying transport stream as
160+
/// [`StreamError::LastOperationFailed`].
161+
Io(wasmtime_wasi::IoError),
162+
163+
/// A TLS protocol error occurred.
164+
Tls(rustls::Error),
165+
}
166+
167+
impl TlsError {
168+
/// Create a [`TlsError::Tls`] error from a simple message.
169+
fn msg(msg: &str) -> Self {
170+
// (Ab)using rustls' error type to synthesize our own TLS errors:
171+
Self::Tls(rustls::Error::General(msg.to_string()))
172+
}
173+
}
174+
175+
impl From<io::Error> for TlsError {
176+
fn from(error: io::Error) -> Self {
177+
// Report unexpected EOFs as an error to prevent truncation attacks.
178+
// See: https://docs.rs/rustls/latest/rustls/struct.Reader.html#method.read
179+
if let io::ErrorKind::WriteZero | io::ErrorKind::UnexpectedEof = error.kind() {
180+
return Self::msg("underlying transport closed abruptly");
181+
}
182+
183+
// Errors from underlying transport.
184+
// These have been wrapped inside `io::Error`s by our wasi-to-tokio stream transformer below.
185+
let error = match error.downcast::<StreamError>() {
186+
Ok(StreamError::LastOperationFailed(e)) => return Self::Io(e),
187+
Ok(StreamError::Trap(e)) => return Self::Trap(e),
188+
Ok(StreamError::Closed) => unreachable!("our wasi-to-tokio stream transformer should have translated this to a 0-sized read"),
189+
Err(e) => e,
190+
};
191+
192+
// Errors from `rustls`.
193+
// These have been wrapped inside `io::Error`s by `tokio-rustls`.
194+
let error = match error.downcast::<rustls::Error>() {
195+
Ok(e) => return Self::Tls(e),
196+
Err(e) => e,
197+
};
198+
199+
// All errors should have been handled by the clauses above.
200+
Self::Trap(anyhow::Error::new(error).context("unknown wasi-tls error"))
201+
}
202+
}
203+
152204
/// Represents the ClientHandshake which will be used to configure the handshake
153205
pub struct ClientHandShake {
154206
server_name: String,
@@ -180,16 +232,17 @@ impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> {
180232
let handshake = self.table.delete(this)?;
181233
let server_name = handshake.server_name;
182234
let streams = handshake.streams;
183-
let domain = ServerName::try_from(server_name)?;
184235

185236
Ok(self
186237
.table
187238
.push(FutureStreams(StreamState::Pending(Box::pin(async move {
188-
let connector = tokio_rustls::TlsConnector::from(default_client_config());
189-
connector
239+
let domain = ServerName::try_from(server_name)
240+
.map_err(|_| TlsError::msg("invalid server name"))?;
241+
242+
let stream = tokio_rustls::TlsConnector::from(default_client_config())
190243
.connect(domain, streams)
191-
.await
192-
.with_context(|| "connection failed")
244+
.await?;
245+
Ok(stream)
193246
}))))?)
194247
}
195248

@@ -203,7 +256,7 @@ impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> {
203256
}
204257

205258
/// Future streams provides the tls streams after the handshake is completed
206-
pub struct FutureStreams<T>(StreamState<Result<T>>);
259+
pub struct FutureStreams<T>(StreamState<Result<T, TlsError>>);
207260

208261
/// Library specific version of TLS connection after the handshake is completed.
209262
/// This alias allows it to use with wit-bindgen component generator which won't take generic types
@@ -239,30 +292,36 @@ impl<'a> generated::types::HostFutureClientStreams for WasiTlsCtx<'a> {
239292
Resource<BoxInputStream>,
240293
Resource<BoxOutputStream>,
241294
),
242-
(),
295+
Resource<HostIoError>,
243296
>,
244297
(),
245298
>,
246299
>,
247300
> {
248-
{
249-
let this = self.table.get(&this)?;
250-
match &this.0 {
251-
StreamState::Pending(_) => return Ok(None),
252-
StreamState::Ready(Ok(_)) => (),
253-
StreamState::Ready(Err(_)) => {
254-
return Ok(Some(Ok(Err(()))));
255-
}
256-
StreamState::Closed => return Ok(Some(Err(()))),
257-
}
301+
let this = &mut self.table.get_mut(&this)?.0;
302+
match this {
303+
StreamState::Pending(_) => return Ok(None),
304+
StreamState::Closed => return Ok(Some(Err(()))),
305+
StreamState::Ready(_) => (),
258306
}
259307

260-
let StreamState::Ready(Ok(tls_stream)) =
261-
mem::replace(&mut self.table.get_mut(&this)?.0, StreamState::Closed)
262-
else {
308+
let StreamState::Ready(result) = mem::replace(this, StreamState::Closed) else {
263309
unreachable!()
264310
};
265311

312+
let tls_stream = match result {
313+
Ok(s) => s,
314+
Err(TlsError::Trap(e)) => return Err(e),
315+
Err(TlsError::Io(e)) => {
316+
let error = self.table.push(e)?;
317+
return Ok(Some(Ok(Err(error))));
318+
}
319+
Err(TlsError::Tls(e)) => {
320+
let error = self.table.push(wasmtime_wasi::IoError::new(e))?;
321+
return Ok(Some(Ok(Err(error))));
322+
}
323+
};
324+
266325
let (rx, tx) = tokio::io::split(tls_stream);
267326
let write_stream = AsyncTlsWriteStream::new(TlsWriter::new(tx));
268327
let client = ClientConnection {
@@ -347,15 +406,15 @@ impl AsyncWrite for WasiStreams {
347406
return match output.write(Bytes::copy_from_slice(&buf[..count])) {
348407
Ok(()) => Poll::Ready(Ok(count)),
349408
Err(StreamError::Closed) => Poll::Ready(Ok(0)),
350-
Err(StreamError::LastOperationFailed(e) | StreamError::Trap(e)) => {
351-
Poll::Ready(Err(std::io::Error::other(e)))
352-
}
409+
Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
353410
};
354411
}
355-
Err(StreamError::Closed) => return Poll::Ready(Ok(0)),
356-
Err(StreamError::LastOperationFailed(e) | StreamError::Trap(e)) => {
357-
return Poll::Ready(Err(std::io::Error::other(e)))
412+
Err(StreamError::Closed) => {
413+
// Our current version of tokio-rustls does not handle returning `Ok(0)` well.
414+
// See: https://github.com/rustls/tokio-rustls/issues/92
415+
return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into()));
358416
}
417+
Err(e) => return Poll::Ready(Err(std::io::Error::other(e))),
359418
};
360419
}
361420
}
@@ -621,7 +680,8 @@ mod tests {
621680
let (tx1, rx1) = oneshot::channel::<()>();
622681

623682
let mut future_streams = FutureStreams(StreamState::Pending(Box::pin(async move {
624-
rx1.await.map_err(|_| anyhow::anyhow!("oneshot canceled"))
683+
rx1.await
684+
.map_err(|_| TlsError::Trap(anyhow::anyhow!("oneshot canceled")))
625685
})));
626686

627687
let mut fut = future_streams.ready();

crates/wasi-tls/wit/deps/tls/types.wit

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ interface types {
44
use wasi:io/streams@0.2.3.{input-stream, output-stream};
55
@unstable(feature = tls)
66
use wasi:io/poll@0.2.3.{pollable};
7+
@unstable(feature = tls)
8+
use wasi:io/error@0.2.3.{error as io-error};
79

810
@unstable(feature = tls)
911
resource client-handshake {
@@ -26,6 +28,6 @@ interface types {
2628
subscribe: func() -> pollable;
2729

2830
@unstable(feature = tls)
29-
get: func() -> option<result<result<tuple<client-connection, input-stream, output-stream>>>>;
31+
get: func() -> option<result<result<tuple<client-connection, input-stream, output-stream>, io-error>>>;
3032
}
3133
}

crates/wasi/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,8 @@ pub use wasmtime::component::{ResourceTable, ResourceTableError};
280280
// users of this crate depend on them at these names.
281281
pub use wasmtime_wasi_io::poll::{subscribe, DynFuture, DynPollable, MakeFuture, Pollable};
282282
pub use wasmtime_wasi_io::streams::{
283-
DynInputStream, DynOutputStream, InputStream, OutputStream, StreamError, StreamResult,
283+
DynInputStream, DynOutputStream, Error as IoError, InputStream, OutputStream, StreamError,
284+
StreamResult,
284285
};
285286
pub use wasmtime_wasi_io::{IoImpl, IoView};
286287

0 commit comments

Comments
 (0)