71
71
#![ doc( test( attr( deny( warnings) ) ) ) ]
72
72
#![ doc( test( attr( allow( dead_code, unused_variables, unused_mut) ) ) ) ]
73
73
74
- use anyhow:: { Context , Result } ;
74
+ use anyhow:: Result ;
75
75
use bytes:: Bytes ;
76
76
use rustls:: pki_types:: ServerName ;
77
77
use std:: io;
@@ -88,6 +88,7 @@ use wasmtime_wasi::OutputStream;
88
88
use wasmtime_wasi:: {
89
89
async_trait,
90
90
bindings:: io:: {
91
+ error:: Error as HostIoError ,
91
92
poll:: Pollable as HostPollable ,
92
93
streams:: { InputStream as BoxInputStream , OutputStream as BoxOutputStream } ,
93
94
} ,
@@ -149,6 +150,57 @@ pub fn add_to_linker<T: Send>(
149
150
generated:: types:: add_to_linker_get_host ( l, & opts, f) ?;
150
151
Ok ( ( ) )
151
152
}
153
+
154
+ enum TlsError {
155
+ /// The component should trap. Under normal circumstances, this only occurs
156
+ /// when the underlying transport stream returns [`StreamError::Trap`].
157
+ Trap ( anyhow:: Error ) ,
158
+
159
+ /// A failure indicated by the underlying transport stream as
160
+ /// [`StreamError::LastOperationFailed`].
161
+ Io ( wasmtime_wasi:: IoError ) ,
162
+
163
+ /// A TLS protocol error occurred.
164
+ Tls ( rustls:: Error ) ,
165
+ }
166
+
167
+ impl TlsError {
168
+ /// Create a [`TlsError::Tls`] error from a simple message.
169
+ fn msg ( msg : & str ) -> Self {
170
+ // (Ab)using rustls' error type to synthesize our own TLS errors:
171
+ Self :: Tls ( rustls:: Error :: General ( msg. to_string ( ) ) )
172
+ }
173
+ }
174
+
175
+ impl From < io:: Error > for TlsError {
176
+ fn from ( error : io:: Error ) -> Self {
177
+ // Report unexpected EOFs as an error to prevent truncation attacks.
178
+ // See: https://docs.rs/rustls/latest/rustls/struct.Reader.html#method.read
179
+ if let io:: ErrorKind :: WriteZero | io:: ErrorKind :: UnexpectedEof = error. kind ( ) {
180
+ return Self :: msg ( "underlying transport closed abruptly" ) ;
181
+ }
182
+
183
+ // Errors from underlying transport.
184
+ // These have been wrapped inside `io::Error`s by our wasi-to-tokio stream transformer below.
185
+ let error = match error. downcast :: < StreamError > ( ) {
186
+ Ok ( StreamError :: LastOperationFailed ( e) ) => return Self :: Io ( e) ,
187
+ Ok ( StreamError :: Trap ( e) ) => return Self :: Trap ( e) ,
188
+ Ok ( StreamError :: Closed ) => unreachable ! ( "our wasi-to-tokio stream transformer should have translated this to a 0-sized read" ) ,
189
+ Err ( e) => e,
190
+ } ;
191
+
192
+ // Errors from `rustls`.
193
+ // These have been wrapped inside `io::Error`s by `tokio-rustls`.
194
+ let error = match error. downcast :: < rustls:: Error > ( ) {
195
+ Ok ( e) => return Self :: Tls ( e) ,
196
+ Err ( e) => e,
197
+ } ;
198
+
199
+ // All errors should have been handled by the clauses above.
200
+ Self :: Trap ( anyhow:: Error :: new ( error) . context ( "unknown wasi-tls error" ) )
201
+ }
202
+ }
203
+
152
204
/// Represents the ClientHandshake which will be used to configure the handshake
153
205
pub struct ClientHandShake {
154
206
server_name : String ,
@@ -180,16 +232,17 @@ impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> {
180
232
let handshake = self . table . delete ( this) ?;
181
233
let server_name = handshake. server_name ;
182
234
let streams = handshake. streams ;
183
- let domain = ServerName :: try_from ( server_name) ?;
184
235
185
236
Ok ( self
186
237
. table
187
238
. push ( FutureStreams ( StreamState :: Pending ( Box :: pin ( async move {
188
- let connector = tokio_rustls:: TlsConnector :: from ( default_client_config ( ) ) ;
189
- connector
239
+ let domain = ServerName :: try_from ( server_name)
240
+ . map_err ( |_| TlsError :: msg ( "invalid server name" ) ) ?;
241
+
242
+ let stream = tokio_rustls:: TlsConnector :: from ( default_client_config ( ) )
190
243
. connect ( domain, streams)
191
- . await
192
- . with_context ( || "connection failed" )
244
+ . await ? ;
245
+ Ok ( stream )
193
246
} ) ) ) ) ?)
194
247
}
195
248
@@ -203,7 +256,7 @@ impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> {
203
256
}
204
257
205
258
/// Future streams provides the tls streams after the handshake is completed
206
- pub struct FutureStreams < T > ( StreamState < Result < T > > ) ;
259
+ pub struct FutureStreams < T > ( StreamState < Result < T , TlsError > > ) ;
207
260
208
261
/// Library specific version of TLS connection after the handshake is completed.
209
262
/// This alias allows it to use with wit-bindgen component generator which won't take generic types
@@ -239,30 +292,36 @@ impl<'a> generated::types::HostFutureClientStreams for WasiTlsCtx<'a> {
239
292
Resource < BoxInputStream > ,
240
293
Resource < BoxOutputStream > ,
241
294
) ,
242
- ( ) ,
295
+ Resource < HostIoError > ,
243
296
> ,
244
297
( ) ,
245
298
> ,
246
299
> ,
247
300
> {
248
- {
249
- let this = self . table . get ( & this) ?;
250
- match & this. 0 {
251
- StreamState :: Pending ( _) => return Ok ( None ) ,
252
- StreamState :: Ready ( Ok ( _) ) => ( ) ,
253
- StreamState :: Ready ( Err ( _) ) => {
254
- return Ok ( Some ( Ok ( Err ( ( ) ) ) ) ) ;
255
- }
256
- StreamState :: Closed => return Ok ( Some ( Err ( ( ) ) ) ) ,
257
- }
301
+ let this = & mut self . table . get_mut ( & this) ?. 0 ;
302
+ match this {
303
+ StreamState :: Pending ( _) => return Ok ( None ) ,
304
+ StreamState :: Closed => return Ok ( Some ( Err ( ( ) ) ) ) ,
305
+ StreamState :: Ready ( _) => ( ) ,
258
306
}
259
307
260
- let StreamState :: Ready ( Ok ( tls_stream) ) =
261
- mem:: replace ( & mut self . table . get_mut ( & this) ?. 0 , StreamState :: Closed )
262
- else {
308
+ let StreamState :: Ready ( result) = mem:: replace ( this, StreamState :: Closed ) else {
263
309
unreachable ! ( )
264
310
} ;
265
311
312
+ let tls_stream = match result {
313
+ Ok ( s) => s,
314
+ Err ( TlsError :: Trap ( e) ) => return Err ( e) ,
315
+ Err ( TlsError :: Io ( e) ) => {
316
+ let error = self . table . push ( e) ?;
317
+ return Ok ( Some ( Ok ( Err ( error) ) ) ) ;
318
+ }
319
+ Err ( TlsError :: Tls ( e) ) => {
320
+ let error = self . table . push ( wasmtime_wasi:: IoError :: new ( e) ) ?;
321
+ return Ok ( Some ( Ok ( Err ( error) ) ) ) ;
322
+ }
323
+ } ;
324
+
266
325
let ( rx, tx) = tokio:: io:: split ( tls_stream) ;
267
326
let write_stream = AsyncTlsWriteStream :: new ( TlsWriter :: new ( tx) ) ;
268
327
let client = ClientConnection {
@@ -347,15 +406,15 @@ impl AsyncWrite for WasiStreams {
347
406
return match output. write ( Bytes :: copy_from_slice ( & buf[ ..count] ) ) {
348
407
Ok ( ( ) ) => Poll :: Ready ( Ok ( count) ) ,
349
408
Err ( StreamError :: Closed ) => Poll :: Ready ( Ok ( 0 ) ) ,
350
- Err ( StreamError :: LastOperationFailed ( e) | StreamError :: Trap ( e) ) => {
351
- Poll :: Ready ( Err ( std:: io:: Error :: other ( e) ) )
352
- }
409
+ Err ( e) => Poll :: Ready ( Err ( std:: io:: Error :: other ( e) ) ) ,
353
410
} ;
354
411
}
355
- Err ( StreamError :: Closed ) => return Poll :: Ready ( Ok ( 0 ) ) ,
356
- Err ( StreamError :: LastOperationFailed ( e) | StreamError :: Trap ( e) ) => {
357
- return Poll :: Ready ( Err ( std:: io:: Error :: other ( e) ) )
412
+ Err ( StreamError :: Closed ) => {
413
+ // Our current version of tokio-rustls does not handle returning `Ok(0)` well.
414
+ // See: https://github.com/rustls/tokio-rustls/issues/92
415
+ return Poll :: Ready ( Err ( std:: io:: ErrorKind :: WriteZero . into ( ) ) ) ;
358
416
}
417
+ Err ( e) => return Poll :: Ready ( Err ( std:: io:: Error :: other ( e) ) ) ,
359
418
} ;
360
419
}
361
420
}
@@ -621,7 +680,8 @@ mod tests {
621
680
let ( tx1, rx1) = oneshot:: channel :: < ( ) > ( ) ;
622
681
623
682
let mut future_streams = FutureStreams ( StreamState :: Pending ( Box :: pin ( async move {
624
- rx1. await . map_err ( |_| anyhow:: anyhow!( "oneshot canceled" ) )
683
+ rx1. await
684
+ . map_err ( |_| TlsError :: Trap ( anyhow:: anyhow!( "oneshot canceled" ) ) )
625
685
} ) ) ) ;
626
686
627
687
let mut fut = future_streams. ready ( ) ;
0 commit comments