Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ tower = { version = "0.4", default-features = false }
tracing = { version = "0.1", default-features = false }
tracing-subscriber = { version = "0.3", default-features = false }
url = { version = "2", default-features = false }
wasm-tokio = { version = "0.1", default-features = false }
wasm-tokio = { version = "0.4", default-features = false }
wasmcloud-component-adapters = { version = "0.9", default-features = false }
wasmparser = { version = "0.208", default-features = false }
wasmtime = { version = "21", default-features = false }
Expand Down
167 changes: 44 additions & 123 deletions crates/runtime-wasmtime/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,96 +1,38 @@
#![allow(clippy::type_complexity)] // TODO: https://github.com/wrpc/wrpc/issues/2

use core::fmt::{self, Display};
use core::future::Future;
use core::iter::zip;
use core::ops::{BitOrAssign, Shl};
use core::pin::{pin, Pin};
use core::task::{Context, Poll};

use std::collections::HashSet;
use std::error::Error;
use std::sync::Arc;

use anyhow::{anyhow, bail, Context as _};
use bytes::{BufMut as _, Bytes, BytesMut};
use bytes::{BufMut as _, BytesMut};
use futures::stream::FuturesUnordered;
use futures::{Stream, TryStreamExt as _};
use futures::TryStreamExt as _;
use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _};
use tokio::try_join;
use tokio_util::codec::Encoder;
use tracing::{error, trace};
use tracing::{instrument, warn};
use wasm_tokio::cm::{AsyncReadValue as _, CharEncoder};
use wasm_tokio::{AsyncReadCore as _, CoreStringEncoder, Leb128Encoder};
use wasm_tokio::cm::AsyncReadValue as _;
use wasm_tokio::{
AsyncReadCore as _, AsyncReadLeb128 as _, AsyncReadUtf8 as _, CoreStringEncoder, Leb128Encoder,
Utf8Encoder,
};
use wasmtime::component::types::{self, Case, Field};
use wasmtime::component::{Linker, ResourceType, Type, Val};
use wasmtime::{AsContextMut, StoreContextMut};
use wasmtime_wasi::pipe::AsyncReadStream;
use wasmtime_wasi::{FileInputStream, HostInputStream, InputStream, StreamError, WasiView};
use wasmtime_wasi::{InputStream, StreamError, WasiView};
use wit_parser::FunctionKind;
use wrpc_introspect::rpc_func_name;
use wrpc_transport::{Index as _, Invocation, Invoke, Session};

pub struct RemoteResource(pub String);

pub struct OutgoingHostInputStream(Box<dyn HostInputStream>);

#[derive(Debug)]
pub enum OutgoingStreamError {
Failed(anyhow::Error),
Trap(anyhow::Error),
}

impl Display for OutgoingStreamError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Failed(err) => write!(f, "last operation failed: {err:#}"),
Self::Trap(err) => write!(f, "trap: {err:#}"),
}
}
}

impl Error for OutgoingStreamError {}

impl Stream for OutgoingHostInputStream {
type Item = Result<Bytes, OutgoingStreamError>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match pin!(self.0.ready()).poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(()) => {}
}
match self.0.read(8096) {
Ok(buf) => Poll::Ready(Some(Ok(buf))),
Err(StreamError::LastOperationFailed(err)) => {
Poll::Ready(Some(Err(OutgoingStreamError::Failed(err))))
}
Err(StreamError::Trap(err)) => Poll::Ready(Some(Err(OutgoingStreamError::Trap(err)))),
Err(StreamError::Closed) => Poll::Ready(None),
}
}
}

pub struct OutgoingFileInputStream(FileInputStream);

impl Stream for OutgoingFileInputStream {
type Item = Result<Bytes, OutgoingStreamError>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match pin!(self.0.read(8096)).poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(buf)) => Poll::Ready(Some(Ok(buf))),
Poll::Ready(Err(StreamError::LastOperationFailed(err))) => {
Poll::Ready(Some(Err(OutgoingStreamError::Failed(err))))
}
Poll::Ready(Err(StreamError::Trap(err))) => {
Poll::Ready(Some(Err(OutgoingStreamError::Trap(err))))
}
Poll::Ready(Err(StreamError::Closed)) => Poll::Ready(None),
}
}
}

pub struct ValEncoder<'a, T, W> {
pub store: StoreContextMut<'a, T>,
pub ty: &'a Type,
Expand Down Expand Up @@ -228,7 +170,7 @@ where
Ok(())
}
(Val::Char(v), Type::Char) => {
CharEncoder.encode(*v, dst).context("failed to encode char")
Utf8Encoder.encode(*v, dst).context("failed to encode char")
}
(Val::String(v), Type::String) => CoreStringEncoder
.encode(v.as_str(), dst)
Expand Down Expand Up @@ -285,30 +227,30 @@ where
let ty = match cases.len() {
..=0x0000_00ff => {
let (discriminant, ty) =
find_variant_discriminant(0.., cases, discriminant)?;
dst.reserve(1 + usize::from(v.is_some()));
dst.put_u8(discriminant);
find_variant_discriminant(0u8.., cases, discriminant)?;
dst.reserve(2 + usize::from(v.is_some()));
Leb128Encoder.encode(discriminant, dst)?;
ty
}
0x0000_0100..=0x0000_ffff => {
let (discriminant, ty) =
find_variant_discriminant(0.., cases, discriminant)?;
dst.reserve(2 + usize::from(v.is_some()));
dst.put_u16_le(discriminant);
find_variant_discriminant(0u16.., cases, discriminant)?;
dst.reserve(3 + usize::from(v.is_some()));
Leb128Encoder.encode(discriminant, dst)?;
ty
}
0x0001_0000..=0x00ff_ffff => {
let (discriminant, ty) =
find_variant_discriminant(0.., cases, discriminant)?;
dst.reserve(3 + usize::from(v.is_some()));
dst.put_slice(&u32::to_le_bytes(discriminant)[..3]);
find_variant_discriminant(0u32.., cases, discriminant)?;
dst.reserve(4 + usize::from(v.is_some()));
Leb128Encoder.encode(discriminant, dst)?;
ty
}
0x0100_0000..=0xffff_ffff => {
let (discriminant, ty) =
find_variant_discriminant(0.., cases, discriminant)?;
dst.reserve(4 + usize::from(v.is_some()));
dst.put_u32_le(discriminant);
find_variant_discriminant(0u32.., cases, discriminant)?;
dst.reserve(5 + usize::from(v.is_some()));
Leb128Encoder.encode(discriminant, dst)?;
ty
}
0x1_0000_0000.. => bail!("case count does not fit in u32"),
Expand All @@ -328,24 +270,24 @@ where
let names = ty.names();
match names.len() {
..=0x0000_00ff => {
let discriminant = find_enum_discriminant(0.., names, discriminant)?;
dst.reserve(1);
dst.put_u8(discriminant);
let discriminant = find_enum_discriminant(0u8.., names, discriminant)?;
dst.reserve(2);
Leb128Encoder.encode(discriminant, dst)?
}
0x0000_0100..=0x0000_ffff => {
let discriminant = find_enum_discriminant(0.., names, discriminant)?;
dst.reserve(2);
dst.put_u16_le(discriminant);
let discriminant = find_enum_discriminant(0u16.., names, discriminant)?;
dst.reserve(3);
Leb128Encoder.encode(discriminant, dst)?
}
0x0001_0000..=0x00ff_ffff => {
let discriminant = find_enum_discriminant(0.., names, discriminant)?;
dst.reserve(3);
dst.put_slice(&u32::to_le_bytes(discriminant)[..3]);
let discriminant = find_enum_discriminant(0u32.., names, discriminant)?;
dst.reserve(4);
Leb128Encoder.encode(discriminant, dst)?
}
0x0100_0000..=0xffff_ffff => {
let discriminant = find_enum_discriminant(0.., names, discriminant)?;
dst.reserve(4);
dst.put_u32_le(discriminant);
let discriminant = find_enum_discriminant(0u32.., names, discriminant)?;
dst.reserve(5);
Leb128Encoder.encode(discriminant, dst)?
}
0x1_0000_0000.. => bail!("name count does not fit in u32"),
}
Expand Down Expand Up @@ -566,31 +508,6 @@ where
}
}

async fn read_discriminant(
cases: usize,
r: &mut (impl AsyncRead + Unpin),
) -> std::io::Result<usize> {
match cases {
..=0x0000_00ff => r.read_u8().await.map(Into::into),
0x0000_0100..=0x0000_ffff => r.read_u16_le().await.map(Into::into),
0x0001_0000..=0x00ff_ffff => {
let mut buf = 0usize.to_le_bytes();
r.read_exact(&mut buf[..3]).await?;
Ok(usize::from_le_bytes(buf))
}
0x0100_0000..=0xffff_ffff => {
let discriminant = r.read_u32_le().await?;
discriminant
.try_into()
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))
}
0x1_0000_0000.. => Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"case count does not fit in u32".to_string(),
)),
}
}

#[inline]
async fn read_flags(n: usize, r: &mut (impl AsyncRead + Unpin)) -> std::io::Result<u128> {
let mut buf = 0u128.to_le_bytes();
Expand Down Expand Up @@ -667,7 +584,7 @@ where
Ok(())
}
Type::Char => {
let v = r.read_char().await?;
let v = r.read_char_utf8().await?;
*val = Val::Char(v);
Ok(())
}
Expand Down Expand Up @@ -722,9 +639,11 @@ where
Ok(())
}
Type::Variant(ty) => {
let mut cases = ty.cases();
let discriminant = read_discriminant(cases.len(), r).await?;
let Case { name, ty } = cases.nth(discriminant).ok_or_else(|| {
let discriminant = r.read_u32_leb128().await?;
let discriminant = discriminant
.try_into()
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
let Case { name, ty } = ty.cases().nth(discriminant).ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("unknown variant discriminant `{discriminant}`"),
Expand All @@ -741,9 +660,11 @@ where
Ok(())
}
Type::Enum(ty) => {
let mut names = ty.names();
let discriminant = read_discriminant(names.len(), r).await?;
let name = names.nth(discriminant).ok_or_else(|| {
let discriminant = r.read_u32_leb128().await?;
let discriminant = discriminant
.try_into()
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
let name = ty.names().nth(discriminant).ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("unknown enum discriminant `{discriminant}`"),
Expand Down