Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 57 additions & 7 deletions crates/wasi-http/src/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@ use std::future::Future;
use std::mem;
use std::task::{Context, Poll};
use std::{pin::Pin, sync::Arc, time::Duration};
use std::any::Any;
use async_trait::async_trait;
use tokio::sync::{mpsc, oneshot};
use wasmtime_wasi::preview2::{
self, poll_noop, AbortOnDropJoinHandle, HostInputStream, HostOutputStream, StreamError,
Subscribe,
};
use wasmtime_wasi::preview2::{self, poll_noop, AbortOnDropJoinHandle, HostInputStream, HostOutputStream, StreamError, Subscribe, StreamResult, InputStream};

pub type HyperIncomingBody = BoxBody<Bytes, types::ErrorCode>;

Expand Down Expand Up @@ -86,6 +85,16 @@ pub struct HostIncomingBody {
worker: Option<Arc<AbortOnDropJoinHandle<()>>>,
}

impl std::fmt::Debug for HostIncomingBody {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.body {
IncomingBodyState::Start(_) => write!(f, "HostIncomingBody {{ Start }}"),
IncomingBodyState::InBodyStream(_) => write!(f, "HostIncomingBody {{ InBodyStream }}"),
IncomingBodyState::Failing(_) => write!(f, "HostIncomingBody {{ Failing }}"),
}
}
}

enum IncomingBodyState {
/// The body is stored here meaning that within `HostIncomingBody` the
/// `take_stream` method can be called for example.
Expand All @@ -95,6 +104,8 @@ enum IncomingBodyState {
/// currently owned here. The body will be sent back over this channel when
/// it's done, however.
InBodyStream(oneshot::Receiver<StreamEnd>),

Failing(String),
}

/// Message sent when a `HostIncomingBodyStream` is done to the
Expand All @@ -118,26 +129,35 @@ impl HostIncomingBody {
}
}

pub fn failing(error: String) -> HostIncomingBody {
HostIncomingBody {
body: IncomingBodyState::Failing(error),
worker: None,
}
}

pub fn retain_worker(&mut self, worker: &Arc<AbortOnDropJoinHandle<()>>) {
assert!(self.worker.is_none());
self.worker = Some(worker.clone());
}

pub fn take_stream(&mut self) -> Option<HostIncomingBodyStream> {
pub fn take_stream(&mut self) -> Option<InputStream> {
match &mut self.body {
IncomingBodyState::Start(_) => {}
IncomingBodyState::Failing(error) => return Some(InputStream::Host(Box::new(FailingStream { error: error.clone() }))),
IncomingBodyState::InBodyStream(_) => return None,
}
let (tx, rx) = oneshot::channel();
let body = match mem::replace(&mut self.body, IncomingBodyState::InBodyStream(rx)) {
IncomingBodyState::Start(b) => b,
IncomingBodyState::InBodyStream(_) => unreachable!(),
IncomingBodyState::Failing(_) => unreachable!(),
};
Some(HostIncomingBodyStream {
Some(InputStream::Host(Box::new(HostIncomingBodyStream {
state: IncomingBodyStreamState::Open { body, tx },
buffer: Bytes::new(),
error: None,
})
})))
}

pub fn into_future_trailers(self) -> HostFutureTrailers {
Expand Down Expand Up @@ -204,6 +224,8 @@ impl HostInputStream for HostIncomingBodyStream {
}
}
}

fn as_any(&self) -> &dyn Any { self }
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -338,9 +360,14 @@ impl Subscribe for HostFutureTrailers {
HostFutureTrailers::Waiting(body) => body,
HostFutureTrailers::Done(_) => return,
};
if let IncomingBodyState::Failing(_) = &mut body.body {
*self = HostFutureTrailers::Done(Err(types::ErrorCode::ConnectionTerminated));
return;
}
let hyper_body = match &mut body.body {
IncomingBodyState::Start(body) => body,
IncomingBodyState::InBodyStream(_) => unreachable!(),
IncomingBodyState::Failing(_) => unreachable!(),
};
let result = loop {
match hyper_body.frame().await {
Expand Down Expand Up @@ -549,6 +576,8 @@ impl BodyWriteStream {

#[async_trait::async_trait]
impl HostOutputStream for BodyWriteStream {
fn as_any(&self) -> &dyn Any { self }

fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
let len = bytes.len();
match self.writer.try_send(bytes) {
Expand Down Expand Up @@ -615,3 +644,24 @@ impl Subscribe for BodyWriteStream {
let _ = self.writer.reserve().await;
}
}


pub struct FailingStream {
error: String
}

#[async_trait]
impl Subscribe for FailingStream {
async fn ready(&mut self) {
}
}

impl HostInputStream for FailingStream {
fn read(&mut self, _size: usize) -> StreamResult<Bytes> {
Err(StreamError::trap(&self.error))
}

fn as_any(&self) -> &dyn Any {
self
}
}
7 changes: 6 additions & 1 deletion crates/wasi-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ pub mod bindings {
import wasi:http/[email protected];
",
tracing: true,
async: false,
async: {
only_imports: [
"[method]future-incoming-response.get",
"[method]future-trailers.get",
],
},
with: {
"wasi:io/error": wasmtime_wasi::preview2::bindings::io::error,
"wasi:io/streams": wasmtime_wasi::preview2::bindings::io::streams,
Expand Down
9 changes: 7 additions & 2 deletions crates/wasi-http/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ pub fn default_send_request(
Ok(fut)
}

async fn handler(
pub(crate) async fn handler(
authority: String,
use_tls: bool,
connect_timeout: Duration,
Expand Down Expand Up @@ -401,13 +401,18 @@ pub enum HostFutureIncomingResponse {
Pending(FutureIncomingResponseHandle),
Ready(anyhow::Result<Result<IncomingResponseInternal, types::ErrorCode>>),
Consumed,
Deferred(OutgoingRequest)
}

impl HostFutureIncomingResponse {
pub fn new(handle: FutureIncomingResponseHandle) -> Self {
Self::Pending(handle)
}

pub fn deferred(request: OutgoingRequest) -> Self {
Self::Deferred(request)
}

pub fn is_ready(&self) -> bool {
matches!(self, Self::Ready(_))
}
Expand All @@ -417,7 +422,7 @@ impl HostFutureIncomingResponse {
) -> anyhow::Result<Result<IncomingResponseInternal, types::ErrorCode>> {
match self {
Self::Ready(res) => res,
Self::Pending(_) | Self::Consumed => {
Self::Pending(_) | Self::Consumed | Self::Deferred(_) => {
panic!("unwrap_ready called on a pending HostFutureIncomingResponse")
}
}
Expand Down
51 changes: 43 additions & 8 deletions crates/wasi-http/src/types_impl.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::types::OutgoingRequest;
use crate::{
bindings::http::types::{self, Headers, Method, Scheme, StatusCode, Trailers},
body::{HostFutureTrailers, HostIncomingBody, HostOutgoingBody},
Expand All @@ -8,7 +9,8 @@ use crate::{
},
WasiHttpView,
};
use anyhow::Context;
use anyhow::{anyhow, Context};
use async_trait::async_trait;
use std::any::Any;
use std::str::FromStr;
use wasmtime::component::Resource;
Expand Down Expand Up @@ -60,7 +62,7 @@ fn move_fields(table: &mut Table, id: Resource<HostFields>) -> wasmtime::Result<
}
}

fn get_fields<'a>(
pub fn get_fields<'a>(
table: &'a mut Table,
id: &Resource<HostFields>,
) -> wasmtime::Result<&'a FieldMap> {
Expand Down Expand Up @@ -490,7 +492,6 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostOutgoingRequest for T {
let req = self.table().get_mut(&request)?;

if let Some(s) = authority.as_ref() {
println!("checking authority {s}");
let auth = match http::uri::Authority::from_str(s.as_str()) {
Ok(auth) => auth,
Err(_) => return Ok(Err(())),
Expand Down Expand Up @@ -618,6 +619,7 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostIncomingResponse for T {
}
}

#[async_trait]
impl<T: WasiHttpView> crate::bindings::http::types::HostFutureTrailers for T {
fn drop(&mut self, id: Resource<HostFutureTrailers>) -> wasmtime::Result<()> {
let _ = self
Expand All @@ -631,10 +633,10 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFutureTrailers for T {
&mut self,
index: Resource<HostFutureTrailers>,
) -> wasmtime::Result<Resource<Pollable>> {
wasmtime_wasi::preview2::subscribe(self.table(), index)
wasmtime_wasi::preview2::subscribe(self.table(), index, None)
}

fn get(
async fn get(
&mut self,
id: Resource<HostFutureTrailers>,
) -> wasmtime::Result<Option<Result<Option<Resource<Trailers>>, types::ErrorCode>>> {
Expand Down Expand Up @@ -674,7 +676,6 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostIncomingBody for T {
let body = self.table().get_mut(&id)?;

if let Some(stream) = body.take_stream() {
let stream = InputStream::Host(Box::new(stream));
let stream = self.table().push_child(stream, &id)?;
return Ok(Ok(stream));
}
Expand Down Expand Up @@ -786,13 +787,14 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostOutgoingResponse for T {
}
}

#[async_trait]
impl<T: WasiHttpView> crate::bindings::http::types::HostFutureIncomingResponse for T {
fn drop(&mut self, id: Resource<HostFutureIncomingResponse>) -> wasmtime::Result<()> {
let _ = self.table().delete(id)?;
Ok(())
}

fn get(
async fn get(
&mut self,
id: Resource<HostFutureIncomingResponse>,
) -> wasmtime::Result<
Expand All @@ -804,6 +806,39 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFutureIncomingResponse f
HostFutureIncomingResponse::Pending(_) => return Ok(None),
HostFutureIncomingResponse::Consumed => return Ok(Some(Err(()))),
HostFutureIncomingResponse::Ready(_) => {}
HostFutureIncomingResponse::Deferred(_) => {
let (tx, rx) = tokio::sync::oneshot::channel();
let handle = wasmtime_wasi::preview2::spawn(async move {
let request = rx.await.map_err(|err| anyhow!(err))?;
let HostFutureIncomingResponse::Deferred(OutgoingRequest {
use_tls,
authority,
request,
connect_timeout,
first_byte_timeout,
between_bytes_timeout,
}) = request
else {
return Err(anyhow!("unexpected incoming response state".to_string()));
};
let resp = crate::types::handler(
authority,
use_tls,
connect_timeout,
first_byte_timeout,
request,
between_bytes_timeout,
)
.await;
Ok(resp)
});
tx.send(std::mem::replace(
resp,
HostFutureIncomingResponse::Pending(handle),
))
.map_err(|_| anyhow!("failed to send request to handler"))?;
return Ok(None);
}
}

let resp =
Expand Down Expand Up @@ -840,7 +875,7 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFutureIncomingResponse f
&mut self,
id: Resource<HostFutureIncomingResponse>,
) -> wasmtime::Result<Resource<Pollable>> {
wasmtime_wasi::preview2::subscribe(self.table(), id)
wasmtime_wasi::preview2::subscribe(self.table(), id, None)
}
}

Expand Down
25 changes: 24 additions & 1 deletion crates/wasi/src/preview2/ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use cap_std::net::Pool;
use cap_std::{ambient_authority, AmbientAuthority};
use std::mem;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::path::PathBuf;
use std::time::Duration;

pub struct WasiCtxBuilder {
stdin: Box<dyn StdinStream>,
Expand All @@ -28,6 +30,8 @@ pub struct WasiCtxBuilder {
wall_clock: Box<dyn HostWallClock + Send + Sync>,
monotonic_clock: Box<dyn HostMonotonicClock + Send + Sync>,
allow_ip_name_lookup: bool,
suspend_threshold: Duration,
suspend_signal: Box<dyn Fn(Duration) -> anyhow::Error + Send + Sync + 'static>,
built: bool,
}

Expand Down Expand Up @@ -78,6 +82,8 @@ impl WasiCtxBuilder {
wall_clock: wall_clock(),
monotonic_clock: monotonic_clock(),
allow_ip_name_lookup: false,
suspend_threshold: Duration::MAX,
suspend_signal: Box::new(|_| unreachable!("suspend_signal not set")),
built: false,
}
}
Expand Down Expand Up @@ -143,9 +149,10 @@ impl WasiCtxBuilder {
perms: DirPerms,
file_perms: FilePerms,
path: impl AsRef<str>,
host_path: PathBuf,
) -> &mut Self {
self.preopens
.push((Dir::new(dir, perms, file_perms), path.as_ref().to_owned()));
.push((Dir::new(dir, perms, file_perms, host_path), path.as_ref().to_owned()));
self
}

Expand Down Expand Up @@ -254,6 +261,16 @@ impl WasiCtxBuilder {
self
}

pub fn set_suspend(
&mut self,
suspend_threshold: Duration,
suspend_signal: impl Fn(Duration) -> anyhow::Error + Send + Sync + 'static,
) -> &mut Self {
self.suspend_threshold = suspend_threshold;
self.suspend_signal = Box::new(suspend_signal);
self
}

/// Uses the configured context so far to construct the final `WasiCtx`.
///
/// Note that each `WasiCtxBuilder` can only be used to "build" once, and
Expand All @@ -279,6 +296,8 @@ impl WasiCtxBuilder {
wall_clock,
monotonic_clock,
allow_ip_name_lookup,
suspend_threshold,
suspend_signal,
built: _,
} = mem::replace(self, Self::new());
self.built = true;
Expand All @@ -297,6 +316,8 @@ impl WasiCtxBuilder {
wall_clock,
monotonic_clock,
allow_ip_name_lookup,
suspend_signal,
suspend_threshold,
}
}
}
Expand All @@ -322,4 +343,6 @@ pub struct WasiCtx {
pub(crate) stderr: Box<dyn StdoutStream>,
pub(crate) pool: Pool,
pub(crate) allow_ip_name_lookup: bool,
pub(crate) suspend_threshold: Duration,
pub(crate) suspend_signal: Box<dyn Fn(Duration) -> anyhow::Error + Send + Sync + 'static>,
}
Loading