Skip to content
Open
56 changes: 48 additions & 8 deletions crates/wasi-http/src/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ use std::future::Future;
use std::mem;
use std::task::{Context, Poll};
use std::{pin::Pin, sync::Arc, time::Duration};
use std::any::Any;
use async_trait::async_trait;
use tokio::sync::{mpsc, oneshot};
use wasmtime_wasi::{
runtime::{poll_noop, AbortOnDropJoinHandle},
HostInputStream, HostOutputStream, StreamError, Subscribe,
};
use wasmtime_wasi::{runtime::{poll_noop, AbortOnDropJoinHandle}, HostInputStream, HostOutputStream, StreamError, Subscribe, StreamResult, InputStream};

/// Common type for incoming bodies.
pub type HyperIncomingBody = BoxBody<Bytes, types::ErrorCode>;
Expand All @@ -41,28 +40,38 @@ impl HostIncomingBody {
}
}

/// Create a new `HostIncomingBody` that's immediately failing with the given `error`.
pub fn failing(error: String) -> HostIncomingBody {
HostIncomingBody {
body: IncomingBodyState::Failing(error),
worker: None,
}
}

/// Retain a worker task that needs to be kept alive while this body is being read.
pub fn retain_worker(&mut self, worker: AbortOnDropJoinHandle<()>) {
assert!(self.worker.is_none());
self.worker = Some(worker);
}

/// Try taking the stream of this body, if it's available.
pub fn take_stream(&mut self) -> Option<HostIncomingBodyStream> {
pub fn take_stream(&mut self) -> Option<InputStream> {
match &mut self.body {
IncomingBodyState::Start(_) => {}
IncomingBodyState::Failing(error) => return Some(InputStream::Host(Box::new(FailingStream { error: error.clone() }))),
IncomingBodyState::InBodyStream(_) => return None,
}
let (tx, rx) = oneshot::channel();
let body = match mem::replace(&mut self.body, IncomingBodyState::InBodyStream(rx)) {
IncomingBodyState::Start(b) => b,
IncomingBodyState::InBodyStream(_) => unreachable!(),
IncomingBodyState::Failing(_) => unreachable!(),
};
Some(HostIncomingBodyStream {
Some(InputStream::Host(Box::new(HostIncomingBodyStream {
state: IncomingBodyStreamState::Open { body, tx },
buffer: Bytes::new(),
error: None,
})
})))
}

/// Convert this body into a `HostFutureTrailers` resource.
Expand All @@ -81,6 +90,8 @@ enum IncomingBodyState {
/// currently owned here. The body will be sent back over this channel when
/// it's done, however.
InBodyStream(oneshot::Receiver<StreamEnd>),

Failing(String),
}

/// Small wrapper around [`HyperIncomingBody`] which adds a timeout to every frame.
Expand Down Expand Up @@ -262,6 +273,8 @@ impl HostInputStream for HostIncomingBodyStream {
}
}
}

fn as_any(&self) -> &dyn Any { self }
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -357,9 +370,14 @@ impl Subscribe for HostFutureTrailers {
HostFutureTrailers::Done(_) => return,
HostFutureTrailers::Consumed => return,
};
if let IncomingBodyState::Failing(_) = &mut body.body {
*self = HostFutureTrailers::Done(Err(types::ErrorCode::ConnectionTerminated));
return;
}
let hyper_body = match &mut body.body {
IncomingBodyState::Start(body) => body,
IncomingBodyState::InBodyStream(_) => unreachable!(),
IncomingBodyState::Failing(_) => unreachable!(),
};
let result = loop {
match hyper_body.frame().await {
Expand Down Expand Up @@ -471,7 +489,7 @@ impl HostOutgoingBody {
body_receiver,
finish_receiver: Some(finish_receiver),
}
.boxed();
.boxed();

// TODO: this capacity constant is arbitrary, and should be configurable
let output_stream =
Expand Down Expand Up @@ -594,6 +612,8 @@ impl BodyWriteStream {

#[async_trait::async_trait]
impl HostOutputStream for BodyWriteStream {
fn as_any(&self) -> &dyn Any { self }

fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
let len = bytes.len();
match self.writer.try_send(bytes) {
Expand Down Expand Up @@ -661,3 +681,23 @@ impl Subscribe for BodyWriteStream {
let _ = self.writer.reserve().await;
}
}

/// A stream that fails on every read.
pub struct FailingStream {
error: String,
}

#[async_trait]
impl Subscribe for FailingStream {
async fn ready(&mut self) {}
}

impl HostInputStream for FailingStream {
fn read(&mut self, _size: usize) -> StreamResult<Bytes> {
Err(StreamError::LastOperationFailed(anyhow!(self.error.clone())))
}

fn as_any(&self) -> &dyn Any {
self
}
}
4 changes: 3 additions & 1 deletion crates/wasi-http/src/http_impl.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Implementation of the `wasi:http/outgoing-handler` interface.

use async_trait::async_trait;
use crate::{
bindings::http::{
outgoing_handler,
Expand All @@ -15,8 +16,9 @@ use http_body_util::{BodyExt, Empty};
use hyper::Method;
use wasmtime::component::Resource;

#[async_trait]
impl<T: WasiHttpView> outgoing_handler::Host for T {
fn handle(
async fn handle(
&mut self,
request_id: Resource<HostOutgoingRequest>,
options: Option<Resource<types::RequestOptions>>,
Expand Down
10 changes: 9 additions & 1 deletion crates/wasi-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ pub mod bindings {
import wasi:http/[email protected];
",
tracing: true,
async: false,
async: {
only_imports: [
"handle",
"[method]future-incoming-response.get",
"[method]future-trailers.get",
],
},
trappable_imports: true,
with: {
// Upstream package dependencies
Expand Down Expand Up @@ -89,3 +95,5 @@ pub use crate::error::{
};
#[doc(inline)]
pub use crate::types::{WasiHttpCtx, WasiHttpView};

pub use crate::types_impl::get_fields;
40 changes: 28 additions & 12 deletions crates/wasi-http/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl WasiHttpCtx {
}

/// A trait which provides internal WASI HTTP state.
pub trait WasiHttpView {
pub trait WasiHttpView: Send {
/// Returns a mutable reference to the WASI HTTP context.
fn ctx(&mut self) -> &mut WasiHttpCtx;

Expand All @@ -42,8 +42,8 @@ pub trait WasiHttpView {
&mut self,
req: hyper::Request<HyperIncomingBody>,
) -> wasmtime::Result<Resource<HostIncomingRequest>>
where
Self: Sized,
where
Self: Sized,
{
let (parts, body) = req.into_parts();
let body = HostIncomingBody::new(
Expand Down Expand Up @@ -257,9 +257,9 @@ pub async fn default_send_request_handler(
connect_timeout,
hyper::client::conn::http1::handshake(stream),
)
.await
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
.map_err(hyper_request_error)?;
.await
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
.map_err(hyper_request_error)?;

let worker = wasmtime_wasi::runtime::spawn(async move {
match conn.await {
Expand All @@ -279,9 +279,9 @@ pub async fn default_send_request_handler(
// TODO: we should plumb the builder through the http context, and use it here
hyper::client::conn::http1::handshake(tcp_stream),
)
.await
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
.map_err(hyper_request_error)?;
.await
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
.map_err(hyper_request_error)?;

let worker = wasmtime_wasi::runtime::spawn(async move {
match conn.await {
Expand Down Expand Up @@ -389,7 +389,7 @@ impl HostIncomingRequest {
pub struct HostResponseOutparam {
/// The sender for sending a response.
pub result:
tokio::sync::oneshot::Sender<Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>>,
tokio::sync::oneshot::Sender<Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>>,
}

/// The concrete type behind a `wasi:http/types/outgoing-response` resource.
Expand Down Expand Up @@ -488,7 +488,7 @@ pub type FieldMap = hyper::HeaderMap;

/// A handle to a future incoming response.
pub type FutureIncomingResponseHandle =
AbortOnDropJoinHandle<anyhow::Result<Result<IncomingResponse, types::ErrorCode>>>;
AbortOnDropJoinHandle<anyhow::Result<Result<IncomingResponse, types::ErrorCode>>>;

/// A response that is in the process of being received.
pub struct IncomingResponse {
Expand All @@ -510,6 +510,13 @@ pub enum HostFutureIncomingResponse {
Ready(anyhow::Result<Result<IncomingResponse, types::ErrorCode>>),
/// The response has been consumed.
Consumed,
/// The request is deferred, to be executed the first time the future is polled
Deferred {
/// Outgoing request
request: hyper::Request<HyperOutgoingBody>,
/// Outgoing request configuration
config: OutgoingRequestConfig,
},
}

impl HostFutureIncomingResponse {
Expand All @@ -523,6 +530,15 @@ impl HostFutureIncomingResponse {
Self::Ready(result)
}

/// Returns `true` if the response is ready.
pub fn deferred(request: hyper::Request<HyperOutgoingBody>,
config: OutgoingRequestConfig) -> Self {
Self::Deferred {
request,
config,
}
}

/// Returns `true` if the response is ready.
pub fn is_ready(&self) -> bool {
matches!(self, Self::Ready(_))
Expand All @@ -532,7 +548,7 @@ impl HostFutureIncomingResponse {
pub fn unwrap_ready(self) -> anyhow::Result<Result<IncomingResponse, types::ErrorCode>> {
match self {
Self::Ready(res) => res,
Self::Pending(_) | Self::Consumed => {
Self::Pending(_) | Self::Consumed | Self::Deferred { .. } => {
panic!("unwrap_ready called on a pending HostFutureIncomingResponse")
}
}
Expand Down
38 changes: 31 additions & 7 deletions crates/wasi-http/src/types_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use crate::{
},
WasiHttpView,
};
use anyhow::Context;
use anyhow::{anyhow, Context};
use async_trait::async_trait;
use std::any::Any;
use std::str::FromStr;
use wasmtime::component::{Resource, ResourceTable};
Expand Down Expand Up @@ -69,7 +70,8 @@ fn move_fields(
}
}

fn get_fields<'a>(
#[allow(missing_docs)]
pub fn get_fields<'a>(
table: &'a mut ResourceTable,
id: &Resource<HostFields>,
) -> wasmtime::Result<&'a FieldMap> {
Expand Down Expand Up @@ -640,6 +642,7 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostIncomingResponse for T {
}
}

#[async_trait]
impl<T: WasiHttpView> crate::bindings::http::types::HostFutureTrailers for T {
fn drop(&mut self, id: Resource<HostFutureTrailers>) -> wasmtime::Result<()> {
let _ = self
Expand All @@ -653,10 +656,10 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFutureTrailers for T {
&mut self,
index: Resource<HostFutureTrailers>,
) -> wasmtime::Result<Resource<Pollable>> {
wasmtime_wasi::subscribe(self.table(), index)
wasmtime_wasi::subscribe(self.table(), index, None)
}

fn get(
async fn get(
&mut self,
id: Resource<HostFutureTrailers>,
) -> wasmtime::Result<Option<Result<Result<Option<Resource<Trailers>>, types::ErrorCode>, ()>>>
Expand Down Expand Up @@ -695,7 +698,6 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostIncomingBody for T {
let body = self.table().get_mut(&id)?;

if let Some(stream) = body.take_stream() {
let stream = InputStream::Host(Box::new(stream));
let stream = self.table().push_child(stream, &id)?;
return Ok(Ok(stream));
}
Expand Down Expand Up @@ -807,13 +809,14 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostOutgoingResponse for T {
}
}

#[async_trait]
impl<T: WasiHttpView> crate::bindings::http::types::HostFutureIncomingResponse for T {
fn drop(&mut self, id: Resource<HostFutureIncomingResponse>) -> wasmtime::Result<()> {
let _ = self.table().delete(id)?;
Ok(())
}

fn get(
async fn get(
&mut self,
id: Resource<HostFutureIncomingResponse>,
) -> wasmtime::Result<
Expand All @@ -825,6 +828,27 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFutureIncomingResponse f
HostFutureIncomingResponse::Pending(_) => return Ok(None),
HostFutureIncomingResponse::Consumed => return Ok(Some(Err(()))),
HostFutureIncomingResponse::Ready(_) => {}
HostFutureIncomingResponse::Deferred { .. } => {
let (tx, rx) = tokio::sync::oneshot::channel();
let handle = wasmtime_wasi::runtime::spawn(async move {
let request = rx.await.map_err(|err| anyhow!(err))?;
let HostFutureIncomingResponse::Deferred { request, config } = request
else {
return Err(anyhow!("unexpected incoming response state".to_string()));
};
let resp = crate::types::default_send_request_handler(
request, config,
)
.await;
Ok(resp)
});
tx.send(std::mem::replace(
resp,
HostFutureIncomingResponse::Pending(handle),
))
.map_err(|_| anyhow!("failed to send request to handler"))?;
return Ok(None);
}
}

let resp =
Expand Down Expand Up @@ -862,7 +886,7 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFutureIncomingResponse f
&mut self,
id: Resource<HostFutureIncomingResponse>,
) -> wasmtime::Result<Resource<Pollable>> {
wasmtime_wasi::subscribe(self.table(), id)
wasmtime_wasi::subscribe(self.table(), id, None)
}
}

Expand Down
Loading