Skip to content
Open
64 changes: 57 additions & 7 deletions crates/wasi-http/src/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ use std::future::Future;
use std::mem;
use std::task::{Context, Poll};
use std::{pin::Pin, sync::Arc, time::Duration};
use std::any::Any;
use async_trait::async_trait;
use tokio::sync::{mpsc, oneshot};
use wasmtime_wasi::preview2::{
self, poll_noop, AbortOnDropJoinHandle, HostInputStream, HostOutputStream, StreamError,
Subscribe,
};
use wasmtime_wasi::preview2::{self, poll_noop, AbortOnDropJoinHandle, HostInputStream, HostOutputStream, StreamError, Subscribe, StreamResult, InputStream};

pub type HyperIncomingBody = BoxBody<Bytes, types::ErrorCode>;

Expand Down Expand Up @@ -85,6 +84,16 @@ pub struct HostIncomingBody {
worker: Option<Arc<AbortOnDropJoinHandle<()>>>,
}

impl std::fmt::Debug for HostIncomingBody {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.body {
IncomingBodyState::Start(_) => write!(f, "HostIncomingBody {{ Start }}"),
IncomingBodyState::InBodyStream(_) => write!(f, "HostIncomingBody {{ InBodyStream }}"),
IncomingBodyState::Failing(_) => write!(f, "HostIncomingBody {{ Failing }}"),
}
}
}

enum IncomingBodyState {
/// The body is stored here meaning that within `HostIncomingBody` the
/// `take_stream` method can be called for example.
Expand All @@ -94,6 +103,8 @@ enum IncomingBodyState {
/// currently owned here. The body will be sent back over this channel when
/// it's done, however.
InBodyStream(oneshot::Receiver<StreamEnd>),

Failing(String),
}

/// Message sent when a `HostIncomingBodyStream` is done to the
Expand All @@ -117,26 +128,35 @@ impl HostIncomingBody {
}
}

pub fn failing(error: String) -> HostIncomingBody {
HostIncomingBody {
body: IncomingBodyState::Failing(error),
worker: None,
}
}

pub fn retain_worker(&mut self, worker: &Arc<AbortOnDropJoinHandle<()>>) {
assert!(self.worker.is_none());
self.worker = Some(worker.clone());
}

pub fn take_stream(&mut self) -> Option<HostIncomingBodyStream> {
pub fn take_stream(&mut self) -> Option<InputStream> {
match &mut self.body {
IncomingBodyState::Start(_) => {}
IncomingBodyState::Failing(error) => return Some(InputStream::Host(Box::new(FailingStream { error: error.clone() }))),
IncomingBodyState::InBodyStream(_) => return None,
}
let (tx, rx) = oneshot::channel();
let body = match mem::replace(&mut self.body, IncomingBodyState::InBodyStream(rx)) {
IncomingBodyState::Start(b) => b,
IncomingBodyState::InBodyStream(_) => unreachable!(),
IncomingBodyState::Failing(_) => unreachable!(),
};
Some(HostIncomingBodyStream {
Some(InputStream::Host(Box::new(HostIncomingBodyStream {
state: IncomingBodyStreamState::Open { body, tx },
buffer: Bytes::new(),
error: None,
})
})))
}

pub fn into_future_trailers(self) -> HostFutureTrailers {
Expand Down Expand Up @@ -203,6 +223,8 @@ impl HostInputStream for HostIncomingBodyStream {
}
}
}

fn as_any(&self) -> &dyn Any { self }
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -342,9 +364,14 @@ impl Subscribe for HostFutureTrailers {
HostFutureTrailers::Done(_) => return,
HostFutureTrailers::Consumed => return,
};
if let IncomingBodyState::Failing(_) = &mut body.body {
*self = HostFutureTrailers::Done(Err(types::ErrorCode::ConnectionTerminated));
return;
}
let hyper_body = match &mut body.body {
IncomingBodyState::Start(body) => body,
IncomingBodyState::InBodyStream(_) => unreachable!(),
IncomingBodyState::Failing(_) => unreachable!(),
};
let result = loop {
match hyper_body.frame().await {
Expand Down Expand Up @@ -567,6 +594,8 @@ impl BodyWriteStream {

#[async_trait::async_trait]
impl HostOutputStream for BodyWriteStream {
fn as_any(&self) -> &dyn Any { self }

fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
let len = bytes.len();
match self.writer.try_send(bytes) {
Expand Down Expand Up @@ -634,3 +663,24 @@ impl Subscribe for BodyWriteStream {
let _ = self.writer.reserve().await;
}
}


pub struct FailingStream {
error: String
}

#[async_trait]
impl Subscribe for FailingStream {
async fn ready(&mut self) {
}
}

impl HostInputStream for FailingStream {
fn read(&mut self, _size: usize) -> StreamResult<Bytes> {
Err(StreamError::trap(&self.error))
}

fn as_any(&self) -> &dyn Any {
self
}
}
7 changes: 6 additions & 1 deletion crates/wasi-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ pub mod bindings {
import wasi:http/[email protected];
",
tracing: true,
async: false,
async: {
only_imports: [
"[method]future-incoming-response.get",
"[method]future-trailers.get",
],
},
with: {
"wasi:io/error": wasmtime_wasi::preview2::bindings::io::error,
"wasi:io/streams": wasmtime_wasi::preview2::bindings::io::streams,
Expand Down
9 changes: 7 additions & 2 deletions crates/wasi-http/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ pub fn default_send_request(
Ok(fut)
}

async fn handler(
pub(crate) async fn handler(
authority: String,
use_tls: bool,
connect_timeout: Duration,
Expand Down Expand Up @@ -405,13 +405,18 @@ pub enum HostFutureIncomingResponse {
Pending(FutureIncomingResponseHandle),
Ready(anyhow::Result<Result<IncomingResponseInternal, types::ErrorCode>>),
Consumed,
Deferred(OutgoingRequest)
}

impl HostFutureIncomingResponse {
pub fn new(handle: FutureIncomingResponseHandle) -> Self {
Self::Pending(handle)
}

pub fn deferred(request: OutgoingRequest) -> Self {
Self::Deferred(request)
}

pub fn is_ready(&self) -> bool {
matches!(self, Self::Ready(_))
}
Expand All @@ -421,7 +426,7 @@ impl HostFutureIncomingResponse {
) -> anyhow::Result<Result<IncomingResponseInternal, types::ErrorCode>> {
match self {
Self::Ready(res) => res,
Self::Pending(_) | Self::Consumed => {
Self::Pending(_) | Self::Consumed | Self::Deferred(_) => {
panic!("unwrap_ready called on a pending HostFutureIncomingResponse")
}
}
Expand Down
50 changes: 43 additions & 7 deletions crates/wasi-http/src/types_impl.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::types::OutgoingRequest;
use crate::{
bindings::http::types::{self, Headers, Method, Scheme, StatusCode, Trailers},
body::{HostFutureTrailers, HostIncomingBody, HostOutgoingBody, StreamContext},
Expand All @@ -8,7 +9,8 @@ use crate::{
},
WasiHttpView,
};
use anyhow::Context;
use anyhow::{anyhow, Context};
use async_trait::async_trait;
use std::any::Any;
use std::str::FromStr;
use wasmtime::component::{Resource, ResourceTable};
Expand Down Expand Up @@ -60,7 +62,7 @@ fn move_fields(table: &mut ResourceTable, id: Resource<HostFields>) -> wasmtime:
}
}

fn get_fields<'a>(
pub fn get_fields<'a>(
table: &'a mut ResourceTable,
id: &Resource<HostFields>,
) -> wasmtime::Result<&'a FieldMap> {
Expand Down Expand Up @@ -631,6 +633,7 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostIncomingResponse for T {
}
}

#[async_trait]
impl<T: WasiHttpView> crate::bindings::http::types::HostFutureTrailers for T {
fn drop(&mut self, id: Resource<HostFutureTrailers>) -> wasmtime::Result<()> {
let _ = self
Expand All @@ -644,10 +647,10 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFutureTrailers for T {
&mut self,
index: Resource<HostFutureTrailers>,
) -> wasmtime::Result<Resource<Pollable>> {
wasmtime_wasi::preview2::subscribe(self.table(), index)
wasmtime_wasi::preview2::subscribe(self.table(), index, None)
}

fn get(
async fn get(
&mut self,
id: Resource<HostFutureTrailers>,
) -> wasmtime::Result<Option<Result<Result<Option<Resource<Trailers>>, types::ErrorCode>, ()>>>
Expand Down Expand Up @@ -686,7 +689,6 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostIncomingBody for T {
let body = self.table().get_mut(&id)?;

if let Some(stream) = body.take_stream() {
let stream = InputStream::Host(Box::new(stream));
let stream = self.table().push_child(stream, &id)?;
return Ok(Ok(stream));
}
Expand Down Expand Up @@ -798,13 +800,14 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostOutgoingResponse for T {
}
}

#[async_trait]
impl<T: WasiHttpView> crate::bindings::http::types::HostFutureIncomingResponse for T {
fn drop(&mut self, id: Resource<HostFutureIncomingResponse>) -> wasmtime::Result<()> {
let _ = self.table().delete(id)?;
Ok(())
}

fn get(
async fn get(
&mut self,
id: Resource<HostFutureIncomingResponse>,
) -> wasmtime::Result<
Expand All @@ -816,6 +819,39 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFutureIncomingResponse f
HostFutureIncomingResponse::Pending(_) => return Ok(None),
HostFutureIncomingResponse::Consumed => return Ok(Some(Err(()))),
HostFutureIncomingResponse::Ready(_) => {}
HostFutureIncomingResponse::Deferred(_) => {
let (tx, rx) = tokio::sync::oneshot::channel();
let handle = wasmtime_wasi::preview2::spawn(async move {
let request = rx.await.map_err(|err| anyhow!(err))?;
let HostFutureIncomingResponse::Deferred(OutgoingRequest {
use_tls,
authority,
request,
connect_timeout,
first_byte_timeout,
between_bytes_timeout,
}) = request
else {
return Err(anyhow!("unexpected incoming response state".to_string()));
};
let resp = crate::types::handler(
authority,
use_tls,
connect_timeout,
first_byte_timeout,
request,
between_bytes_timeout,
)
.await;
Ok(resp)
});
tx.send(std::mem::replace(
resp,
HostFutureIncomingResponse::Pending(handle),
))
.map_err(|_| anyhow!("failed to send request to handler"))?;
return Ok(None);
}
}

let resp =
Expand Down Expand Up @@ -852,7 +888,7 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFutureIncomingResponse f
&mut self,
id: Resource<HostFutureIncomingResponse>,
) -> wasmtime::Result<Resource<Pollable>> {
wasmtime_wasi::preview2::subscribe(self.table(), id)
wasmtime_wasi::preview2::subscribe(self.table(), id, None)
}
}

Expand Down
25 changes: 24 additions & 1 deletion crates/wasi/src/preview2/ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use crate::preview2::{
use cap_rand::{Rng, RngCore, SeedableRng};
use std::sync::Arc;
use std::{mem, net::SocketAddr};
use std::path::PathBuf;
use std::time::Duration;
use wasmtime::component::ResourceTable;

pub struct WasiCtxBuilder {
Expand All @@ -28,6 +30,8 @@ pub struct WasiCtxBuilder {
wall_clock: Box<dyn HostWallClock + Send + Sync>,
monotonic_clock: Box<dyn HostMonotonicClock + Send + Sync>,
allowed_network_uses: AllowedNetworkUses,
suspend_threshold: Duration,
suspend_signal: Box<dyn Fn(Duration) -> anyhow::Error + Send + Sync + 'static>,
built: bool,
}

Expand Down Expand Up @@ -78,6 +82,8 @@ impl WasiCtxBuilder {
wall_clock: wall_clock(),
monotonic_clock: monotonic_clock(),
allowed_network_uses: AllowedNetworkUses::default(),
suspend_threshold: Duration::MAX,
suspend_signal: Box::new(|_| unreachable!("suspend_signal not set")),
built: false,
}
}
Expand Down Expand Up @@ -143,9 +149,10 @@ impl WasiCtxBuilder {
perms: DirPerms,
file_perms: FilePerms,
path: impl AsRef<str>,
host_path: PathBuf,
) -> &mut Self {
self.preopens
.push((Dir::new(dir, perms, file_perms), path.as_ref().to_owned()));
.push((Dir::new(dir, perms, file_perms, host_path), path.as_ref().to_owned()));
self
}

Expand Down Expand Up @@ -222,6 +229,16 @@ impl WasiCtxBuilder {
self
}

pub fn set_suspend(
&mut self,
suspend_threshold: Duration,
suspend_signal: impl Fn(Duration) -> anyhow::Error + Send + Sync + 'static,
) -> &mut Self {
self.suspend_threshold = suspend_threshold;
self.suspend_signal = Box::new(suspend_signal);
self
}

/// Uses the configured context so far to construct the final `WasiCtx`.
///
/// Note that each `WasiCtxBuilder` can only be used to "build" once, and
Expand All @@ -247,6 +264,8 @@ impl WasiCtxBuilder {
wall_clock,
monotonic_clock,
allowed_network_uses,
suspend_threshold,
suspend_signal,
built: _,
} = mem::replace(self, Self::new());
self.built = true;
Expand All @@ -265,6 +284,8 @@ impl WasiCtxBuilder {
wall_clock,
monotonic_clock,
allowed_network_uses,
suspend_signal,
suspend_threshold,
}
}
}
Expand All @@ -290,6 +311,8 @@ pub struct WasiCtx {
pub(crate) stderr: Box<dyn StdoutStream>,
pub(crate) socket_addr_check: SocketAddrCheck,
pub(crate) allowed_network_uses: AllowedNetworkUses,
pub(crate) suspend_threshold: Duration,
pub(crate) suspend_signal: Box<dyn Fn(Duration) -> anyhow::Error + Send + Sync + 'static>,
}

pub struct AllowedNetworkUses {
Expand Down
Loading