Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 59 additions & 5 deletions crates/transport/src/frame/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,12 @@ pin_project! {
path: Arc<[usize]>,
index: Arc<std::sync::Mutex<IndexTrie>>,
io: Arc<JoinSet<()>>,
// Track whether we've successfully read any data from this stream.
// Used to detect when we've consumed all available data and should return EOF.
has_read_data: bool,
// Track consecutive Pending results. If we get Pending multiple times after having
// attempted to read, it indicates we're waiting for data that will never arrive.
Comment on lines +310 to +311
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can make this assumption and there are a few reasons for it:

  • Various async executors are free to poll futures as many times as they want until they return Ready. That can be done, in fact, as a valid optimization - polling a future once with a noop task context to short-circuit the case when it's ready, then constructing a "real" task context and polling again to register the task for wake up.
  • The underlying I/O stream is allowed to wake up spuriously. For example, see https://docs.rs/tokio/latest/tokio/sync/mpsc/struct.Receiver.html#method.poll_recv . Returning Poll::Pending 5 times and then returning Poll::Ready would be a perfectly valid behavior, which would be disallowed by this PR.

I don't think that counting the times that "pending" was returned is a valid approach in a general case.

pending_count: u32,
}
}

Expand All @@ -330,6 +336,8 @@ impl Index<Self> for Incoming {
path,
index: Arc::clone(&self.index),
io: Arc::clone(&self.io),
has_read_data: false,
pending_count: 0,
})
}
}
Expand All @@ -350,12 +358,56 @@ impl AsyncRead for Incoming {
trace!("reader is closed");
return Poll::Ready(Ok(()));
};
ready!(rx.poll_read(cx, buf))?;
trace!(buf = ?buf.filled(), "read buffer");
if buf.filled().is_empty() {
self.rx.take();

// Save the initial filled length to detect if we read any data
let initial_len = buf.filled().len();

// Try to read from the stream
match rx.poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let new_len = buf.filled().len();
let diff_len = new_len - initial_len;

// Track that we've successfully read at least one frame
if !*this.has_read_data {
*this.has_read_data = true;
trace!("marked has_read_data=true (consumed first frame from channel)");
}

// Reset pending count since we got Ready
*this.pending_count = 0;

trace!(buf = ?buf.filled(), diff_len, "read buffer");

// If we got an empty read (no data added to buffer),
// it means we consumed an empty frame (clear signal that we've reached EOF)
// close the receiver immediately to signal end of stream
if diff_len == 0 {
trace!("consumed empty frame, closing receiver");
self.as_mut().get_mut().rx.take();
}
Comment on lines +385 to +388
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: the old code did

if buf.filled().is_empty() {
            self.rx.take();
}

but theoretically, it's possible for buf to not be empty when this function is called. I think this is a subtle bug, so I updated it to check the diff (i.e. if we have new data) instead

Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => {
*this.pending_count += 1;
trace!(pending_count = *this.pending_count, "got Pending");

// Check two cases for our heuristic:
// 1. `this.has_read_data`: we received data at some point, but have no data left to read
// this could be a slow connection - but also likely that there's just no more data coming
// 2. `this.pending_count > 1`: no matter how much we poll, we never get data
// this could be a slow connection where no data has received yet - but likely not
if *this.has_read_data || *this.pending_count > 1 {
trace!("pending after consuming frames or multiple pending, returning UnexpectedEof");
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"end of parameter data: insufficient data for expected parameters",
)));
}
Poll::Pending
}
}
Poll::Ready(Ok(()))
}
}

Expand Down Expand Up @@ -589,6 +641,8 @@ impl Conn {
path: Arc::from([]),
index: Arc::clone(&index),
io: Arc::new(rx_io),
has_read_data: false,
pending_count: 0,
},
}
}
Expand Down
219 changes: 219 additions & 0 deletions tests/deadlock.rs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: I wasn't sure where to put this test, so I just put it in a standalone file at the root of tests so it's easy to move in case you want it in a specific place

Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
#[cfg(test)]
mod parameter_validation_tests {
use core::pin::pin;
use std::collections::{HashMap};
use anyhow::Context as _;
use futures::StreamExt;
use tokio::join;
use tracing::instrument;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio::sync::Mutex;
use wasmtime::component::ResourceTable;
use wasmtime::component::{Component, Linker};
use wasmtime::Engine;
use wasmtime::Store;
use wasmtime_wasi::{WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
use wrpc_runtime_wasmtime::{
ServeExt,
SharedResourceTable, WrpcCtxView, WrpcView,
};
use wrpc_transport::frame::Oneshot;
use wrpc_transport::Invoke;
use tokio::io::AsyncReadExt;

pub struct WrpcCtx<C: Invoke> {
pub wrpc: C,
pub cx: C::Context,
pub shared_resources: SharedResourceTable,
pub timeout: Duration,
}

pub struct Ctx<C: Invoke> {
pub table: ResourceTable,
pub wasi: WasiCtx,
pub wrpc: WrpcCtx<C>,
}

impl<C> wrpc_runtime_wasmtime::WrpcCtx<C> for WrpcCtx<C>
where
C: Invoke,
C::Context: Clone,
{
fn context(&self) -> C::Context {
self.cx.clone()
}

fn client(&self) -> &C {
&self.wrpc
}

fn shared_resources(&mut self) -> &mut SharedResourceTable {
&mut self.shared_resources
}

fn timeout(&self) -> Option<Duration> {
Some(self.timeout)
}
}

impl<C> WrpcView for Ctx<C>
where
C: Invoke,
C::Context: Clone,
{
type Invoke = C;

fn wrpc(&mut self) -> WrpcCtxView<'_, Self::Invoke> {
WrpcCtxView {
ctx: &mut self.wrpc,
table: &mut self.table,
}
}
}

impl<C> WasiView for Ctx<C>
where
C: Invoke,
{
fn ctx(&mut self) -> WasiCtxView<'_> {
WasiCtxView {
ctx: &mut self.wasi,
table: &mut self.table,
}
}
}

pub fn gen_ctx<C: Invoke>(wrpc: C, cx: C::Context) -> Ctx<C> {
Ctx {
table: ResourceTable::new(),
wasi: WasiCtxBuilder::new().build(),
wrpc: WrpcCtx {
wrpc,
cx,
shared_resources: SharedResourceTable::default(),
timeout: Duration::from_secs(10),
},
}
}

/// Test that insufficient parameters cause an error instead of deadlock
///
/// This test verifies the fix for the DoS vulnerability where malformed
/// parameter data causes the server to hang indefinitely.
///
/// IMPORTANT: This test uses serve_function_shared (runtime-wasmtime layer)
/// with raw invoke(), which is where the bug manifests. Using serve_values
/// (transport layer) with invoke_values_blocking will NOT reproduce the bug.
#[test_log::test(tokio::test(flavor = "multi_thread"))]
#[instrument(ret)]
async fn test_parameter_validation_deadlock() -> anyhow::Result<()> {
let params = wit_bindgen_wrpc::bytes::Bytes::new(); // Empty params

let (oneshot_clt, oneshot_srv) = Oneshot::duplex(1024);
let srv = Arc::new(wrpc_transport::frame::Server::default());

let mut config = wasmtime::Config::default();
config.async_support(true);
let engine = Engine::new(&config).context("failed to create engine with async support")?;

let component_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("simple-args.wasm");
let component_bytes = std::fs::read(&component_path)
.with_context(|| format!("failed to read component from {:?}", component_path))?;

let component =
Component::new(&engine, component_bytes).context("failed to parse component")?;

let linker = Linker::new(&engine);
let mut store = Store::new(&engine, gen_ctx(oneshot_clt, ()));

let instance = linker
.instantiate_async(&mut store, &component)
.await
.context("failed to instantiate component")?;


let function = "get-value";
let func = instance
.get_func(&mut store, &function)
.ok_or_else(|| anyhow::anyhow!("function `{function}` not found in component"))?;

let fun_ty = func.ty(&store);

let guest_resources_vec = Vec::new();
let host_resources = HashMap::new();

let instance_name = "".to_string();

let store_shared = Arc::new(Mutex::new(store));
let store_shared_clone = store_shared.clone();
let invocations_stream = srv
.serve_function_shared(
store_shared,
instance,
Arc::from(guest_resources_vec.into_boxed_slice()),
Arc::from(host_resources),
fun_ty,
&instance_name,
&function,
)
.await
.with_context(|| {
format!("failed to register handler for function `{function}`")
})?;
let (result, invocation_handle) = join!(
// client side
async move {
let paths: &[&[Option<usize>]] = &[];
// Lock the store only to get the wrpc client and invoke
// Release the lock immediately after getting the streams
let (mut outgoing, mut incoming) = {
let store = store_shared_clone.lock().await;
store.data().wrpc.wrpc
.invoke((), &instance_name, &function, params, paths)
.await
.expect(&format!("failed to invoke {}", function))
};
// Lock is now released, allowing server to process the invocation
outgoing.flush().await?;
let mut buf = vec![];
Comment on lines +182 to +183
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
outgoing.flush().await?;
let mut buf = vec![];
outgoing.flush().await?;
drop(outgoing);
let mut buf = vec![];

I don't think this usage is valid. By not sending the parameters in the initial request and keeping the stream open, the client signals that it will send parameters asynchronously. This is a fundamental part of the design, which allows us to implement e.g. future and stream types.

Closing the stream (by dropping it, for example) as expected, does not cause a deadlock on current main

incoming
.read_to_end(&mut buf)
.await
.with_context(|| format!("failed to read result for root function `{function}`"))?;
Ok::<Vec<u8>, anyhow::Error>(buf)
},
// server side
async move {
srv.accept(oneshot_srv)
.await
.expect("failed to accept connection");

tokio::spawn(async move {
let mut invocations = pin!(invocations_stream);
while let Some(invocation) = invocations.as_mut().next().await {
match invocation {
Ok((_, fut)) => {
if let Err(err) = fut.await {
eprintln!("failed to serve invocation for root function `{function}`: {err:?}");
}
}
Err(err) => {
eprintln!("failed to accept invocation for root function `{function}`: {err:?}");
}
}
}
})
}
);
// Clean up the invocation handle since the oneshot connection is complete
// The stream should naturally end, but we abort to ensure cleanup happens immediately
invocation_handle.abort();
println!("result: {:?}", result);
Ok(())
}
}
Binary file added tests/simple-args.wasm
Binary file not shown.
Loading