Skip to content
Open
64 changes: 57 additions & 7 deletions crates/wasi-http/src/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ use std::future::Future;
use std::mem;
use std::task::{Context, Poll};
use std::{pin::Pin, sync::Arc, time::Duration};
use std::any::Any;
use async_trait::async_trait;
use tokio::sync::{mpsc, oneshot};
use wasmtime_wasi::preview2::{
self, poll_noop, AbortOnDropJoinHandle, HostInputStream, HostOutputStream, StreamError,
Subscribe,
};
use wasmtime_wasi::preview2::{self, poll_noop, AbortOnDropJoinHandle, HostInputStream, HostOutputStream, StreamError, Subscribe, StreamResult, InputStream};

pub type HyperIncomingBody = BoxBody<Bytes, types::ErrorCode>;

Expand Down Expand Up @@ -85,6 +84,16 @@ pub struct HostIncomingBody {
worker: Option<Arc<AbortOnDropJoinHandle<()>>>,
}

impl std::fmt::Debug for HostIncomingBody {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.body {
IncomingBodyState::Start(_) => write!(f, "HostIncomingBody {{ Start }}"),
IncomingBodyState::InBodyStream(_) => write!(f, "HostIncomingBody {{ InBodyStream }}"),
IncomingBodyState::Failing(_) => write!(f, "HostIncomingBody {{ Failing }}"),
}
}
}

enum IncomingBodyState {
/// The body is stored here meaning that within `HostIncomingBody` the
/// `take_stream` method can be called for example.
Expand All @@ -94,6 +103,8 @@ enum IncomingBodyState {
/// currently owned here. The body will be sent back over this channel when
/// it's done, however.
InBodyStream(oneshot::Receiver<StreamEnd>),

Failing(String),
}

/// Message sent when a `HostIncomingBodyStream` is done to the
Expand All @@ -117,26 +128,35 @@ impl HostIncomingBody {
}
}

pub fn failing(error: String) -> HostIncomingBody {
HostIncomingBody {
body: IncomingBodyState::Failing(error),
worker: None,
}
}

pub fn retain_worker(&mut self, worker: &Arc<AbortOnDropJoinHandle<()>>) {
assert!(self.worker.is_none());
self.worker = Some(worker.clone());
}

pub fn take_stream(&mut self) -> Option<HostIncomingBodyStream> {
pub fn take_stream(&mut self) -> Option<InputStream> {
match &mut self.body {
IncomingBodyState::Start(_) => {}
IncomingBodyState::Failing(error) => return Some(InputStream::Host(Box::new(FailingStream { error: error.clone() }))),
IncomingBodyState::InBodyStream(_) => return None,
}
let (tx, rx) = oneshot::channel();
let body = match mem::replace(&mut self.body, IncomingBodyState::InBodyStream(rx)) {
IncomingBodyState::Start(b) => b,
IncomingBodyState::InBodyStream(_) => unreachable!(),
IncomingBodyState::Failing(_) => unreachable!(),
};
Some(HostIncomingBodyStream {
Some(InputStream::Host(Box::new(HostIncomingBodyStream {
state: IncomingBodyStreamState::Open { body, tx },
buffer: Bytes::new(),
error: None,
})
})))
}

pub fn into_future_trailers(self) -> HostFutureTrailers {
Expand Down Expand Up @@ -203,6 +223,8 @@ impl HostInputStream for HostIncomingBodyStream {
}
}
}

fn as_any(&self) -> &dyn Any { self }
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -342,9 +364,14 @@ impl Subscribe for HostFutureTrailers {
HostFutureTrailers::Done(_) => return,
HostFutureTrailers::Consumed => return,
};
if let IncomingBodyState::Failing(_) = &mut body.body {
*self = HostFutureTrailers::Done(Err(types::ErrorCode::ConnectionTerminated));
return;
}
let hyper_body = match &mut body.body {
IncomingBodyState::Start(body) => body,
IncomingBodyState::InBodyStream(_) => unreachable!(),
IncomingBodyState::Failing(_) => unreachable!(),
};
let result = loop {
match hyper_body.frame().await {
Expand Down Expand Up @@ -567,6 +594,8 @@ impl BodyWriteStream {

#[async_trait::async_trait]
impl HostOutputStream for BodyWriteStream {
fn as_any(&self) -> &dyn Any { self }

fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
let len = bytes.len();
match self.writer.try_send(bytes) {
Expand Down Expand Up @@ -634,3 +663,24 @@ impl Subscribe for BodyWriteStream {
let _ = self.writer.reserve().await;
}
}


pub struct FailingStream {
error: String
}

#[async_trait]
impl Subscribe for FailingStream {
async fn ready(&mut self) {
}
}

impl HostInputStream for FailingStream {
fn read(&mut self, _size: usize) -> StreamResult<Bytes> {
Err(StreamError::trap(&self.error))
}

fn as_any(&self) -> &dyn Any {
self
}
}
4 changes: 3 additions & 1 deletion crates/wasi-http/src/http_impl.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use async_trait::async_trait;
use crate::{
bindings::http::{
outgoing_handler,
Expand All @@ -12,8 +13,9 @@ use http_body_util::{BodyExt, Empty};
use hyper::Method;
use wasmtime::component::Resource;

#[async_trait]
impl<T: WasiHttpView> outgoing_handler::Host for T {
fn handle(
async fn handle(
&mut self,
request_id: Resource<HostOutgoingRequest>,
options: Option<Resource<types::RequestOptions>>,
Expand Down
8 changes: 7 additions & 1 deletion crates/wasi-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ pub mod bindings {
import wasi:http/[email protected];
",
tracing: true,
async: false,
async: {
only_imports: [
"handle",
"[method]future-incoming-response.get",
"[method]future-trailers.get",
],
},
with: {
"wasi:io/error": wasmtime_wasi::preview2::bindings::io::error,
"wasi:io/streams": wasmtime_wasi::preview2::bindings::io::streams,
Expand Down
9 changes: 7 additions & 2 deletions crates/wasi-http/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ pub fn default_send_request(
Ok(fut)
}

async fn handler(
pub(crate) async fn handler(
authority: String,
use_tls: bool,
connect_timeout: Duration,
Expand Down Expand Up @@ -405,13 +405,18 @@ pub enum HostFutureIncomingResponse {
Pending(FutureIncomingResponseHandle),
Ready(anyhow::Result<Result<IncomingResponseInternal, types::ErrorCode>>),
Consumed,
Deferred(OutgoingRequest)
}

impl HostFutureIncomingResponse {
pub fn new(handle: FutureIncomingResponseHandle) -> Self {
Self::Pending(handle)
}

pub fn deferred(request: OutgoingRequest) -> Self {
Self::Deferred(request)
}

pub fn is_ready(&self) -> bool {
matches!(self, Self::Ready(_))
}
Expand All @@ -421,7 +426,7 @@ impl HostFutureIncomingResponse {
) -> anyhow::Result<Result<IncomingResponseInternal, types::ErrorCode>> {
match self {
Self::Ready(res) => res,
Self::Pending(_) | Self::Consumed => {
Self::Pending(_) | Self::Consumed | Self::Deferred(_) => {
panic!("unwrap_ready called on a pending HostFutureIncomingResponse")
}
}
Expand Down
50 changes: 43 additions & 7 deletions crates/wasi-http/src/types_impl.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::types::OutgoingRequest;
use crate::{
bindings::http::types::{self, Headers, Method, Scheme, StatusCode, Trailers},
body::{HostFutureTrailers, HostIncomingBody, HostOutgoingBody, StreamContext},
Expand All @@ -8,7 +9,8 @@ use crate::{
},
WasiHttpView,
};
use anyhow::Context;
use anyhow::{anyhow, Context};
use async_trait::async_trait;
use std::any::Any;
use std::str::FromStr;
use wasmtime::component::{Resource, ResourceTable};
Expand Down Expand Up @@ -60,7 +62,7 @@ fn move_fields(table: &mut ResourceTable, id: Resource<HostFields>) -> wasmtime:
}
}

fn get_fields<'a>(
pub fn get_fields<'a>(
table: &'a mut ResourceTable,
id: &Resource<HostFields>,
) -> wasmtime::Result<&'a FieldMap> {
Expand Down Expand Up @@ -631,6 +633,7 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostIncomingResponse for T {
}
}

#[async_trait]
impl<T: WasiHttpView> crate::bindings::http::types::HostFutureTrailers for T {
fn drop(&mut self, id: Resource<HostFutureTrailers>) -> wasmtime::Result<()> {
let _ = self
Expand All @@ -644,10 +647,10 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFutureTrailers for T {
&mut self,
index: Resource<HostFutureTrailers>,
) -> wasmtime::Result<Resource<Pollable>> {
wasmtime_wasi::preview2::subscribe(self.table(), index)
wasmtime_wasi::preview2::subscribe(self.table(), index, None)
}

fn get(
async fn get(
&mut self,
id: Resource<HostFutureTrailers>,
) -> wasmtime::Result<Option<Result<Result<Option<Resource<Trailers>>, types::ErrorCode>, ()>>>
Expand Down Expand Up @@ -686,7 +689,6 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostIncomingBody for T {
let body = self.table().get_mut(&id)?;

if let Some(stream) = body.take_stream() {
let stream = InputStream::Host(Box::new(stream));
let stream = self.table().push_child(stream, &id)?;
return Ok(Ok(stream));
}
Expand Down Expand Up @@ -798,13 +800,14 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostOutgoingResponse for T {
}
}

#[async_trait]
impl<T: WasiHttpView> crate::bindings::http::types::HostFutureIncomingResponse for T {
fn drop(&mut self, id: Resource<HostFutureIncomingResponse>) -> wasmtime::Result<()> {
let _ = self.table().delete(id)?;
Ok(())
}

fn get(
async fn get(
&mut self,
id: Resource<HostFutureIncomingResponse>,
) -> wasmtime::Result<
Expand All @@ -816,6 +819,39 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFutureIncomingResponse f
HostFutureIncomingResponse::Pending(_) => return Ok(None),
HostFutureIncomingResponse::Consumed => return Ok(Some(Err(()))),
HostFutureIncomingResponse::Ready(_) => {}
HostFutureIncomingResponse::Deferred(_) => {
let (tx, rx) = tokio::sync::oneshot::channel();
let handle = wasmtime_wasi::preview2::spawn(async move {
let request = rx.await.map_err(|err| anyhow!(err))?;
let HostFutureIncomingResponse::Deferred(OutgoingRequest {
use_tls,
authority,
request,
connect_timeout,
first_byte_timeout,
between_bytes_timeout,
}) = request
else {
return Err(anyhow!("unexpected incoming response state".to_string()));
};
let resp = crate::types::handler(
authority,
use_tls,
connect_timeout,
first_byte_timeout,
request,
between_bytes_timeout,
)
.await;
Ok(resp)
});
tx.send(std::mem::replace(
resp,
HostFutureIncomingResponse::Pending(handle),
))
.map_err(|_| anyhow!("failed to send request to handler"))?;
return Ok(None);
}
}

let resp =
Expand Down Expand Up @@ -852,7 +888,7 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFutureIncomingResponse f
&mut self,
id: Resource<HostFutureIncomingResponse>,
) -> wasmtime::Result<Resource<Pollable>> {
wasmtime_wasi::preview2::subscribe(self.table(), id)
wasmtime_wasi::preview2::subscribe(self.table(), id, None)
}
}

Expand Down
Loading