Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 59 additions & 5 deletions crates/transport/src/frame/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,12 @@ pin_project! {
path: Arc<[usize]>,
index: Arc<std::sync::Mutex<IndexTrie>>,
io: Arc<JoinSet<()>>,
// Track whether we've successfully read any data from this stream.
// Used to detect when we've consumed all available data and should return EOF.
has_read_data: bool,
// Track consecutive Pending results. If we get Pending multiple times after having
// attempted to read, it indicates we're waiting for data that will never arrive.
pending_count: u32,
}
}

Expand All @@ -330,6 +336,8 @@ impl Index<Self> for Incoming {
path,
index: Arc::clone(&self.index),
io: Arc::clone(&self.io),
has_read_data: false,
pending_count: 0,
})
}
}
Expand All @@ -350,12 +358,56 @@ impl AsyncRead for Incoming {
trace!("reader is closed");
return Poll::Ready(Ok(()));
};
ready!(rx.poll_read(cx, buf))?;
trace!(buf = ?buf.filled(), "read buffer");
if buf.filled().is_empty() {
self.rx.take();

// Save the initial filled length to detect if we read any data
let initial_len = buf.filled().len();

// Try to read from the stream
match rx.poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let new_len = buf.filled().len();
let diff_len = new_len - initial_len;

// Track that we've successfully read at least one frame
if !*this.has_read_data {
*this.has_read_data = true;
trace!("marked has_read_data=true (consumed first frame from channel)");
}

// Reset pending count since we got Ready
*this.pending_count = 0;

trace!(buf = ?buf.filled(), diff_len, "read buffer");

// If we got an empty read (no data added to buffer),
// it means we consumed an empty frame (clear signal that we've reached EOF)
// close the receiver immediately to signal end of stream
if diff_len == 0 {
trace!("consumed empty frame, closing receiver");
self.as_mut().get_mut().rx.take();
}
Comment on lines +385 to +388
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: the old code did

if buf.filled().is_empty() {
            self.rx.take();
}

but theoretically, it's possible for buf to not be empty when this function is called. I think this is a subtle bug, so I updated it to check the diff (i.e. if we have new data) instead

Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => {
*this.pending_count += 1;
trace!(pending_count = *this.pending_count, "got Pending");

// Check two cases for our heuristic:
// 1. `this.has_read_data`: we received data at some point, but have no data left to read
// this could be a slow connection - but also likely that there's just no more data coming
// 2. `this.pending_count > 1`: no matter how much we poll, we never get data
// this could be a slow connection where no data has received yet - but likely not
if *this.has_read_data || *this.pending_count > 1 {
trace!("pending after consuming frames or multiple pending, returning UnexpectedEof");
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"end of parameter data: insufficient data for expected parameters",
)));
}
Poll::Pending
}
}
Poll::Ready(Ok(()))
}
}

Expand Down Expand Up @@ -589,6 +641,8 @@ impl Conn {
path: Arc::from([]),
index: Arc::clone(&index),
io: Arc::new(rx_io),
has_read_data: false,
pending_count: 0,
},
}
}
Expand Down
219 changes: 219 additions & 0 deletions tests/deadlock.rs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: I wasn't sure where to put this test, so I just put it in a standalone file at the root of tests so it's easy to move in case you want it in a specific place

Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
#[cfg(test)]
mod parameter_validation_tests {
use core::pin::pin;
use std::collections::{HashMap};
use anyhow::Context as _;
use futures::StreamExt;
use tokio::join;
use tracing::instrument;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio::sync::Mutex;
use wasmtime::component::ResourceTable;
use wasmtime::component::{Component, Linker};
use wasmtime::Engine;
use wasmtime::Store;
use wasmtime_wasi::{WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
use wrpc_runtime_wasmtime::{
ServeExt,
SharedResourceTable, WrpcCtxView, WrpcView,
};
use wrpc_transport::frame::Oneshot;
use wrpc_transport::Invoke;
use tokio::io::AsyncReadExt;

pub struct WrpcCtx<C: Invoke> {
pub wrpc: C,
pub cx: C::Context,
pub shared_resources: SharedResourceTable,
pub timeout: Duration,
}

pub struct Ctx<C: Invoke> {
pub table: ResourceTable,
pub wasi: WasiCtx,
pub wrpc: WrpcCtx<C>,
}

impl<C> wrpc_runtime_wasmtime::WrpcCtx<C> for WrpcCtx<C>
where
C: Invoke,
C::Context: Clone,
{
fn context(&self) -> C::Context {
self.cx.clone()
}

fn client(&self) -> &C {
&self.wrpc
}

fn shared_resources(&mut self) -> &mut SharedResourceTable {
&mut self.shared_resources
}

fn timeout(&self) -> Option<Duration> {
Some(self.timeout)
}
}

impl<C> WrpcView for Ctx<C>
where
C: Invoke,
C::Context: Clone,
{
type Invoke = C;

fn wrpc(&mut self) -> WrpcCtxView<'_, Self::Invoke> {
WrpcCtxView {
ctx: &mut self.wrpc,
table: &mut self.table,
}
}
}

impl<C> WasiView for Ctx<C>
where
C: Invoke,
{
fn ctx(&mut self) -> WasiCtxView<'_> {
WasiCtxView {
ctx: &mut self.wasi,
table: &mut self.table,
}
}
}

pub fn gen_ctx<C: Invoke>(wrpc: C, cx: C::Context) -> Ctx<C> {
Ctx {
table: ResourceTable::new(),
wasi: WasiCtxBuilder::new().build(),
wrpc: WrpcCtx {
wrpc,
cx,
shared_resources: SharedResourceTable::default(),
timeout: Duration::from_secs(10),
},
}
}

/// Test that insufficient parameters cause an error instead of deadlock
///
/// This test verifies the fix for the DoS vulnerability where malformed
/// parameter data causes the server to hang indefinitely.
///
/// IMPORTANT: This test uses serve_function_shared (runtime-wasmtime layer)
/// with raw invoke(), which is where the bug manifests. Using serve_values
/// (transport layer) with invoke_values_blocking will NOT reproduce the bug.
#[test_log::test(tokio::test(flavor = "multi_thread"))]
#[instrument(ret)]
async fn test_parameter_validation_deadlock() -> anyhow::Result<()> {
let params = wit_bindgen_wrpc::bytes::Bytes::new(); // Empty params

let (oneshot_clt, oneshot_srv) = Oneshot::duplex(1024);
let srv = Arc::new(wrpc_transport::frame::Server::default());

let mut config = wasmtime::Config::default();
config.async_support(true);
let engine = Engine::new(&config).context("failed to create engine with async support")?;

let component_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("simple-args.wasm");
let component_bytes = std::fs::read(&component_path)
.with_context(|| format!("failed to read component from {:?}", component_path))?;

let component =
Component::new(&engine, component_bytes).context("failed to parse component")?;

let linker = Linker::new(&engine);
let mut store = Store::new(&engine, gen_ctx(oneshot_clt, ()));

let instance = linker
.instantiate_async(&mut store, &component)
.await
.context("failed to instantiate component")?;


let function = "get-value";
let func = instance
.get_func(&mut store, &function)
.ok_or_else(|| anyhow::anyhow!("function `{function}` not found in component"))?;

let fun_ty = func.ty(&store);

let guest_resources_vec = Vec::new();
let host_resources = HashMap::new();

let instance_name = "".to_string();

let store_shared = Arc::new(Mutex::new(store));
let store_shared_clone = store_shared.clone();
let invocations_stream = srv
.serve_function_shared(
store_shared,
instance,
Arc::from(guest_resources_vec.into_boxed_slice()),
Arc::from(host_resources),
fun_ty,
&instance_name,
&function,
)
.await
.with_context(|| {
format!("failed to register handler for function `{function}`")
})?;
let (result, invocation_handle) = join!(
// client side
async move {
let paths: &[&[Option<usize>]] = &[];
// Lock the store only to get the wrpc client and invoke
// Release the lock immediately after getting the streams
let (mut outgoing, mut incoming) = {
let store = store_shared_clone.lock().await;
store.data().wrpc.wrpc
.invoke((), &instance_name, &function, params, paths)
.await
.expect(&format!("failed to invoke {}", function))
};
// Lock is now released, allowing server to process the invocation
outgoing.flush().await?;
let mut buf = vec![];
incoming
.read_to_end(&mut buf)
.await
.with_context(|| format!("failed to read result for root function `{function}`"))?;
Ok::<Vec<u8>, anyhow::Error>(buf)
},
// server side
async move {
srv.accept(oneshot_srv)
.await
.expect("failed to accept connection");

tokio::spawn(async move {
let mut invocations = pin!(invocations_stream);
while let Some(invocation) = invocations.as_mut().next().await {
match invocation {
Ok((_, fut)) => {
if let Err(err) = fut.await {
eprintln!("failed to serve invocation for root function `{function}`: {err:?}");
}
}
Err(err) => {
eprintln!("failed to accept invocation for root function `{function}`: {err:?}");
}
}
}
})
}
);
// Clean up the invocation handle since the oneshot connection is complete
// The stream should naturally end, but we abort to ensure cleanup happens immediately
invocation_handle.abort();
println!("result: {:?}", result);
Ok(())
}
}
Binary file added tests/simple-args.wasm
Binary file not shown.
Loading