Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 113 additions & 30 deletions packages/fullstack/src/payloads/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ use crate::{
};
use axum::extract::{FromRequest, Request};
use axum_core::response::IntoResponse;
use bytes::Bytes;
use bytes::{Buf as _, Bytes};
use dioxus_fullstack_core::{HttpError, RequestError};
use futures::stream::iter as iter_stream;
use futures::{Stream, StreamExt};
#[cfg(feature = "server")]
use futures_channel::mpsc::UnboundedSender;
Expand Down Expand Up @@ -277,21 +278,13 @@ impl FromResponse for Streaming<Bytes> {
}
}

impl<T: DeserializeOwned + Serialize + 'static + Send, E: Encoding> FromResponse
impl<T: DeserializeOwned + Serialize + 'static + Send, E: Encoding + 'static> FromResponse
for Streaming<T, E>
{
fn from_response(res: ClientResponse) -> impl Future<Output = Result<Self, ServerFnError>> {
SendWrapper::new(async move {
let client_stream = Box::pin(SendWrapper::new(res.bytes_stream().map(
|byte| match byte {
Ok(bytes) => match decode_stream_frame::<T, E>(bytes) {
Some(res) => Ok(res),
None => Err(StreamingError::Decoding),
},
Err(_) => Err(StreamingError::Failed),
},
)));
let client_stream = byte_stream_to_client_stream::<E, _, _, _>(res.bytes_stream());

SendWrapper::new(async move {
Ok(Self {
output_stream: client_stream,
input_stream: Box::pin(futures::stream::empty()),
Expand Down Expand Up @@ -370,7 +363,7 @@ impl<S> FromRequest<S> for ByteStream {
}
}

impl<T: DeserializeOwned + Serialize + 'static + Send, E: Encoding, S> FromRequest<S>
impl<T: DeserializeOwned + Serialize + 'static + Send, E: Encoding + 'static, S> FromRequest<S>
for Streaming<T, E>
{
type Rejection = ServerFnError;
Expand All @@ -393,15 +386,10 @@ impl<T: DeserializeOwned + Serialize + 'static + Send, E: Encoding, S> FromReque

let stream = body.into_data_stream();

let client_stream = byte_stream_to_client_stream::<E, _, _, _>(stream);
Ok(Self {
input_stream: Box::pin(futures::stream::empty()),
output_stream: Box::pin(stream.map(|byte| match byte {
Ok(bytes) => match decode_stream_frame::<T, E>(bytes) {
Some(res) => Ok(res),
None => Err(StreamingError::Decoding),
},
Err(_) => Err(StreamingError::Failed),
})),
output_stream: client_stream,
encoding: PhantomData,
})
}
Expand Down Expand Up @@ -514,21 +502,117 @@ pub fn encode_stream_frame<T: Serialize, E: Encoding>(data: T) -> Option<Bytes>
Some(Bytes::from(bytes).slice(offset..))
}

fn byte_stream_to_client_stream<E, T, S, E1>(
stream: S,
) -> Pin<Box<dyn Stream<Item = Result<T, StreamingError>> + Send>>
where
S: Stream<Item = Result<Bytes, E1>> + 'static,
E: Encoding + 'static,
T: DeserializeOwned + 'static,
{
Box::pin(SendWrapper::new(stream.flat_map(|bytes| match bytes {
Ok(bytes) => iter_stream(DecodeIterator::<T, E>::new(bytes)),
Err(_) => iter_stream(DecodeIterator::<T, E>::failed()),
})))
}

enum DecodeIteratorState {
Empty,
Failed,
Checked(Bytes),
UnChecked(Bytes),
}

/// An iterator of T decoded from bytes
/// that return an error if it is created empty
struct DecodeIterator<T, E>(DecodeIteratorState, PhantomData<*const (T, E)>);

impl<T, E> DecodeIterator<T, E> {
fn new(bytes: Bytes) -> Self {
DecodeIterator(DecodeIteratorState::UnChecked(bytes), PhantomData)
}
fn failed() -> Self {
DecodeIterator(DecodeIteratorState::Failed, PhantomData)
}
}

impl<T, E> Iterator for DecodeIterator<T, E>
where
E: Encoding,
T: DeserializeOwned,
{
type Item = Result<T, StreamingError>;

fn next(&mut self) -> Option<Self::Item> {
match std::mem::replace(&mut self.0, DecodeIteratorState::Empty) {
DecodeIteratorState::Empty => None,
DecodeIteratorState::Failed => Some(Err(StreamingError::Failed)),
DecodeIteratorState::Checked(mut bytes) => {
let r = decode_stream_frame_multi::<T, E>(&mut bytes);
if r.is_some() {
self.0 = DecodeIteratorState::Checked(bytes)
}
r
}
DecodeIteratorState::UnChecked(mut bytes) => {
let r = decode_stream_frame_multi::<T, E>(&mut bytes);
if r.is_some() {
self.0 = DecodeIteratorState::Checked(bytes);
r
} else {
Some(Err(StreamingError::Decoding))
}
}
}
}
}

/// Decode a websocket-framed streaming payload produced by [`encode_stream_frame`].
///
/// This function returns `None` if the frame is invalid or cannot be decoded.
///
/// It cannot handle masked frames, as those are not produced by our encoding function.
pub fn decode_stream_frame<T, E>(frame: Bytes) -> Option<T>
pub fn decode_stream_frame<T, E>(mut frame: Bytes) -> Option<T>
where
E: Encoding,
T: DeserializeOwned,
{
decode_stream_frame_multi::<T, E>(&mut frame).and_then(|r| r.ok())
}

/// Decode one value and advance the bytes pointer
///
/// If the frame is empty return None.
///
/// Otherwise, if the initial opcode is not the one expected for binary stream
/// or the frame is not large enough return error StreamingError::Decoding
fn decode_stream_frame_multi<T, E>(frame: &mut Bytes) -> Option<Result<T, StreamingError>>
where
E: Encoding,
T: DeserializeOwned,
{
let (offset, payload_len) = match offset_payload_len(frame)? {
Ok(r) => r,
Err(e) => return Some(Err(e)),
};

let r = E::decode(frame.slice(offset..offset + payload_len));

frame.advance(offset + payload_len);

r.map(|r| Ok(r))
}

/// Compute (offset,len) for decoding data
fn offset_payload_len(frame: &Bytes) -> Option<Result<(usize, usize), StreamingError>> {
let data = frame.as_ref();

if data.len() < 2 {
if data.is_empty() {
return None;
}
if data.len() < 2 {
return Some(Err(StreamingError::Decoding));
}

let first = data[0];
let second = data[1];
Expand All @@ -538,44 +622,43 @@ where
let opcode = first & 0x0F;
let rsv = first & 0x70;
if !fin || opcode != 0x02 || rsv != 0 {
return None;
return Some(Err(StreamingError::Decoding));
}

// Mask bit must be zero for our framing
if second & 0x80 != 0 {
return None;
return Some(Err(StreamingError::Decoding));
}

let mut offset = 2usize;
let mut payload_len = (second & 0x7F) as usize;

if payload_len == 126 {
if data.len() < offset + 2 {
return None;
return Some(Err(StreamingError::Decoding));
}

payload_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
offset += 2;
} else if payload_len == 127 {
if data.len() < offset + 8 {
return None;
return Some(Err(StreamingError::Decoding));
}

let mut len_bytes = [0u8; 8];
len_bytes.copy_from_slice(&data[offset..offset + 8]);
let len_u64 = u64::from_be_bytes(len_bytes);

if len_u64 > usize::MAX as u64 {
return None;
return Some(Err(StreamingError::Decoding));
}

payload_len = len_u64 as usize;
offset += 8;
}

if data.len() < offset + payload_len {
return None;
return Some(Err(StreamingError::Decoding));
}

E::decode(frame.slice(offset..offset + payload_len))
Some(Ok((offset, payload_len)))
}