Skip to content

Commit da63361

Browse files
authored
support continuing host-to-host streams after store is dropped (#11763)
* add `StreamReader::try_into` and use it in `Response::into_http_with_getter` This allows the embedder to consume a host-created `StreamReader` and extract the `StreamProducer` (or an `Unpin` subset of it) used to create it. This enables "short-circuiting" host-to-host streams, bypassing the guest entirely. Concretely, it is now possible to dispose of a `wasi:http/[email protected]` guest and its store before the request and/or response bodies have finished streaming in the case where both the read and write ends of the body stream are owned by the host. In this case, the inbound body is reused as the outbound body with no additional processing or memory overhead. Signed-off-by: Joel Dice <[email protected]> * spawn tokio task instead of store task for outbound requests Previously, `wasmtime-wasi-http`'s outbound request handler used `Accessor::spawn` to spawn the connection future for a given outbound request, but that meant the request stopped making progress when the store was dropped. That's a problem for a proxy scenario where the guest task just wants to make an outgoing request and return the response (or at least the body of that response) and then exit. At that point, we should be able to drop the instance and its store and expect the body stream will continue to flow. Using `tokio::task::spawn` ensures that will happen. Fixes #11703 Signed-off-by: Joel Dice <[email protected]> * test host-to-host streaming in wasi-http p3 tests In the process of adding these tests, I found that we were leaking subtasks which had exited but not been dropped by their supertask by the time that supertask was dropped, so I've fixed that. Signed-off-by: Joel Dice <[email protected]> * address review feedback Signed-off-by: Joel Dice <[email protected]> --------- Signed-off-by: Joel Dice <[email protected]>
1 parent 89fdfa1 commit da63361

File tree

10 files changed

+263
-124
lines changed

10 files changed

+263
-124
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/test-programs/src/bin/p3_http_echo.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ impl Handler for Component {
1818
let (_, result_rx) = wit_future::new(|| Ok(()));
1919
let (body, trailers) = Request::consume_body(request, result_rx);
2020

21-
let (response, _result) = if false {
21+
let (response, _result) = if headers
22+
.get("x-host-to-host")
23+
.into_iter()
24+
.any(|v| v == b"true")
25+
{
2226
// This is the easy and efficient way to do it...
2327
Response::new(headers, Some(body), trailers)
2428
} else {

crates/test-programs/src/bin/p3_http_middleware.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,16 @@ impl Handler for Component {
6060
wit_bindgen::spawn(async move {
6161
{
6262
let mut decoder = DeflateDecoder::new(Vec::new());
63+
let mut status = StreamResult::Complete(0);
64+
let mut chunk = Vec::with_capacity(64 * 1024);
6365

64-
let (mut status, mut chunk) = body.read(Vec::with_capacity(64 * 1024)).await;
6566
while let StreamResult::Complete(_) = status {
67+
(status, chunk) = body.read(chunk).await;
6668
decoder.write_all(&chunk).unwrap();
6769
let remaining = pipe_tx.write_all(mem::take(decoder.get_mut())).await;
6870
assert!(remaining.is_empty());
6971
*decoder.get_mut() = remaining;
7072
chunk.clear();
71-
(status, chunk) = body.read(chunk).await;
7273
}
7374

7475
let remaining = pipe_tx.write_all(decoder.finish().unwrap()).await;
@@ -123,15 +124,16 @@ impl Handler for Component {
123124
wit_bindgen::spawn(async move {
124125
{
125126
let mut encoder = DeflateEncoder::new(Vec::new(), Compression::fast());
126-
let (mut status, mut chunk) = body.read(Vec::with_capacity(64 * 1024)).await;
127+
let mut status = StreamResult::Complete(0);
128+
let mut chunk = Vec::with_capacity(64 * 1024);
127129

128130
while let StreamResult::Complete(_) = status {
131+
(status, chunk) = body.read(chunk).await;
129132
encoder.write_all(&chunk).unwrap();
130133
let remaining = pipe_tx.write_all(mem::take(encoder.get_mut())).await;
131134
assert!(remaining.is_empty());
132135
*encoder.get_mut() = remaining;
133136
chunk.clear();
134-
(status, chunk) = body.read(chunk).await;
135137
}
136138

137139
let remaining = pipe_tx.write_all(encoder.finish().unwrap()).await;

crates/wasi-http/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,4 @@ base64 = { workspace = true }
5454
flate2 = { workspace = true }
5555
wasm-compose = { workspace = true }
5656
tempfile = { workspace = true }
57+
env_logger = { workspace = true }

crates/wasi-http/src/p3/body.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use core::task::{Context, Poll, ready};
99
use http::HeaderMap;
1010
use http_body::Body as _;
1111
use http_body_util::combinators::BoxBody;
12+
use std::any::{Any, TypeId};
1213
use std::io::Cursor;
1314
use std::sync::Arc;
1415
use tokio::sync::{mpsc, oneshot};
@@ -445,8 +446,8 @@ where
445446
}
446447

447448
/// [StreamProducer] implementation for bodies originating in the host.
448-
struct HostBodyStreamProducer<T> {
449-
body: BoxBody<Bytes, ErrorCode>,
449+
pub(crate) struct HostBodyStreamProducer<T> {
450+
pub(crate) body: BoxBody<Bytes, ErrorCode>,
450451
trailers: Option<oneshot::Sender<Result<Option<Resource<Trailers>>, ErrorCode>>>,
451452
getter: fn(&mut T) -> WasiHttpCtxView<'_>,
452453
}
@@ -536,4 +537,13 @@ where
536537
self.close(res);
537538
Poll::Ready(Ok(StreamResult::Dropped))
538539
}
540+
541+
fn try_into(me: Pin<Box<Self>>, ty: TypeId) -> Result<Box<dyn Any>, Pin<Box<Self>>> {
542+
if ty == TypeId::of::<Self>() {
543+
let me = Pin::into_inner(me);
544+
Ok(me)
545+
} else {
546+
Err(me)
547+
}
548+
}
539549
}

crates/wasi-http/src/p3/host/handler.rs

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ use http::{HeaderValue, Uri};
1212
use http_body_util::BodyExt as _;
1313
use std::sync::Arc;
1414
use tokio::sync::oneshot;
15+
use tokio::task::{self, JoinHandle};
1516
use tracing::debug;
16-
use wasmtime::component::{Accessor, AccessorTask, JoinHandle, Resource};
17+
use wasmtime::component::{Accessor, Resource};
1718

1819
/// A wrapper around [`JoinHandle`], which will [`JoinHandle::abort`] the task
1920
/// when dropped
20-
struct AbortOnDropJoinHandle(JoinHandle);
21+
struct AbortOnDropJoinHandle(JoinHandle<()>);
2122

2223
impl Drop for AbortOnDropJoinHandle {
2324
fn drop(&mut self) {
@@ -171,20 +172,6 @@ trait BodyExt {
171172

172173
impl<T> BodyExt for T {}
173174

174-
struct SendRequestTask {
175-
io: Pin<Box<dyn Future<Output = Result<(), ErrorCode>> + Send>>,
176-
result_tx: oneshot::Sender<Result<(), ErrorCode>>,
177-
}
178-
179-
impl<T> AccessorTask<T, WasiHttp, wasmtime::Result<()>> for SendRequestTask {
180-
async fn run(self, _: &Accessor<T, WasiHttp>) -> wasmtime::Result<()> {
181-
let res = self.io.await;
182-
debug!(?res, "`send_request` I/O future finished");
183-
_ = self.result_tx.send(res);
184-
Ok(())
185-
}
186-
}
187-
188175
async fn io_task_result(
189176
rx: oneshot::Receiver<(
190177
Arc<AbortOnDropJoinHandle>,
@@ -336,7 +323,11 @@ impl HostWithStore for WasiHttp {
336323
Poll::Pending => {
337324
// I/O driver still needs to be polled, spawn a task and send handles to it
338325
let (tx, rx) = oneshot::channel();
339-
let io = store.spawn(SendRequestTask { io, result_tx: tx });
326+
let io = task::spawn(async move {
327+
let res = io.await;
328+
debug!(?res, "`send_request` I/O future finished");
329+
_ = tx.send(res);
330+
});
340331
let io = Arc::new(AbortOnDropJoinHandle(io));
341332
_ = io_result_tx.send((Arc::clone(&io), rx));
342333
_ = io_task_tx.send(Arc::clone(&io));

crates/wasi-http/src/p3/host/types.rs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@ use crate::p3::bindings::http::types::{
44
HostRequestOptions, HostRequestWithStore, HostResponse, HostResponseWithStore, Method, Request,
55
RequestOptions, RequestOptionsError, Response, Scheme, StatusCode, Trailers,
66
};
7-
use crate::p3::body::Body;
7+
use crate::p3::body::{Body, HostBodyStreamProducer};
88
use crate::p3::{HeaderResult, HttpError, RequestOptionsResult, WasiHttp, WasiHttpCtxView};
99
use anyhow::Context as _;
10+
use core::mem;
1011
use core::pin::Pin;
1112
use core::task::{Context, Poll, ready};
1213
use http::header::CONTENT_LENGTH;
1314
use std::sync::Arc;
1415
use tokio::sync::oneshot;
15-
use wasmtime::StoreContextMut;
1616
use wasmtime::component::{
1717
Access, Accessor, FutureProducer, FutureReader, Resource, ResourceTable, StreamReader,
1818
};
19+
use wasmtime::{AsContextMut, StoreContextMut};
1920

2021
fn get_fields<'a>(
2122
table: &'a ResourceTable,
@@ -602,13 +603,26 @@ impl HostResponseWithStore for WasiHttp {
602603
let instance = store.instance();
603604
store.with(|mut store| {
604605
let (result_tx, result_rx) = oneshot::channel();
606+
let body = match contents
607+
.map(|rx| rx.try_into::<HostBodyStreamProducer<T>>(store.as_context_mut()))
608+
{
609+
Some(Ok(mut producer)) => Body::Host {
610+
body: mem::take(&mut producer.body),
611+
result_tx,
612+
},
613+
Some(Err(rx)) => Body::Guest {
614+
contents_rx: Some(rx),
615+
trailers_rx: trailers,
616+
result_tx,
617+
},
618+
None => Body::Guest {
619+
contents_rx: None,
620+
trailers_rx: trailers,
621+
result_tx,
622+
},
623+
};
605624
let WasiHttpCtxView { table, .. } = store.get();
606625
let headers = delete_fields(table, headers)?;
607-
let body = Body::Guest {
608-
contents_rx: contents,
609-
trailers_rx: trailers,
610-
result_tx,
611-
};
612626
let res = Response {
613627
status: http::StatusCode::OK,
614628
headers: headers.into(),

crates/wasi-http/tests/all/p3/mod.rs

Lines changed: 88 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -123,25 +123,36 @@ async fn run_http<E: Into<ErrorCode> + 'static>(
123123
let instance = linker.instantiate_async(&mut store, &component).await?;
124124
let proxy = Proxy::new(&mut store, &instance)?;
125125
let (req, io) = Request::from_http(req);
126-
let (res, ()) = instance
127-
.run_concurrent(&mut store, async |store| {
128-
try_join!(
129-
async {
130-
let (res, task) = match proxy.handle(store, req).await? {
131-
Ok(pair) => pair,
132-
Err(err) => return Ok(Err(Some(err))),
133-
};
134-
let res = store.with(|store| res.into_http(store, async { Ok(()) }))?;
135-
let (parts, body) = res.into_parts();
136-
let body = body.collect().await.context("failed to collect body")?;
137-
task.block(store).await;
138-
Ok(Ok(http::Response::from_parts(parts, body)))
139-
},
140-
async { io.await.context("failed to consume request body") }
141-
)
142-
})
143-
.await??;
144-
Ok(res)
126+
let (tx, rx) = tokio::sync::oneshot::channel();
127+
let ((handle_result, ()), res) = try_join!(
128+
async move {
129+
instance
130+
.run_concurrent(&mut store, async |store| {
131+
try_join!(
132+
async {
133+
let (res, task) = match proxy.handle(store, req).await? {
134+
Ok(pair) => pair,
135+
Err(err) => return Ok(Err(Some(err))),
136+
};
137+
_ = tx
138+
.send(store.with(|store| res.into_http(store, async { Ok(()) }))?);
139+
task.block(store).await;
140+
Ok(Ok(()))
141+
},
142+
async { io.await.context("failed to consume request body") }
143+
)
144+
})
145+
.await?
146+
},
147+
async move {
148+
let res = rx.await?;
149+
let (parts, body) = res.into_parts();
150+
let body = body.collect().await.context("failed to collect body")?;
151+
anyhow::Ok(http::Response::from_parts(parts, body))
152+
}
153+
)?;
154+
155+
Ok(handle_result.map(|()| res))
145156
}
146157

147158
#[test_log::test(tokio::test(flavor = "multi_thread"))]
@@ -254,18 +265,32 @@ async fn wasi_http_proxy_tests() -> anyhow::Result<()> {
254265

255266
#[test_log::test(tokio::test(flavor = "multi_thread"))]
256267
async fn p3_http_echo() -> Result<()> {
257-
test_http_echo(P3_HTTP_ECHO_COMPONENT, false).await
268+
test_http_echo(P3_HTTP_ECHO_COMPONENT, false, false).await
269+
}
270+
271+
#[test_log::test(tokio::test(flavor = "multi_thread"))]
272+
async fn p3_http_echo_host_to_host() -> Result<()> {
273+
test_http_echo(P3_HTTP_ECHO_COMPONENT, false, true).await
258274
}
259275

260276
#[test_log::test(tokio::test(flavor = "multi_thread"))]
261277
async fn p3_http_middleware() -> Result<()> {
278+
test_http_middleware(false).await
279+
}
280+
281+
#[test_log::test(tokio::test(flavor = "multi_thread"))]
282+
async fn p3_http_middleware_host_to_host() -> Result<()> {
283+
test_http_middleware(true).await
284+
}
285+
286+
async fn test_http_middleware(host_to_host: bool) -> Result<()> {
262287
let tempdir = tempfile::tempdir()?;
263288
let echo = &fs::read(P3_HTTP_ECHO_COMPONENT).await?;
264289
let middleware = &fs::read(P3_HTTP_MIDDLEWARE_COMPONENT).await?;
265290

266291
let path = tempdir.path().join("temp.wasm");
267292
fs::write(&path, compose(middleware, echo).await?).await?;
268-
test_http_echo(&path.to_str().unwrap(), true).await
293+
test_http_echo(&path.to_str().unwrap(), true, host_to_host).await
269294
}
270295

271296
async fn compose(a: &[u8], b: &[u8]) -> Result<Vec<u8>> {
@@ -290,6 +315,15 @@ async fn compose(a: &[u8], b: &[u8]) -> Result<Vec<u8>> {
290315

291316
#[test_log::test(tokio::test(flavor = "multi_thread"))]
292317
async fn p3_http_middleware_with_chain() -> Result<()> {
318+
test_http_middleware_with_chain(false).await
319+
}
320+
321+
#[test_log::test(tokio::test(flavor = "multi_thread"))]
322+
async fn p3_http_middleware_with_chain_host_to_host() -> Result<()> {
323+
test_http_middleware_with_chain(true).await
324+
}
325+
326+
async fn test_http_middleware_with_chain(host_to_host: bool) -> Result<()> {
293327
let dir = tempfile::tempdir()?;
294328
let path = dir.path().join("temp.wasm");
295329

@@ -334,10 +368,12 @@ async fn p3_http_middleware_with_chain() -> Result<()> {
334368
.compose()?;
335369
fs::write(&path, &bytes).await?;
336370

337-
test_http_echo(&path.to_str().unwrap(), true).await
371+
test_http_echo(&path.to_str().unwrap(), true, host_to_host).await
338372
}
339373

340-
async fn test_http_echo(component: &str, use_compression: bool) -> Result<()> {
374+
async fn test_http_echo(component: &str, use_compression: bool, host_to_host: bool) -> Result<()> {
375+
_ = env_logger::try_init();
376+
341377
let body = b"And the mome raths outgrabe";
342378

343379
// Prepare the raw body, optionally compressed if that's what we're
@@ -353,23 +389,7 @@ async fn test_http_echo(component: &str, use_compression: bool) -> Result<()> {
353389
// Prepare the http_body body, modeled here as a channel with the body
354390
// chunk above buffered up followed by some trailers. Note that trailers
355391
// are always here to test that code paths throughout the components.
356-
let (mut body_tx, body_rx) = futures::channel::mpsc::channel::<Result<_, ErrorCode>>(2);
357-
body_tx
358-
.send(Ok(http_body::Frame::data(raw_body)))
359-
.await
360-
.unwrap();
361-
body_tx
362-
.send(Ok(http_body::Frame::trailers({
363-
let mut trailers = http::HeaderMap::new();
364-
assert!(
365-
trailers
366-
.insert("fizz", http::HeaderValue::from_static("buzz"))
367-
.is_none()
368-
);
369-
trailers
370-
})))
371-
.await
372-
.unwrap();
392+
let (mut body_tx, body_rx) = futures::channel::mpsc::channel::<Result<_, ErrorCode>>(1);
373393

374394
// Build the `http::Request`, optionally specifying compression-related
375395
// headers.
@@ -382,16 +402,40 @@ async fn test_http_echo(component: &str, use_compression: bool) -> Result<()> {
382402
.header("content-encoding", "deflate")
383403
.header("accept-encoding", "nonexistent-encoding, deflate");
384404
}
405+
if host_to_host {
406+
request = request.header("x-host-to-host", "true");
407+
}
385408

386409
// Send this request to wasm and assert that success comes back.
387410
//
388411
// Note that this will read the entire body internally and wait for
389412
// everything to get collected before proceeding to below.
390-
let response = run_http(
391-
component,
392-
request.body(http_body_util::StreamBody::new(body_rx))?,
413+
let response = futures::join!(
414+
run_http(
415+
component,
416+
request.body(http_body_util::StreamBody::new(body_rx))?,
417+
),
418+
async {
419+
body_tx
420+
.send(Ok(http_body::Frame::data(raw_body)))
421+
.await
422+
.unwrap();
423+
body_tx
424+
.send(Ok(http_body::Frame::trailers({
425+
let mut trailers = http::HeaderMap::new();
426+
assert!(
427+
trailers
428+
.insert("fizz", http::HeaderValue::from_static("buzz"))
429+
.is_none()
430+
);
431+
trailers
432+
})))
433+
.await
434+
.unwrap();
435+
drop(body_tx);
436+
}
393437
)
394-
.await?
438+
.0?
395439
.unwrap();
396440
assert!(response.status().as_u16() == 200);
397441

0 commit comments

Comments
 (0)