diff --git a/packages/fullstack/src/payloads/stream.rs b/packages/fullstack/src/payloads/stream.rs index 7ab59ffeb6..bd9fed3ca8 100644 --- a/packages/fullstack/src/payloads/stream.rs +++ b/packages/fullstack/src/payloads/stream.rs @@ -6,8 +6,9 @@ use crate::{ }; use axum::extract::{FromRequest, Request}; use axum_core::response::IntoResponse; -use bytes::Bytes; +use bytes::{Buf as _, Bytes}; use dioxus_fullstack_core::{HttpError, RequestError}; +use futures::stream::iter as iter_stream; use futures::{Stream, StreamExt}; #[cfg(feature = "server")] use futures_channel::mpsc::UnboundedSender; @@ -277,21 +278,13 @@ impl FromResponse for Streaming { } } -impl FromResponse +impl FromResponse for Streaming { fn from_response(res: ClientResponse) -> impl Future> { - SendWrapper::new(async move { - let client_stream = Box::pin(SendWrapper::new(res.bytes_stream().map( - |byte| match byte { - Ok(bytes) => match decode_stream_frame::(bytes) { - Some(res) => Ok(res), - None => Err(StreamingError::Decoding), - }, - Err(_) => Err(StreamingError::Failed), - }, - ))); + let client_stream = byte_stream_to_client_stream::(res.bytes_stream()); + SendWrapper::new(async move { Ok(Self { output_stream: client_stream, input_stream: Box::pin(futures::stream::empty()), @@ -370,7 +363,7 @@ impl FromRequest for ByteStream { } } -impl FromRequest +impl FromRequest for Streaming { type Rejection = ServerFnError; @@ -393,15 +386,10 @@ impl FromReque let stream = body.into_data_stream(); + let client_stream = byte_stream_to_client_stream::(stream); Ok(Self { input_stream: Box::pin(futures::stream::empty()), - output_stream: Box::pin(stream.map(|byte| match byte { - Ok(bytes) => match decode_stream_frame::(bytes) { - Some(res) => Ok(res), - None => Err(StreamingError::Decoding), - }, - Err(_) => Err(StreamingError::Failed), - })), + output_stream: client_stream, encoding: PhantomData, }) } @@ -514,21 +502,117 @@ pub fn encode_stream_frame(data: T) -> Option Some(Bytes::from(bytes).slice(offset..)) } +fn byte_stream_to_client_stream( + stream: S, +) -> Pin> + Send>> +where + S: Stream> + 'static, + E: Encoding + 'static, + T: DeserializeOwned + 'static, +{ + Box::pin(SendWrapper::new(stream.flat_map(|bytes| match bytes { + Ok(bytes) => iter_stream(DecodeIterator::::new(bytes)), + Err(_) => iter_stream(DecodeIterator::::failed()), + }))) +} + +enum DecodeIteratorState { + Empty, + Failed, + Checked(Bytes), + UnChecked(Bytes), +} + +/// An iterator of T decoded from bytes +/// that return an error if it is created empty +struct DecodeIterator(DecodeIteratorState, PhantomData<*const (T, E)>); + +impl DecodeIterator { + fn new(bytes: Bytes) -> Self { + DecodeIterator(DecodeIteratorState::UnChecked(bytes), PhantomData) + } + fn failed() -> Self { + DecodeIterator(DecodeIteratorState::Failed, PhantomData) + } +} + +impl Iterator for DecodeIterator +where + E: Encoding, + T: DeserializeOwned, +{ + type Item = Result; + + fn next(&mut self) -> Option { + match std::mem::replace(&mut self.0, DecodeIteratorState::Empty) { + DecodeIteratorState::Empty => None, + DecodeIteratorState::Failed => Some(Err(StreamingError::Failed)), + DecodeIteratorState::Checked(mut bytes) => { + let r = decode_stream_frame_multi::(&mut bytes); + if r.is_some() { + self.0 = DecodeIteratorState::Checked(bytes) + } + r + } + DecodeIteratorState::UnChecked(mut bytes) => { + let r = decode_stream_frame_multi::(&mut bytes); + if r.is_some() { + self.0 = DecodeIteratorState::Checked(bytes); + r + } else { + Some(Err(StreamingError::Decoding)) + } + } + } + } +} + /// Decode a websocket-framed streaming payload produced by [`encode_stream_frame`]. /// /// This function returns `None` if the frame is invalid or cannot be decoded. /// /// It cannot handle masked frames, as those are not produced by our encoding function. -pub fn decode_stream_frame(frame: Bytes) -> Option +pub fn decode_stream_frame(mut frame: Bytes) -> Option +where + E: Encoding, + T: DeserializeOwned, +{ + decode_stream_frame_multi::(&mut frame).and_then(|r| r.ok()) +} + +/// Decode one value and advance the bytes pointer +/// +/// If the frame is empty return None. +/// +/// Otherwise, if the initial opcode is not the one expected for binary stream +/// or the frame is not large enough return error StreamingError::Decoding +fn decode_stream_frame_multi(frame: &mut Bytes) -> Option> where E: Encoding, T: DeserializeOwned, { + let (offset, payload_len) = match offset_payload_len(frame)? { + Ok(r) => r, + Err(e) => return Some(Err(e)), + }; + + let r = E::decode(frame.slice(offset..offset + payload_len)); + + frame.advance(offset + payload_len); + + r.map(|r| Ok(r)) +} + +/// Compute (offset,len) for decoding data +fn offset_payload_len(frame: &Bytes) -> Option> { let data = frame.as_ref(); - if data.len() < 2 { + if data.is_empty() { return None; } + if data.len() < 2 { + return Some(Err(StreamingError::Decoding)); + } let first = data[0]; let second = data[1]; @@ -538,12 +622,12 @@ where let opcode = first & 0x0F; let rsv = first & 0x70; if !fin || opcode != 0x02 || rsv != 0 { - return None; + return Some(Err(StreamingError::Decoding)); } // Mask bit must be zero for our framing if second & 0x80 != 0 { - return None; + return Some(Err(StreamingError::Decoding)); } let mut offset = 2usize; @@ -551,14 +635,14 @@ where if payload_len == 126 { if data.len() < offset + 2 { - return None; + return Some(Err(StreamingError::Decoding)); } payload_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize; offset += 2; } else if payload_len == 127 { if data.len() < offset + 8 { - return None; + return Some(Err(StreamingError::Decoding)); } let mut len_bytes = [0u8; 8]; @@ -566,7 +650,7 @@ where let len_u64 = u64::from_be_bytes(len_bytes); if len_u64 > usize::MAX as u64 { - return None; + return Some(Err(StreamingError::Decoding)); } payload_len = len_u64 as usize; @@ -574,8 +658,7 @@ where } if data.len() < offset + payload_len { - return None; + return Some(Err(StreamingError::Decoding)); } - - E::decode(frame.slice(offset..offset + payload_len)) + Some(Ok((offset, payload_len))) }