diff --git a/src/http/response.rs b/src/http/response.rs index 0dcfe30..98a22fc 100644 --- a/src/http/response.rs +++ b/src/http/response.rs @@ -111,19 +111,24 @@ impl AsyncRead for IncomingBody { let buf = match &mut self.buf { Some(ref mut buf) => buf, None => { - // Wait for an event to be ready - let pollable = self.body_stream.subscribe(); - Reactor::current().wait_for(pollable).await; - - // Read the bytes from the body stream - let buf = match self.body_stream.read(CHUNK_SIZE) { - Ok(buf) => buf, - Err(StreamError::Closed) => return Ok(0), - Err(StreamError::LastOperationFailed(err)) => { - return Err(std::io::Error::other(format!( - "last operation failed: {}", - err.to_debug_string() - ))) + // workaround for unexpected stream break. https://github.com/bytecodealliance/wasmtime/issues/9667 + let reactor = Reactor::current(); + let buf = loop { + reactor.wait_for(self.body_stream.subscribe()).await; + match self.body_stream.read(CHUNK_SIZE) { + Ok(buf) => { + if buf.is_empty() { + continue; + } + break buf; + } + Err(StreamError::Closed) => return Ok(0), + Err(StreamError::LastOperationFailed(err)) => { + return Err(std::io::Error::other(format!( + "last operation failed: {}", + err.to_debug_string() + ))) + } } }; self.buf.insert(buf) diff --git a/src/net/tcp_stream.rs b/src/net/tcp_stream.rs index 5c6ab8a..a35ec48 100644 --- a/src/net/tcp_stream.rs +++ b/src/net/tcp_stream.rs @@ -30,11 +30,20 @@ impl TcpStream { impl AsyncRead for TcpStream { async fn read(&mut self, buf: &mut [u8]) -> io::Result { - Reactor::current().wait_for(self.input.subscribe()).await; - let slice = match self.input.read(buf.len() as u64) { - Ok(slice) => slice, - Err(StreamError::Closed) => return Ok(0), - Err(e) => return Err(to_io_err(e)), + // workaround for unexpected stream break. https://github.com/bytecodealliance/wasmtime/issues/9667 + let reactor = Reactor::current(); + let slice = loop { + reactor.wait_for(self.input.subscribe()).await; + match self.input.read(buf.len() as u64) { + Ok(slice) => { + if slice.is_empty() { + continue; + } + break slice; + } + Err(StreamError::Closed) => return Ok(0), + Err(e) => return Err(to_io_err(e)), + }; }; let bytes_read = slice.len(); buf[..bytes_read].clone_from_slice(&slice); @@ -44,11 +53,20 @@ impl AsyncRead for TcpStream { impl AsyncRead for &TcpStream { async fn read(&mut self, buf: &mut [u8]) -> io::Result { - Reactor::current().wait_for(self.input.subscribe()).await; - let slice = match self.input.read(buf.len() as u64) { - Ok(slice) => slice, - Err(StreamError::Closed) => return Ok(0), - Err(e) => return Err(to_io_err(e)), + // workaround for unexpected stream break. https://github.com/bytecodealliance/wasmtime/issues/9667 + let reactor = Reactor::current(); + let slice = loop { + reactor.wait_for(self.input.subscribe()).await; + match self.input.read(buf.len() as u64) { + Ok(slice) => { + if slice.is_empty() { + continue; + } + break slice; + } + Err(StreamError::Closed) => return Ok(0), + Err(e) => return Err(to_io_err(e)), + }; }; let bytes_read = slice.len(); buf[..bytes_read].clone_from_slice(&slice); diff --git a/test-programs/artifacts/tests/tcp_echo_server.rs b/test-programs/artifacts/tests/tcp_echo_server.rs index 92a16f4..c07125a 100644 --- a/test-programs/artifacts/tests/tcp_echo_server.rs +++ b/test-programs/artifacts/tests/tcp_echo_server.rs @@ -103,16 +103,18 @@ fn tcp_echo_server() -> Result<()> { const MESSAGE: &[u8] = b"hello, echoserver!\n"; - tcpstream.write_all(MESSAGE).context("write to socket")?; - println!("wrote to echo server"); + let n = 2; + for _ in 0..n { + tcpstream.write_all(MESSAGE).context("write to socket")?; + println!("wrote to echo server"); - let mut readback = Vec::new(); - tcpstream - .read_to_end(&mut readback) - .context("read from socket")?; + let mut buf = [0; 1024]; + let n = tcpstream.read(&mut buf).context("read from socket")?; + let readback = &buf[..n]; - println!("read from wasm server"); - assert_eq!(MESSAGE, readback); + println!("read from wasm server"); + assert_eq!(MESSAGE, readback); + } if wasmtime_thread.is_finished() { wasmtime_thread.join().expect("wasmtime panicked")?;