diff --git a/Cargo.lock b/Cargo.lock index c3c22b601..c01be1b87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1202,6 +1202,16 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "884e2677b40cc8c339eaefcb701c32ef1fd2493d71118dc0ca4b6a736c93bd67" +[[package]] +name = "leb128-tokio" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47a57548d0fcd8bf0c9c601af2a3fdd05dac4829b82ae872e893df31f523e8ae" +dependencies = [ + "tokio", + "tokio-util", +] + [[package]] name = "libc" version = "0.2.154" @@ -2522,6 +2532,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf8-tokio" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29f6ede684629efe08c9538958d0f80a4a7475dd49f41fb102b799882098e55a" +dependencies = [ + "tokio", + "tokio-util", +] + [[package]] name = "utf8parse" version = "0.2.1" @@ -2678,12 +2698,14 @@ dependencies = [ [[package]] name = "wasm-tokio" -version = "0.1.11" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fccf07b746992fc4b8922dfe0feff98ff3aed34f4627cb96ccc3a279381dec1" +checksum = "15db3d053fd102fa738502799dd19d0c3e7645b290004dabbdc2f86830d1344c" dependencies = [ + "leb128-tokio", "tokio", "tokio-util", + "utf8-tokio", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index b55bb0457..ea7c84995 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -117,7 +117,7 @@ tower = { version = "0.4", default-features = false } tracing = { version = "0.1", default-features = false } tracing-subscriber = { version = "0.3", default-features = false } url = { version = "2", default-features = false } -wasm-tokio = { version = "0.1", default-features = false } +wasm-tokio = { version = "0.4", default-features = false } wasmcloud-component-adapters = { version = "0.9", default-features = false } wasmparser = { version = "0.208", default-features = false } wasmtime = { version = "21", default-features = false } diff --git a/crates/runtime-wasmtime/src/lib.rs b/crates/runtime-wasmtime/src/lib.rs index 29084729f..b64094fe1 100644 --- a/crates/runtime-wasmtime/src/lib.rs +++ b/crates/runtime-wasmtime/src/lib.rs @@ -1,96 +1,38 @@ #![allow(clippy::type_complexity)] // TODO: https://github.com/wrpc/wrpc/issues/2 -use core::fmt::{self, Display}; use core::future::Future; use core::iter::zip; use core::ops::{BitOrAssign, Shl}; use core::pin::{pin, Pin}; -use core::task::{Context, Poll}; use std::collections::HashSet; -use std::error::Error; use std::sync::Arc; use anyhow::{anyhow, bail, Context as _}; -use bytes::{BufMut as _, Bytes, BytesMut}; +use bytes::{BufMut as _, BytesMut}; use futures::stream::FuturesUnordered; -use futures::{Stream, TryStreamExt as _}; +use futures::TryStreamExt as _; use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _}; use tokio::try_join; use tokio_util::codec::Encoder; use tracing::{error, trace}; use tracing::{instrument, warn}; -use wasm_tokio::cm::{AsyncReadValue as _, CharEncoder}; -use wasm_tokio::{AsyncReadCore as _, CoreStringEncoder, Leb128Encoder}; +use wasm_tokio::cm::AsyncReadValue as _; +use wasm_tokio::{ + AsyncReadCore as _, AsyncReadLeb128 as _, AsyncReadUtf8 as _, CoreStringEncoder, Leb128Encoder, + Utf8Encoder, +}; use wasmtime::component::types::{self, Case, Field}; use wasmtime::component::{Linker, ResourceType, Type, Val}; use wasmtime::{AsContextMut, StoreContextMut}; use wasmtime_wasi::pipe::AsyncReadStream; -use wasmtime_wasi::{FileInputStream, HostInputStream, InputStream, StreamError, WasiView}; +use wasmtime_wasi::{InputStream, StreamError, WasiView}; use wit_parser::FunctionKind; use wrpc_introspect::rpc_func_name; use wrpc_transport::{Index as _, Invocation, Invoke, Session}; pub struct RemoteResource(pub String); -pub struct OutgoingHostInputStream(Box); - -#[derive(Debug)] -pub enum OutgoingStreamError { - Failed(anyhow::Error), - Trap(anyhow::Error), -} - -impl Display for OutgoingStreamError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Failed(err) => write!(f, "last operation failed: {err:#}"), - Self::Trap(err) => write!(f, "trap: {err:#}"), - } - } -} - -impl Error for OutgoingStreamError {} - -impl Stream for OutgoingHostInputStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match pin!(self.0.ready()).poll(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(()) => {} - } - match self.0.read(8096) { - Ok(buf) => Poll::Ready(Some(Ok(buf))), - Err(StreamError::LastOperationFailed(err)) => { - Poll::Ready(Some(Err(OutgoingStreamError::Failed(err)))) - } - Err(StreamError::Trap(err)) => Poll::Ready(Some(Err(OutgoingStreamError::Trap(err)))), - Err(StreamError::Closed) => Poll::Ready(None), - } - } -} - -pub struct OutgoingFileInputStream(FileInputStream); - -impl Stream for OutgoingFileInputStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match pin!(self.0.read(8096)).poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(buf)) => Poll::Ready(Some(Ok(buf))), - Poll::Ready(Err(StreamError::LastOperationFailed(err))) => { - Poll::Ready(Some(Err(OutgoingStreamError::Failed(err)))) - } - Poll::Ready(Err(StreamError::Trap(err))) => { - Poll::Ready(Some(Err(OutgoingStreamError::Trap(err)))) - } - Poll::Ready(Err(StreamError::Closed)) => Poll::Ready(None), - } - } -} - pub struct ValEncoder<'a, T, W> { pub store: StoreContextMut<'a, T>, pub ty: &'a Type, @@ -228,7 +170,7 @@ where Ok(()) } (Val::Char(v), Type::Char) => { - CharEncoder.encode(*v, dst).context("failed to encode char") + Utf8Encoder.encode(*v, dst).context("failed to encode char") } (Val::String(v), Type::String) => CoreStringEncoder .encode(v.as_str(), dst) @@ -285,30 +227,30 @@ where let ty = match cases.len() { ..=0x0000_00ff => { let (discriminant, ty) = - find_variant_discriminant(0.., cases, discriminant)?; - dst.reserve(1 + usize::from(v.is_some())); - dst.put_u8(discriminant); + find_variant_discriminant(0u8.., cases, discriminant)?; + dst.reserve(2 + usize::from(v.is_some())); + Leb128Encoder.encode(discriminant, dst)?; ty } 0x0000_0100..=0x0000_ffff => { let (discriminant, ty) = - find_variant_discriminant(0.., cases, discriminant)?; - dst.reserve(2 + usize::from(v.is_some())); - dst.put_u16_le(discriminant); + find_variant_discriminant(0u16.., cases, discriminant)?; + dst.reserve(3 + usize::from(v.is_some())); + Leb128Encoder.encode(discriminant, dst)?; ty } 0x0001_0000..=0x00ff_ffff => { let (discriminant, ty) = - find_variant_discriminant(0.., cases, discriminant)?; - dst.reserve(3 + usize::from(v.is_some())); - dst.put_slice(&u32::to_le_bytes(discriminant)[..3]); + find_variant_discriminant(0u32.., cases, discriminant)?; + dst.reserve(4 + usize::from(v.is_some())); + Leb128Encoder.encode(discriminant, dst)?; ty } 0x0100_0000..=0xffff_ffff => { let (discriminant, ty) = - find_variant_discriminant(0.., cases, discriminant)?; - dst.reserve(4 + usize::from(v.is_some())); - dst.put_u32_le(discriminant); + find_variant_discriminant(0u32.., cases, discriminant)?; + dst.reserve(5 + usize::from(v.is_some())); + Leb128Encoder.encode(discriminant, dst)?; ty } 0x1_0000_0000.. => bail!("case count does not fit in u32"), @@ -328,24 +270,24 @@ where let names = ty.names(); match names.len() { ..=0x0000_00ff => { - let discriminant = find_enum_discriminant(0.., names, discriminant)?; - dst.reserve(1); - dst.put_u8(discriminant); + let discriminant = find_enum_discriminant(0u8.., names, discriminant)?; + dst.reserve(2); + Leb128Encoder.encode(discriminant, dst)? } 0x0000_0100..=0x0000_ffff => { - let discriminant = find_enum_discriminant(0.., names, discriminant)?; - dst.reserve(2); - dst.put_u16_le(discriminant); + let discriminant = find_enum_discriminant(0u16.., names, discriminant)?; + dst.reserve(3); + Leb128Encoder.encode(discriminant, dst)? } 0x0001_0000..=0x00ff_ffff => { - let discriminant = find_enum_discriminant(0.., names, discriminant)?; - dst.reserve(3); - dst.put_slice(&u32::to_le_bytes(discriminant)[..3]); + let discriminant = find_enum_discriminant(0u32.., names, discriminant)?; + dst.reserve(4); + Leb128Encoder.encode(discriminant, dst)? } 0x0100_0000..=0xffff_ffff => { - let discriminant = find_enum_discriminant(0.., names, discriminant)?; - dst.reserve(4); - dst.put_u32_le(discriminant); + let discriminant = find_enum_discriminant(0u32.., names, discriminant)?; + dst.reserve(5); + Leb128Encoder.encode(discriminant, dst)? } 0x1_0000_0000.. => bail!("name count does not fit in u32"), } @@ -566,31 +508,6 @@ where } } -async fn read_discriminant( - cases: usize, - r: &mut (impl AsyncRead + Unpin), -) -> std::io::Result { - match cases { - ..=0x0000_00ff => r.read_u8().await.map(Into::into), - 0x0000_0100..=0x0000_ffff => r.read_u16_le().await.map(Into::into), - 0x0001_0000..=0x00ff_ffff => { - let mut buf = 0usize.to_le_bytes(); - r.read_exact(&mut buf[..3]).await?; - Ok(usize::from_le_bytes(buf)) - } - 0x0100_0000..=0xffff_ffff => { - let discriminant = r.read_u32_le().await?; - discriminant - .try_into() - .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err)) - } - 0x1_0000_0000.. => Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "case count does not fit in u32".to_string(), - )), - } -} - #[inline] async fn read_flags(n: usize, r: &mut (impl AsyncRead + Unpin)) -> std::io::Result { let mut buf = 0u128.to_le_bytes(); @@ -667,7 +584,7 @@ where Ok(()) } Type::Char => { - let v = r.read_char().await?; + let v = r.read_char_utf8().await?; *val = Val::Char(v); Ok(()) } @@ -722,9 +639,11 @@ where Ok(()) } Type::Variant(ty) => { - let mut cases = ty.cases(); - let discriminant = read_discriminant(cases.len(), r).await?; - let Case { name, ty } = cases.nth(discriminant).ok_or_else(|| { + let discriminant = r.read_u32_leb128().await?; + let discriminant = discriminant + .try_into() + .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?; + let Case { name, ty } = ty.cases().nth(discriminant).ok_or_else(|| { std::io::Error::new( std::io::ErrorKind::InvalidInput, format!("unknown variant discriminant `{discriminant}`"), @@ -741,9 +660,11 @@ where Ok(()) } Type::Enum(ty) => { - let mut names = ty.names(); - let discriminant = read_discriminant(names.len(), r).await?; - let name = names.nth(discriminant).ok_or_else(|| { + let discriminant = r.read_u32_leb128().await?; + let discriminant = discriminant + .try_into() + .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?; + let name = ty.names().nth(discriminant).ok_or_else(|| { std::io::Error::new( std::io::ErrorKind::InvalidInput, format!("unknown enum discriminant `{discriminant}`"),