diff --git a/Cargo.lock b/Cargo.lock index c80f6e2f2787..22237e318301 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4382,6 +4382,7 @@ dependencies = [ "wasmtime-wasi-keyvalue", "wasmtime-wasi-nn", "wasmtime-wasi-threads", + "wasmtime-wasi-tls", "wasmtime-wast", "wasmtime-wast-util", "wast 227.0.0", @@ -4760,6 +4761,22 @@ dependencies = [ "wasmtime-wasi", ] +[[package]] +name = "wasmtime-wasi-tls" +version = "32.0.0" +dependencies = [ + "anyhow", + "bytes", + "futures", + "rustls 0.22.4", + "test-programs-artifacts", + "tokio", + "tokio-rustls", + "wasmtime", + "wasmtime-wasi", + "webpki-roots", +] + [[package]] name = "wasmtime-wast" version = "32.0.0" diff --git a/Cargo.toml b/Cargo.toml index b01c4c427b20..753756462332 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,6 +54,7 @@ wasi-common = { workspace = true, default-features = true, features = ["exit", " wasmtime-wasi = { workspace = true, default-features = true, optional = true } wasmtime-wasi-nn = { workspace = true, optional = true } wasmtime-wasi-config = { workspace = true, optional = true } +wasmtime-wasi-tls = { workspace = true, optional = true } wasmtime-wasi-keyvalue = { workspace = true, optional = true } wasmtime-wasi-threads = { workspace = true, optional = true } wasmtime-wasi-http = { workspace = true, optional = true } @@ -245,6 +246,7 @@ wasmtime-component-macro = { path = "crates/component-macro", version = "=32.0.0 wasmtime-asm-macros = { path = "crates/asm-macros", version = "=32.0.0" } wasmtime-versioned-export-macros = { path = "crates/versioned-export-macros", version = "=32.0.0" } wasmtime-slab = { path = "crates/slab", version = "=32.0.0" } +wasmtime-wasi-tls = { path = "crates/wasi-tls", version = "32.0.0" } component-test-util = { path = "crates/misc/component-test-util" } component-fuzz-util = { path = "crates/misc/component-fuzz-util" } wiggle = { path = "crates/wiggle", version = "=32.0.0", default-features = false } @@ -377,6 +379,9 @@ libtest-mimic = "0.7.0" semver = { version = "1.0.17", default-features = false } ittapi = "0.4.0" libm = "0.2.7" +tokio-rustls = "0.25.0" +rustls = "0.22.0" +webpki-roots = "0.26.0" # ============================================================================= # @@ -409,6 +414,7 @@ default = [ "wasi-http", "wasi-config", "wasi-keyvalue", + "wasi-tls", # Most features of Wasmtime are enabled by default. "wat", @@ -459,6 +465,7 @@ disable-logging = ["log/max_level_off", "tracing/max_level_off"] # These features are all included in the `default` set above and this is # the internal mapping for what they enable in Wasmtime itself. wasi-nn = ["dep:wasmtime-wasi-nn"] +wasi-tls = ["dep:wasmtime-wasi-tls"] wasi-threads = ["dep:wasmtime-wasi-threads", "threads"] wasi-http = ["component-model", "dep:wasmtime-wasi-http", "dep:tokio", "dep:hyper"] wasi-config = ["dep:wasmtime-wasi-config"] @@ -567,3 +574,4 @@ opt-level = 's' inherits = "release" codegen-units = 1 lto = true + diff --git a/crates/cli-flags/src/lib.rs b/crates/cli-flags/src/lib.rs index bafdf377d940..9079d7f882bb 100644 --- a/crates/cli-flags/src/lib.rs +++ b/crates/cli-flags/src/lib.rs @@ -414,6 +414,8 @@ wasmtime_option_group! { /// Grant access to the given TCP listen socket #[serde(default)] pub tcplisten: Vec, + /// Enable support for WASI TLS (Transport Layer Security) imports (experimental) + pub tls: Option, /// Implement WASI Preview1 using new Preview2 implementation (true, default) or legacy /// implementation (false) pub preview2: Option, diff --git a/crates/test-programs/artifacts/build.rs b/crates/test-programs/artifacts/build.rs index f545ff0d8240..258325ae31ab 100644 --- a/crates/test-programs/artifacts/build.rs +++ b/crates/test-programs/artifacts/build.rs @@ -78,6 +78,7 @@ fn build_and_generate_tests() { s if s.starts_with("dwarf_") => "dwarf", s if s.starts_with("config_") => "config", s if s.starts_with("keyvalue_") => "keyvalue", + s if s.starts_with("tls_") => "tls", // If you're reading this because you hit this panic, either add it // to a test suite above or add a new "suite". The purpose of the // categorization above is to have a static assertion that tests diff --git a/crates/test-programs/src/bin/tls_sample_application.rs b/crates/test-programs/src/bin/tls_sample_application.rs new file mode 100644 index 000000000000..ac427df01527 --- /dev/null +++ b/crates/test-programs/src/bin/tls_sample_application.rs @@ -0,0 +1,48 @@ +use core::str; + +use test_programs::wasi::sockets::network::{IpSocketAddress, Network}; +use test_programs::wasi::sockets::tcp::{ShutdownType, TcpSocket}; +use test_programs::wasi::tls::types::ClientHandshake; + +fn test_tls_sample_application() { + const PORT: u16 = 443; + const DOMAIN: &'static str = "example.com"; + + let request = format!("GET / HTTP/1.1\r\nHost: {DOMAIN}\r\n\r\n"); + + let net = Network::default(); + + let Some(ip) = net + .permissive_blocking_resolve_addresses(DOMAIN) + .unwrap() + .first() + .map(|a| a.to_owned()) + else { + eprintln!("DNS lookup failed."); + return; + }; + + let socket = TcpSocket::new(ip.family()).unwrap(); + let (tcp_input, tcp_output) = socket + .blocking_connect(&net, IpSocketAddress::new(ip, PORT)) + .unwrap(); + + let (client_connection, tls_input, tls_output) = + ClientHandshake::new(DOMAIN, tcp_input, tcp_output) + .blocking_finish() + .unwrap(); + + tls_output.blocking_write_util(request.as_bytes()).unwrap(); + client_connection + .blocking_close_output(&tls_output) + .unwrap(); + socket.shutdown(ShutdownType::Send).unwrap(); + let response = tls_input.blocking_read_to_end().unwrap(); + let response = String::from_utf8(response).unwrap(); + + assert!(response.contains("HTTP/1.1 200 OK")); +} + +fn main() { + test_tls_sample_application(); +} diff --git a/crates/test-programs/src/lib.rs b/crates/test-programs/src/lib.rs index 49301621dbd0..92b2a0fca128 100644 --- a/crates/test-programs/src/lib.rs +++ b/crates/test-programs/src/lib.rs @@ -2,6 +2,7 @@ pub mod http; pub mod nn; pub mod preview1; pub mod sockets; +pub mod tls; wit_bindgen::generate!({ inline: " @@ -12,15 +13,17 @@ wit_bindgen::generate!({ include wasi:http/imports@0.2.3; include wasi:config/imports@0.2.0-draft; include wasi:keyvalue/imports@0.2.0-draft; + include wasi:tls/imports@0.2.0-draft; } ", path: [ "../wasi-http/wit", "../wasi-config/wit", "../wasi-keyvalue/wit", + "../wasi-tls/wit/world.wit", ], world: "wasmtime:test/test", - features: ["cli-exit-with-code"], + features: ["cli-exit-with-code", "tls"], generate_all, }); diff --git a/crates/test-programs/src/tls.rs b/crates/test-programs/src/tls.rs new file mode 100644 index 000000000000..b59cb29948eb --- /dev/null +++ b/crates/test-programs/src/tls.rs @@ -0,0 +1,45 @@ +use crate::wasi::clocks::monotonic_clock; +use crate::wasi::io::streams::StreamError; +use crate::wasi::tls::types::{ClientConnection, ClientHandshake, InputStream, OutputStream}; + +const TIMEOUT_NS: u64 = 1_000_000_000; + +impl ClientHandshake { + pub fn blocking_finish(self) -> Result<(ClientConnection, InputStream, OutputStream), ()> { + let future = ClientHandshake::finish(self); + let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS * 200); + let pollable = future.subscribe(); + + loop { + match future.get() { + None => pollable.block_until(&timeout).expect("timed out"), + Some(Ok(r)) => return r, + Some(Err(e)) => { + eprintln!("{e:?}"); + unimplemented!() + } + } + } + } +} + +impl ClientConnection { + pub fn blocking_close_output( + &self, + output: &OutputStream, + ) -> Result<(), crate::wasi::io::error::Error> { + let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS); + let pollable = output.subscribe(); + + self.close_output(); + + loop { + match output.check_write() { + Ok(0) => pollable.block_until(&timeout).expect("timed out"), + Ok(_) => unreachable!("After calling close_output, the output stream should never accept new writes again."), + Err(StreamError::Closed) => return Ok(()), + Err(StreamError::LastOperationFailed(e)) => return Err(e), + } + } + } +} diff --git a/crates/wasi-http/Cargo.toml b/crates/wasi-http/Cargo.toml index 3339fc1b5743..278dd8f994ea 100644 --- a/crates/wasi-http/Cargo.toml +++ b/crates/wasi-http/Cargo.toml @@ -28,12 +28,9 @@ http-body-util = { workspace = true } tracing = { workspace = true } wasmtime-wasi = { workspace = true } wasmtime = { workspace = true, features = ['component-model'] } - -# The `ring` crate, used to implement TLS, does not build on riscv64 or s390x -[target.'cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))'.dependencies] -tokio-rustls = { version = "0.25.0" } -rustls = { version = "0.22.0" } -webpki-roots = { version = "0.26.0" } +tokio-rustls = { workspace = true } +rustls = { workspace = true } +webpki-roots = { workspace = true } [dev-dependencies] test-programs-artifacts = { workspace = true } diff --git a/crates/wasi-http/src/types.rs b/crates/wasi-http/src/types.rs index 7a7b518db5ee..f510a065c291 100644 --- a/crates/wasi-http/src/types.rs +++ b/crates/wasi-http/src/types.rs @@ -373,58 +373,48 @@ pub async fn default_send_request_handler( })?; let (mut sender, worker) = if use_tls { - #[cfg(any(target_arch = "riscv64", target_arch = "s390x"))] - { - return Err(crate::bindings::http::types::ErrorCode::InternalError( - Some("unsupported architecture for SSL".to_string()), - )); - } + use rustls::pki_types::ServerName; - #[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))] - { - use rustls::pki_types::ServerName; - - // derived from https://github.com/rustls/rustls/blob/main/examples/src/bin/simpleclient.rs - let root_cert_store = rustls::RootCertStore { - roots: webpki_roots::TLS_SERVER_ROOTS.into(), - }; - let config = rustls::ClientConfig::builder() - .with_root_certificates(root_cert_store) - .with_no_client_auth(); - let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config)); - let mut parts = authority.split(":"); - let host = parts.next().unwrap_or(&authority); - let domain = ServerName::try_from(host) - .map_err(|e| { - tracing::warn!("dns lookup error: {e:?}"); - dns_error("invalid dns name".to_string(), 0) - })? - .to_owned(); - let stream = connector.connect(domain, tcp_stream).await.map_err(|e| { - tracing::warn!("tls protocol error: {e:?}"); - types::ErrorCode::TlsProtocolError - })?; - let stream = TokioIo::new(stream); - - let (sender, conn) = timeout( - connect_timeout, - hyper::client::conn::http1::handshake(stream), - ) - .await - .map_err(|_| types::ErrorCode::ConnectionTimeout)? - .map_err(hyper_request_error)?; - - let worker = wasmtime_wasi::runtime::spawn(async move { - match conn.await { - Ok(()) => {} - // TODO: shouldn't throw away this error and ideally should - // surface somewhere. - Err(e) => tracing::warn!("dropping error {e}"), - } - }); + // derived from https://github.com/rustls/rustls/blob/main/examples/src/bin/simpleclient.rs + let root_cert_store = rustls::RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.into(), + }; + let config = rustls::ClientConfig::builder() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config)); + let mut parts = authority.split(":"); + let host = parts.next().unwrap_or(&authority); + let domain = ServerName::try_from(host) + .map_err(|e| { + tracing::warn!("dns lookup error: {e:?}"); + dns_error("invalid dns name".to_string(), 0) + })? + .to_owned(); + let stream = connector.connect(domain, tcp_stream).await.map_err(|e| { + tracing::warn!("tls protocol error: {e:?}"); + types::ErrorCode::TlsProtocolError + })?; + let stream = TokioIo::new(stream); - (sender, worker) - } + let (sender, conn) = timeout( + connect_timeout, + hyper::client::conn::http1::handshake(stream), + ) + .await + .map_err(|_| types::ErrorCode::ConnectionTimeout)? + .map_err(hyper_request_error)?; + + let worker = wasmtime_wasi::runtime::spawn(async move { + match conn.await { + Ok(()) => {} + // TODO: shouldn't throw away this error and ideally should + // surface somewhere. + Err(e) => tracing::warn!("dropping error {e}"), + } + }); + + (sender, worker) } else { let tcp_stream = TokioIo::new(tcp_stream); let (sender, conn) = timeout( diff --git a/crates/wasi-http/tests/all/main.rs b/crates/wasi-http/tests/all/main.rs index 0f6c1d7cf432..395d67ab73ee 100644 --- a/crates/wasi-http/tests/all/main.rs +++ b/crates/wasi-http/tests/all/main.rs @@ -528,8 +528,6 @@ async fn do_wasi_http_echo(uri: &str, url_header: Option<&str>) -> Result<()> { } #[test_log::test(tokio::test)] -// test uses TLS but riscv/s390x don't support that yet -#[cfg_attr(any(target_arch = "riscv64", target_arch = "s390x"), ignore)] async fn wasi_http_without_port() -> Result<()> { let req = hyper::Request::builder() .method(http::Method::GET) diff --git a/crates/wasi-tls/Cargo.toml b/crates/wasi-tls/Cargo.toml new file mode 100644 index 000000000000..7bc7a29c26dc --- /dev/null +++ b/crates/wasi-tls/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "wasmtime-wasi-tls" +version.workspace = true +authors.workspace = true +edition.workspace = true +rust-version.workspace = true +repository = "https://github.com/bytecodealliance/wasmtime" +license = "Apache-2.0 WITH LLVM-exception" +description = "Wasmtime implementation of the wasi-tls API" + +[lints] +workspace = true + +[dependencies] +anyhow = { workspace = true } +bytes = { workspace = true } +tokio = { workspace = true, features = [ + "net", + "rt-multi-thread", + "time", +] } +wasmtime = { workspace = true, features = ["runtime", "component-model"] } +wasmtime-wasi = { workspace = true } +tokio-rustls = { workspace = true } +rustls = { workspace = true } +webpki-roots = { workspace = true } + + +[dev-dependencies] +test-programs-artifacts = { workspace = true } +wasmtime-wasi = { workspace = true } +tokio = { workspace = true, features = ["macros"] } +futures = { workspace = true } diff --git a/crates/wasi-tls/src/lib.rs b/crates/wasi-tls/src/lib.rs new file mode 100644 index 000000000000..89065fe53129 --- /dev/null +++ b/crates/wasi-tls/src/lib.rs @@ -0,0 +1,649 @@ +//! # Wasmtime's [wasi-tls] (Transport Layer Security) Implementation +//! +//! This crate provides the Wasmtime host implementation for the [wasi-tls] API. +//! The [wasi-tls] world allows WebAssembly modules to perform SSL/TLS operations, +//! such as establishing secure connections to servers. TLS often relies on other wasi networking systems +//! to provide the stream so it will be common to enable the [wasi:cli] world as well with the networking features enabled. +//! +//! # An example of how to configure [wasi-tls] is the following: +//! +//! ```rust +//! use wasmtime_wasi::{IoView, WasiCtx, WasiCtxBuilder, WasiView}; +//! use wasmtime::{ +//! component::{Linker, ResourceTable}, +//! Store, Engine, Result, Config +//! }; +//! use wasmtime_wasi_tls::{LinkOptions, WasiTlsCtx}; +//! +//! struct Ctx { +//! table: ResourceTable, +//! wasi_ctx: WasiCtx, +//! } +//! +//! impl IoView for Ctx { +//! fn table(&mut self) -> &mut ResourceTable { +//! &mut self.table +//! } +//! } +//! +//! impl WasiView for Ctx { +//! fn ctx(&mut self) -> &mut WasiCtx { +//! &mut self.wasi_ctx +//! } +//! } +//! +//! #[tokio::main] +//! async fn main() -> Result<()> { +//! let ctx = Ctx { +//! table: ResourceTable::new(), +//! wasi_ctx: WasiCtxBuilder::new() +//! .inherit_stderr() +//! .inherit_network() +//! .allow_ip_name_lookup(true) +//! .build(), +//! }; +//! +//! let mut config = Config::new(); +//! config.async_support(true); +//! let engine = Engine::new(&config)?; +//! +//! // Set up wasi-cli +//! let mut store = Store::new(&engine, ctx); +//! let mut linker = Linker::new(&engine); +//! wasmtime_wasi::add_to_linker_async(&mut linker)?; +//! +//! // Add wasi-tls types and turn on the feature in linker +//! let mut opts = LinkOptions::default(); +//! opts.tls(true); +//! wasmtime_wasi_tls::add_to_linker(&mut linker, &mut opts, |h: &mut Ctx| { +//! WasiTlsCtx::new(&mut h.table) +//! })?; +//! +//! // ... use `linker` to instantiate within `store` ... +//! Ok(()) +//! } +//! +//! ``` +//! [wasi-tls]: https://github.com/WebAssembly/wasi-tls +//! [wasi:cli]: https://docs.rs/wasmtime-wasi/latest + +#![deny(missing_docs)] +#![doc(test(attr(deny(warnings))))] +#![doc(test(attr(allow(dead_code, unused_variables, unused_mut))))] + +use anyhow::{Context, Result}; +use bytes::Bytes; +use rustls::pki_types::ServerName; +use std::io; +use std::sync::Arc; +use std::task::{ready, Poll}; +use std::{future::Future, mem, pin::Pin, sync::LazyLock}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::sync::Mutex; +use tokio_rustls::client::TlsStream; +use wasmtime::component::{Resource, ResourceTable}; +use wasmtime_wasi::pipe::AsyncReadStream; +use wasmtime_wasi::runtime::AbortOnDropJoinHandle; +use wasmtime_wasi::OutputStream; +use wasmtime_wasi::{ + async_trait, + bindings::io::{ + poll::Pollable as HostPollable, + streams::{InputStream as BoxInputStream, OutputStream as BoxOutputStream}, + }, + Pollable, StreamError, +}; + +mod gen_ { + wasmtime::component::bindgen!({ + path: "wit/", + world: "imports", + with: { + "wasi:io": wasmtime_wasi::bindings::io, + "wasi:tls/types/client-connection": super::ClientConnection, + "wasi:tls/types/client-handshake": super::ClientHandShake, + "wasi:tls/types/future-client-streams": super::FutureClientStreams, + }, + trappable_imports: true, + async: { + only_imports: [], + } + }); +} +pub use gen_::wasi::tls::types::LinkOptions; +use gen_::wasi::tls::{self as generated}; + +fn default_client_config() -> Arc { + static CONFIG: LazyLock> = LazyLock::new(|| { + let roots = rustls::RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.into(), + }; + let config = rustls::ClientConfig::builder() + .with_root_certificates(roots) + .with_no_client_auth(); + Arc::new(config) + }); + Arc::clone(&CONFIG) +} + +/// Wasi TLS context needed fro internal `wasi-tls`` state +pub struct WasiTlsCtx<'a> { + table: &'a mut ResourceTable, +} + +impl<'a> WasiTlsCtx<'a> { + /// Create a new Wasi TLS context + pub fn new(table: &'a mut ResourceTable) -> Self { + Self { table } + } +} + +impl<'a> generated::types::Host for WasiTlsCtx<'a> {} + +/// Add the `wasi-tls` world's types to a [`wasmtime::component::Linker`]. +pub fn add_to_linker( + l: &mut wasmtime::component::Linker, + opts: &mut LinkOptions, + f: impl Fn(&mut T) -> WasiTlsCtx + Send + Sync + Copy + 'static, +) -> Result<()> { + generated::types::add_to_linker_get_host(l, &opts, f)?; + Ok(()) +} +/// Represents the ClientHandshake which will be used to configure the handshake +pub struct ClientHandShake { + server_name: String, + streams: WasiStreams, +} + +impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> { + fn new( + &mut self, + server_name: String, + input: Resource, + output: Resource, + ) -> wasmtime::Result> { + let input = self.table.delete(input)?; + let output = self.table.delete(output)?; + Ok(self.table.push(ClientHandShake { + server_name, + streams: WasiStreams { + input: StreamState::Ready(input), + output: StreamState::Ready(output), + }, + })?) + } + + fn finish( + &mut self, + this: wasmtime::component::Resource, + ) -> wasmtime::Result> { + let handshake = self.table.delete(this)?; + let server_name = handshake.server_name; + let streams = handshake.streams; + let domain = ServerName::try_from(server_name)?; + + Ok(self + .table + .push(FutureStreams(StreamState::Pending(Box::pin(async move { + let connector = tokio_rustls::TlsConnector::from(default_client_config()); + connector + .connect(domain, streams) + .await + .with_context(|| "connection failed") + }))))?) + } + + fn drop( + &mut self, + this: wasmtime::component::Resource, + ) -> wasmtime::Result<()> { + self.table.delete(this)?; + Ok(()) + } +} + +/// Future streams provides the tls streams after the handshake is completed +pub struct FutureStreams(StreamState>); + +/// Library specific version of TLS connection after the handshake is completed. +/// This alias allows it to use with wit-bindgen component generator which won't take generic types +pub type FutureClientStreams = FutureStreams>; + +#[async_trait] +impl Pollable for FutureStreams { + async fn ready(&mut self) { + match &mut self.0 { + StreamState::Ready(_) | StreamState::Closed => return, + StreamState::Pending(task) => self.0 = StreamState::Ready(task.as_mut().await), + } + } +} + +impl<'a> generated::types::HostFutureClientStreams for WasiTlsCtx<'a> { + fn subscribe( + &mut self, + this: wasmtime::component::Resource, + ) -> wasmtime::Result> { + wasmtime_wasi::subscribe(self.table, this) + } + + fn get( + &mut self, + this: wasmtime::component::Resource, + ) -> wasmtime::Result< + Option< + Result< + Result< + ( + Resource, + Resource, + Resource, + ), + (), + >, + (), + >, + >, + > { + { + let this = self.table.get(&this)?; + match &this.0 { + StreamState::Pending(_) => return Ok(None), + StreamState::Ready(Ok(_)) => (), + StreamState::Ready(Err(_)) => { + return Ok(Some(Ok(Err(())))); + } + StreamState::Closed => return Ok(Some(Err(()))), + } + } + + let StreamState::Ready(Ok(tls_stream)) = + mem::replace(&mut self.table.get_mut(&this)?.0, StreamState::Closed) + else { + unreachable!() + }; + + let (rx, tx) = tokio::io::split(tls_stream); + let write_stream = AsyncTlsWriteStream::new(TlsWriter::new(tx)); + let client = ClientConnection { + writer: write_stream.clone(), + }; + + let input = Box::new(AsyncReadStream::new(rx)) as BoxInputStream; + let output = Box::new(write_stream) as BoxOutputStream; + + let client = self.table.push(client)?; + let input = self.table.push_child(input, &client)?; + let output = self.table.push_child(output, &client)?; + + Ok(Some(Ok(Ok((client, input, output))))) + } + + fn drop( + &mut self, + this: wasmtime::component::Resource, + ) -> wasmtime::Result<()> { + self.table.delete(this)?; + Ok(()) + } +} + +/// Represents the client connection and used to shut down the tls stream +pub struct ClientConnection { + writer: AsyncTlsWriteStream, +} + +impl<'a> generated::types::HostClientConnection for WasiTlsCtx<'a> { + fn close_output(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table.get_mut(&this)?.writer.close() + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table.delete(this)?; + Ok(()) + } +} + +enum StreamState { + Ready(T), + Pending(Pin + Send>>), + Closed, +} + +/// Wrapper around Input and Output wasi IO Stream that provides Async Read/Write +pub struct WasiStreams { + input: StreamState, + output: StreamState, +} + +impl AsyncWrite for WasiStreams { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + loop { + match &mut self.as_mut().output { + StreamState::Closed => unreachable!(), + StreamState::Pending(future) => { + let value = ready!(future.as_mut().poll(cx)); + self.as_mut().output = StreamState::Ready(value); + } + StreamState::Ready(output) => { + match output.check_write() { + Ok(0) => { + let StreamState::Ready(mut output) = + mem::replace(&mut self.as_mut().output, StreamState::Closed) + else { + unreachable!() + }; + self.as_mut().output = StreamState::Pending(Box::pin(async move { + output.ready().await; + output + })); + } + Ok(count) => { + let count = count.min(buf.len()); + return match output.write(Bytes::copy_from_slice(&buf[..count])) { + Ok(()) => Poll::Ready(Ok(count)), + Err(StreamError::Closed) => Poll::Ready(Ok(0)), + Err(StreamError::LastOperationFailed(e) | StreamError::Trap(e)) => { + Poll::Ready(Err(std::io::Error::other(e))) + } + }; + } + Err(StreamError::Closed) => return Poll::Ready(Ok(0)), + Err(StreamError::LastOperationFailed(e) | StreamError::Trap(e)) => { + return Poll::Ready(Err(std::io::Error::other(e))) + } + }; + } + } + } + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.poll_write(cx, &[]).map(|v| v.map(drop)) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.poll_flush(cx) + } +} + +impl AsyncRead for WasiStreams { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + loop { + let stream = match &mut self.input { + StreamState::Ready(stream) => stream, + StreamState::Pending(fut) => { + let stream = ready!(fut.as_mut().poll(cx)); + self.input = StreamState::Ready(stream); + if let StreamState::Ready(stream) = &mut self.input { + stream + } else { + unreachable!() + } + } + StreamState::Closed => { + return Poll::Ready(Ok(())); + } + }; + match stream.read(buf.remaining()) { + Ok(bytes) if bytes.is_empty() => { + let StreamState::Ready(mut stream) = + std::mem::replace(&mut self.input, StreamState::Closed) + else { + unreachable!() + }; + + self.input = StreamState::Pending(Box::pin(async move { + stream.ready().await; + stream + })); + } + Ok(bytes) => { + buf.put_slice(&bytes); + + return Poll::Ready(Ok(())); + } + Err(StreamError::Closed) => { + self.input = StreamState::Closed; + return Poll::Ready(Ok(())); + } + Err(e) => { + self.input = StreamState::Closed; + return Poll::Ready(Err(std::io::Error::other(e))); + } + } + } + } +} + +type TlsWriteHalf = tokio::io::WriteHalf>; + +struct TlsWriter { + state: WriteState, +} + +enum WriteState { + Ready(TlsWriteHalf), + Writing(AbortOnDropJoinHandle>), + Closing(AbortOnDropJoinHandle>), + Closed, + Error(io::Error), +} +const READY_SIZE: usize = 1024 * 1024 * 1024; + +impl TlsWriter { + fn new(stream: TlsWriteHalf) -> Self { + Self { + state: WriteState::Ready(stream), + } + } + + fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> { + let WriteState::Ready(_) = self.state else { + return Err(StreamError::Trap(anyhow::anyhow!( + "unpermitted: must call check_write first" + ))); + }; + + if bytes.is_empty() { + return Ok(()); + } + + let WriteState::Ready(mut stream) = std::mem::replace(&mut self.state, WriteState::Closed) + else { + unreachable!() + }; + + self.state = WriteState::Writing(wasmtime_wasi::runtime::spawn(async move { + while !bytes.is_empty() { + match stream.write(&bytes).await { + Ok(n) => { + let _ = bytes.split_to(n); + } + Err(e) => return Err(e.into()), + } + } + + Ok(stream) + })); + + Ok(()) + } + + fn flush(&mut self) -> Result<(), StreamError> { + // `flush` is a no-op here, as we're not managing any internal buffer. + match self.state { + WriteState::Ready(_) + | WriteState::Writing(_) + | WriteState::Closing(_) + | WriteState::Error(_) => Ok(()), + WriteState::Closed => Err(StreamError::Closed), + } + } + + fn check_write(&mut self) -> Result { + match &mut self.state { + WriteState::Ready(_) => Ok(READY_SIZE), + WriteState::Writing(_) => Ok(0), + WriteState::Closing(_) => Ok(0), + WriteState::Closed => Err(StreamError::Closed), + WriteState::Error(_) => { + let WriteState::Error(e) = std::mem::replace(&mut self.state, WriteState::Closed) + else { + unreachable!() + }; + + Err(StreamError::LastOperationFailed(e.into())) + } + } + } + + fn close(&mut self) { + match std::mem::replace(&mut self.state, WriteState::Closed) { + // No write in progress, immediately shut down: + WriteState::Ready(mut stream) => { + self.state = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move { + stream.shutdown().await + })); + } + + // Schedule the shutdown after the current write has finished: + WriteState::Writing(write) => { + self.state = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move { + let mut stream = write.await?; + stream.shutdown().await + })); + } + + WriteState::Closing(t) => { + self.state = WriteState::Closing(t); + } + WriteState::Closed | WriteState::Error(_) => {} + } + } + + async fn cancel(&mut self) { + match std::mem::replace(&mut self.state, WriteState::Closed) { + WriteState::Writing(task) => _ = task.cancel().await, + WriteState::Closing(task) => _ = task.cancel().await, + _ => {} + } + } + + async fn ready(&mut self) { + match &mut self.state { + WriteState::Writing(task) => { + self.state = match task.await { + Ok(s) => WriteState::Ready(s), + Err(e) => WriteState::Error(e), + } + } + WriteState::Closing(task) => { + self.state = match task.await { + Ok(()) => WriteState::Closed, + Err(e) => WriteState::Error(e), + } + } + _ => {} + } + } +} + +#[derive(Clone)] +struct AsyncTlsWriteStream(Arc>); + +impl AsyncTlsWriteStream { + fn new(writer: TlsWriter) -> Self { + AsyncTlsWriteStream(Arc::new(Mutex::new(writer))) + } + + fn close(&mut self) -> wasmtime::Result<()> { + try_lock_for_stream(&self.0)?.close(); + Ok(()) + } +} + +#[async_trait] +impl OutputStream for AsyncTlsWriteStream { + fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> { + try_lock_for_stream(&self.0)?.write(bytes) + } + + fn flush(&mut self) -> Result<(), StreamError> { + try_lock_for_stream(&self.0)?.flush() + } + + fn check_write(&mut self) -> Result { + try_lock_for_stream(&self.0)?.check_write() + } + + async fn cancel(&mut self) { + self.0.lock().await.cancel().await + } +} + +#[async_trait] +impl Pollable for AsyncTlsWriteStream { + async fn ready(&mut self) { + self.0.lock().await.ready().await + } +} + +fn try_lock_for_stream( + mutex: &Mutex, +) -> Result, StreamError> { + mutex + .try_lock() + .map_err(|_| StreamError::trap("concurrent access to resource not supported")) +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::sync::oneshot; + + #[tokio::test] + async fn test_future_client_streams_ready_can_be_canceled() { + let (tx1, rx1) = oneshot::channel::<()>(); + + let mut future_streams = FutureStreams(StreamState::Pending(Box::pin(async move { + rx1.await.map_err(|_| anyhow::anyhow!("oneshot canceled")) + }))); + + let mut fut = future_streams.ready(); + + let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref()); + assert!(fut.as_mut().poll(&mut cx).is_pending()); + + //cancel the readiness check + drop(fut); + + match future_streams.0 { + StreamState::Closed => panic!("First future should be in Pending/ready state"), + _ => (), + } + + // make it ready and wait for it to progress + tx1.send(()).unwrap(); + future_streams.ready().await; + + match future_streams.0 { + StreamState::Ready(Ok(())) => (), + _ => panic!("First future should be in Ready(Err) state"), + } + } +} diff --git a/crates/wasi-tls/tests/main.rs b/crates/wasi-tls/tests/main.rs new file mode 100644 index 000000000000..6919a73ead4e --- /dev/null +++ b/crates/wasi-tls/tests/main.rs @@ -0,0 +1,72 @@ +use anyhow::{anyhow, Result}; +use test_programs_artifacts::{foreach_tls, TLS_SAMPLE_APPLICATION_COMPONENT}; +use wasmtime::{ + component::{Component, Linker, ResourceTable}, + Store, +}; +use wasmtime_wasi::{bindings::Command, IoView, WasiCtx, WasiCtxBuilder, WasiView}; +use wasmtime_wasi_tls::{LinkOptions, WasiTlsCtx}; + +struct Ctx { + table: ResourceTable, + wasi_ctx: WasiCtx, +} + +impl IoView for Ctx { + fn table(&mut self) -> &mut ResourceTable { + &mut self.table + } +} +impl WasiView for Ctx { + fn ctx(&mut self) -> &mut WasiCtx { + &mut self.wasi_ctx + } +} + +async fn run_wasi(path: &str, ctx: Ctx) -> Result<()> { + let engine = test_programs_artifacts::engine(|config| { + config.async_support(true); + }); + let mut store = Store::new(&engine, ctx); + let component = Component::from_file(&engine, path)?; + + let mut linker = Linker::new(&engine); + wasmtime_wasi::add_to_linker_async(&mut linker)?; + let mut opts = LinkOptions::default(); + opts.tls(true); + wasmtime_wasi_tls::add_to_linker(&mut linker, &mut opts, |h: &mut Ctx| { + WasiTlsCtx::new(&mut h.table) + })?; + + let command = Command::instantiate_async(&mut store, &component, &linker).await?; + command + .wasi_cli_run() + .call_run(&mut store) + .await? + .map_err(|()| anyhow!("command returned with failing exit status")) +} + +macro_rules! assert_test_exists { + ($name:ident) => { + #[expect(unused_imports, reason = "just here to assert it exists")] + use self::$name as _; + }; +} + +foreach_tls!(assert_test_exists); + +#[tokio::test(flavor = "multi_thread")] +async fn tls_sample_application() -> Result<()> { + run_wasi( + TLS_SAMPLE_APPLICATION_COMPONENT, + Ctx { + table: ResourceTable::new(), + wasi_ctx: WasiCtxBuilder::new() + .inherit_stderr() + .inherit_network() + .allow_ip_name_lookup(true) + .build(), + }, + ) + .await +} diff --git a/crates/wasi-tls/wit/deps/io/error.wit b/crates/wasi-tls/wit/deps/io/error.wit new file mode 100644 index 000000000000..5bc13f99e3b1 --- /dev/null +++ b/crates/wasi-tls/wit/deps/io/error.wit @@ -0,0 +1,34 @@ +package wasi:io@0.2.3; + +@since(version = 0.2.0) +interface error { + /// A resource which represents some error information. + /// + /// The only method provided by this resource is `to-debug-string`, + /// which provides some human-readable information about the error. + /// + /// In the `wasi:io` package, this resource is returned through the + /// `wasi:io/streams/stream-error` type. + /// + /// To provide more specific error information, other interfaces may + /// offer functions to "downcast" this error into more specific types. For example, + /// errors returned from streams derived from filesystem types can be described using + /// the filesystem's own error-code type. This is done using the function + /// `wasi:filesystem/types/filesystem-error-code`, which takes a `borrow` + /// parameter and returns an `option`. + /// + /// The set of functions which can "downcast" an `error` into a more + /// concrete type is open. + @since(version = 0.2.0) + resource error { + /// Returns a string that is suitable to assist humans in debugging + /// this error. + /// + /// WARNING: The returned string should not be consumed mechanically! + /// It may change across platforms, hosts, or other implementation + /// details. Parsing this string is a major platform-compatibility + /// hazard. + @since(version = 0.2.0) + to-debug-string: func() -> string; + } +} \ No newline at end of file diff --git a/crates/wasi-tls/wit/deps/io/poll.wit b/crates/wasi-tls/wit/deps/io/poll.wit new file mode 100644 index 000000000000..9eb452566d83 --- /dev/null +++ b/crates/wasi-tls/wit/deps/io/poll.wit @@ -0,0 +1,47 @@ +package wasi:io@0.2.3; + +/// A poll API intended to let users wait for I/O events on multiple handles +/// at once. +@since(version = 0.2.0) +interface poll { + /// `pollable` represents a single I/O event which may be ready, or not. + @since(version = 0.2.0) + resource pollable { + + /// Return the readiness of a pollable. This function never blocks. + /// + /// Returns `true` when the pollable is ready, and `false` otherwise. + @since(version = 0.2.0) + ready: func() -> bool; + + /// `block` returns immediately if the pollable is ready, and otherwise + /// blocks until ready. + /// + /// This function is equivalent to calling `poll.poll` on a list + /// containing only this pollable. + @since(version = 0.2.0) + block: func(); + } + + /// Poll for completion on a set of pollables. + /// + /// This function takes a list of pollables, which identify I/O sources of + /// interest, and waits until one or more of the events is ready for I/O. + /// + /// The result `list` contains one or more indices of handles in the + /// argument list that is ready for I/O. + /// + /// This function traps if either: + /// - the list is empty, or: + /// - the list contains more elements than can be indexed with a `u32` value. + /// + /// A timeout can be implemented by adding a pollable from the + /// wasi-clocks API to the list. + /// + /// This function does not return a `result`; polling in itself does not + /// do any I/O so it doesn't fail. If any of the I/O sources identified by + /// the pollables has an error, it is indicated by marking the source as + /// being ready for I/O. + @since(version = 0.2.0) + poll: func(in: list>) -> list; +} \ No newline at end of file diff --git a/crates/wasi-tls/wit/deps/io/streams.wit b/crates/wasi-tls/wit/deps/io/streams.wit new file mode 100644 index 000000000000..68fd80c5a446 --- /dev/null +++ b/crates/wasi-tls/wit/deps/io/streams.wit @@ -0,0 +1,290 @@ +package wasi:io@0.2.3; + +/// WASI I/O is an I/O abstraction API which is currently focused on providing +/// stream types. +/// +/// In the future, the component model is expected to add built-in stream types; +/// when it does, they are expected to subsume this API. +@since(version = 0.2.0) +interface streams { + @since(version = 0.2.0) + use error.{error}; + @since(version = 0.2.0) + use poll.{pollable}; + + /// An error for input-stream and output-stream operations. + @since(version = 0.2.0) + variant stream-error { + /// The last operation (a write or flush) failed before completion. + /// + /// More information is available in the `error` payload. + /// + /// After this, the stream will be closed. All future operations return + /// `stream-error::closed`. + last-operation-failed(error), + /// The stream is closed: no more input will be accepted by the + /// stream. A closed output-stream will return this error on all + /// future operations. + closed + } + + /// An input bytestream. + /// + /// `input-stream`s are *non-blocking* to the extent practical on underlying + /// platforms. I/O operations always return promptly; if fewer bytes are + /// promptly available than requested, they return the number of bytes promptly + /// available, which could even be zero. To wait for data to be available, + /// use the `subscribe` function to obtain a `pollable` which can be polled + /// for using `wasi:io/poll`. + @since(version = 0.2.0) + resource input-stream { + /// Perform a non-blocking read from the stream. + /// + /// When the source of a `read` is binary data, the bytes from the source + /// are returned verbatim. When the source of a `read` is known to the + /// implementation to be text, bytes containing the UTF-8 encoding of the + /// text are returned. + /// + /// This function returns a list of bytes containing the read data, + /// when successful. The returned list will contain up to `len` bytes; + /// it may return fewer than requested, but not more. The list is + /// empty when no bytes are available for reading at this time. The + /// pollable given by `subscribe` will be ready when more bytes are + /// available. + /// + /// This function fails with a `stream-error` when the operation + /// encounters an error, giving `last-operation-failed`, or when the + /// stream is closed, giving `closed`. + /// + /// When the caller gives a `len` of 0, it represents a request to + /// read 0 bytes. If the stream is still open, this call should + /// succeed and return an empty list, or otherwise fail with `closed`. + /// + /// The `len` parameter is a `u64`, which could represent a list of u8 which + /// is not possible to allocate in wasm32, or not desirable to allocate as + /// as a return value by the callee. The callee may return a list of bytes + /// less than `len` in size while more bytes are available for reading. + @since(version = 0.2.0) + read: func( + /// The maximum number of bytes to read + len: u64 + ) -> result, stream-error>; + + /// Read bytes from a stream, after blocking until at least one byte can + /// be read. Except for blocking, behavior is identical to `read`. + @since(version = 0.2.0) + blocking-read: func( + /// The maximum number of bytes to read + len: u64 + ) -> result, stream-error>; + + /// Skip bytes from a stream. Returns number of bytes skipped. + /// + /// Behaves identical to `read`, except instead of returning a list + /// of bytes, returns the number of bytes consumed from the stream. + @since(version = 0.2.0) + skip: func( + /// The maximum number of bytes to skip. + len: u64, + ) -> result; + + /// Skip bytes from a stream, after blocking until at least one byte + /// can be skipped. Except for blocking behavior, identical to `skip`. + @since(version = 0.2.0) + blocking-skip: func( + /// The maximum number of bytes to skip. + len: u64, + ) -> result; + + /// Create a `pollable` which will resolve once either the specified stream + /// has bytes available to read or the other end of the stream has been + /// closed. + /// The created `pollable` is a child resource of the `input-stream`. + /// Implementations may trap if the `input-stream` is dropped before + /// all derived `pollable`s created with this function are dropped. + @since(version = 0.2.0) + subscribe: func() -> pollable; + } + + + /// An output bytestream. + /// + /// `output-stream`s are *non-blocking* to the extent practical on + /// underlying platforms. Except where specified otherwise, I/O operations also + /// always return promptly, after the number of bytes that can be written + /// promptly, which could even be zero. To wait for the stream to be ready to + /// accept data, the `subscribe` function to obtain a `pollable` which can be + /// polled for using `wasi:io/poll`. + /// + /// Dropping an `output-stream` while there's still an active write in + /// progress may result in the data being lost. Before dropping the stream, + /// be sure to fully flush your writes. + @since(version = 0.2.0) + resource output-stream { + /// Check readiness for writing. This function never blocks. + /// + /// Returns the number of bytes permitted for the next call to `write`, + /// or an error. Calling `write` with more bytes than this function has + /// permitted will trap. + /// + /// When this function returns 0 bytes, the `subscribe` pollable will + /// become ready when this function will report at least 1 byte, or an + /// error. + @since(version = 0.2.0) + check-write: func() -> result; + + /// Perform a write. This function never blocks. + /// + /// When the destination of a `write` is binary data, the bytes from + /// `contents` are written verbatim. When the destination of a `write` is + /// known to the implementation to be text, the bytes of `contents` are + /// transcoded from UTF-8 into the encoding of the destination and then + /// written. + /// + /// Precondition: check-write gave permit of Ok(n) and contents has a + /// length of less than or equal to n. Otherwise, this function will trap. + /// + /// returns Err(closed) without writing if the stream has closed since + /// the last call to check-write provided a permit. + @since(version = 0.2.0) + write: func( + contents: list + ) -> result<_, stream-error>; + + /// Perform a write of up to 4096 bytes, and then flush the stream. Block + /// until all of these operations are complete, or an error occurs. + /// + /// This is a convenience wrapper around the use of `check-write`, + /// `subscribe`, `write`, and `flush`, and is implemented with the + /// following pseudo-code: + /// + /// ```text + /// let pollable = this.subscribe(); + /// while !contents.is_empty() { + /// // Wait for the stream to become writable + /// pollable.block(); + /// let Ok(n) = this.check-write(); // eliding error handling + /// let len = min(n, contents.len()); + /// let (chunk, rest) = contents.split_at(len); + /// this.write(chunk ); // eliding error handling + /// contents = rest; + /// } + /// this.flush(); + /// // Wait for completion of `flush` + /// pollable.block(); + /// // Check for any errors that arose during `flush` + /// let _ = this.check-write(); // eliding error handling + /// ``` + @since(version = 0.2.0) + blocking-write-and-flush: func( + contents: list + ) -> result<_, stream-error>; + + /// Request to flush buffered output. This function never blocks. + /// + /// This tells the output-stream that the caller intends any buffered + /// output to be flushed. the output which is expected to be flushed + /// is all that has been passed to `write` prior to this call. + /// + /// Upon calling this function, the `output-stream` will not accept any + /// writes (`check-write` will return `ok(0)`) until the flush has + /// completed. The `subscribe` pollable will become ready when the + /// flush has completed and the stream can accept more writes. + @since(version = 0.2.0) + flush: func() -> result<_, stream-error>; + + /// Request to flush buffered output, and block until flush completes + /// and stream is ready for writing again. + @since(version = 0.2.0) + blocking-flush: func() -> result<_, stream-error>; + + /// Create a `pollable` which will resolve once the output-stream + /// is ready for more writing, or an error has occurred. When this + /// pollable is ready, `check-write` will return `ok(n)` with n>0, or an + /// error. + /// + /// If the stream is closed, this pollable is always ready immediately. + /// + /// The created `pollable` is a child resource of the `output-stream`. + /// Implementations may trap if the `output-stream` is dropped before + /// all derived `pollable`s created with this function are dropped. + @since(version = 0.2.0) + subscribe: func() -> pollable; + + /// Write zeroes to a stream. + /// + /// This should be used precisely like `write` with the exact same + /// preconditions (must use check-write first), but instead of + /// passing a list of bytes, you simply pass the number of zero-bytes + /// that should be written. + @since(version = 0.2.0) + write-zeroes: func( + /// The number of zero-bytes to write + len: u64 + ) -> result<_, stream-error>; + + /// Perform a write of up to 4096 zeroes, and then flush the stream. + /// Block until all of these operations are complete, or an error + /// occurs. + /// + /// This is a convenience wrapper around the use of `check-write`, + /// `subscribe`, `write-zeroes`, and `flush`, and is implemented with + /// the following pseudo-code: + /// + /// ```text + /// let pollable = this.subscribe(); + /// while num_zeroes != 0 { + /// // Wait for the stream to become writable + /// pollable.block(); + /// let Ok(n) = this.check-write(); // eliding error handling + /// let len = min(n, num_zeroes); + /// this.write-zeroes(len); // eliding error handling + /// num_zeroes -= len; + /// } + /// this.flush(); + /// // Wait for completion of `flush` + /// pollable.block(); + /// // Check for any errors that arose during `flush` + /// let _ = this.check-write(); // eliding error handling + /// ``` + @since(version = 0.2.0) + blocking-write-zeroes-and-flush: func( + /// The number of zero-bytes to write + len: u64 + ) -> result<_, stream-error>; + + /// Read from one stream and write to another. + /// + /// The behavior of splice is equivalent to: + /// 1. calling `check-write` on the `output-stream` + /// 2. calling `read` on the `input-stream` with the smaller of the + /// `check-write` permitted length and the `len` provided to `splice` + /// 3. calling `write` on the `output-stream` with that read data. + /// + /// Any error reported by the call to `check-write`, `read`, or + /// `write` ends the splice and reports that error. + /// + /// This function returns the number of bytes transferred; it may be less + /// than `len`. + @since(version = 0.2.0) + splice: func( + /// The stream to read from + src: borrow, + /// The number of bytes to splice + len: u64, + ) -> result; + + /// Read from one stream and write to another, with blocking. + /// + /// This is similar to `splice`, except that it blocks until the + /// `output-stream` is ready for writing, and the `input-stream` + /// is ready for reading, before performing the `splice`. + @since(version = 0.2.0) + blocking-splice: func( + /// The stream to read from + src: borrow, + /// The number of bytes to splice + len: u64, + ) -> result; + } +} \ No newline at end of file diff --git a/crates/wasi-tls/wit/deps/io/world.wit b/crates/wasi-tls/wit/deps/io/world.wit new file mode 100644 index 000000000000..135dbb5baf6d --- /dev/null +++ b/crates/wasi-tls/wit/deps/io/world.wit @@ -0,0 +1,10 @@ +package wasi:io@0.2.3; + +@since(version = 0.2.0) +world imports { + @since(version = 0.2.0) + import streams; + + @since(version = 0.2.0) + import poll; +} \ No newline at end of file diff --git a/crates/wasi-tls/wit/world.wit b/crates/wasi-tls/wit/world.wit new file mode 100644 index 000000000000..411603ec1637 --- /dev/null +++ b/crates/wasi-tls/wit/world.wit @@ -0,0 +1,39 @@ +package wasi:tls@0.2.0-draft; + +@unstable(feature = tls) +world imports { + @unstable(feature = tls) + import types; +} + +@unstable(feature = tls) +interface types { + @unstable(feature = tls) + use wasi:io/streams@0.2.3.{input-stream, output-stream}; + @unstable(feature = tls) + use wasi:io/poll@0.2.3.{pollable}; + + @unstable(feature = tls) + resource client-handshake { + @unstable(feature = tls) + constructor(server-name: string, input: input-stream, output: output-stream); + + @unstable(feature = tls) + finish: static func(this: client-handshake) -> future-client-streams; + } + + @unstable(feature = tls) + resource client-connection { + @unstable(feature = tls) + close-output: func(); + } + + @unstable(feature = tls) + resource future-client-streams { + @unstable(feature = tls) + subscribe: func() -> pollable; + + @unstable(feature = tls) + get: func() -> option>>>; + } +} \ No newline at end of file diff --git a/crates/wasi/src/runtime.rs b/crates/wasi/src/runtime.rs index 762e4045e32b..2c23bd7885a0 100644 --- a/crates/wasi/src/runtime.rs +++ b/crates/wasi/src/runtime.rs @@ -42,7 +42,7 @@ pub struct AbortOnDropJoinHandle(tokio::task::JoinHandle); impl AbortOnDropJoinHandle { /// Abort the task and wait for it to finish. Optionally returns the result /// of the task if it ran to completion prior to being aborted. - pub(crate) async fn cancel(mut self) -> Option { + pub async fn cancel(mut self) -> Option { self.0.abort(); match (&mut self.0).await { diff --git a/scripts/publish.rs b/scripts/publish.rs index 81557649f93a..402446a61be5 100644 --- a/scripts/publish.rs +++ b/scripts/publish.rs @@ -75,6 +75,7 @@ const CRATES_TO_PUBLISH: &[&str] = &[ "wasmtime-wasi-config", "wasmtime-wasi-keyvalue", "wasmtime-wasi-threads", + "wasmtime-wasi-tls", "wasmtime-wast", "wasmtime-c-api-macros", "wasmtime-c-api-impl", @@ -93,6 +94,7 @@ const PUBLIC_CRATES: &[&str] = &[ "wasmtime", "wasmtime-wasi-io", "wasmtime-wasi", + "wasmtime-wasi-tls", "wasmtime-wasi-http", "wasmtime-wasi-nn", "wasmtime-wasi-config", diff --git a/src/commands/run.rs b/src/commands/run.rs index 7baa4a37f812..690edb5ad855 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -32,6 +32,9 @@ use wasmtime_wasi_http::{ #[cfg(feature = "wasi-keyvalue")] use wasmtime_wasi_keyvalue::{WasiKeyValue, WasiKeyValueCtx, WasiKeyValueCtxBuilder}; +#[cfg(feature = "wasi-tls")] +use wasmtime_wasi_tls::WasiTlsCtx; + fn parse_preloads(s: &str) -> Result<(String, PathBuf)> { let parts: Vec<&str> = s.splitn(2, '=').collect(); if parts.len() != 2 { @@ -825,6 +828,32 @@ impl RunCommand { } } + if self.run.common.wasi.tls == Some(true) { + #[cfg(all(not(all(feature = "wasi-tls", feature = "component-model"))))] + { + bail!("Cannot enable wasi-tls when the binary is not compiled with this feature."); + } + #[cfg(all(feature = "wasi-tls", feature = "component-model",))] + { + match linker { + CliLinker::Core(_) => { + bail!("Cannot enable wasi-tls for core wasm modules"); + } + CliLinker::Component(linker) => { + let mut opts = wasmtime_wasi_tls::LinkOptions::default(); + opts.tls(true); + wasmtime_wasi_tls::add_to_linker(linker, &mut opts, |h| { + let preview2_ctx = + h.preview2_ctx.as_mut().expect("wasip2 is not configured"); + let preview2_ctx = + Arc::get_mut(preview2_ctx).unwrap().get_mut().unwrap(); + WasiTlsCtx::new(preview2_ctx.table()) + })?; + } + } + } + } + Ok(()) }