diff --git a/crates/async-compression/src/futures/write/buf_writer.rs b/crates/async-compression/src/futures/write/buf_writer.rs index 13f23f16..847437ff 100644 --- a/crates/async-compression/src/futures/write/buf_writer.rs +++ b/crates/async-compression/src/futures/write/buf_writer.rs @@ -2,218 +2,15 @@ // the `AsyncBufWrite` impl can access its internals, and changed a bit to make it more efficient // with those methods. -use super::AsyncBufWrite; -use futures_core::ready; +use crate::generic::write::impl_buf_writer; use futures_io::{AsyncSeek, AsyncWrite, SeekFrom}; -use pin_project_lite::pin_project; use std::{ - cmp::min, - fmt, io, + io, pin::Pin, task::{Context, Poll}, }; -const DEFAULT_BUF_SIZE: usize = 8192; - -pin_project! { - pub struct BufWriter { - #[pin] - inner: W, - buf: Box<[u8]>, - written: usize, - buffered: usize, - } -} - -impl BufWriter { - /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB, - /// but may change in the future. - pub fn new(inner: W) -> Self { - Self::with_capacity(DEFAULT_BUF_SIZE, inner) - } - - /// Creates a new `BufWriter` with the specified buffer capacity. - pub fn with_capacity(cap: usize, inner: W) -> Self { - Self { - inner, - buf: vec![0; cap].into(), - written: 0, - buffered: 0, - } - } - - fn partial_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - let mut ret = Ok(()); - while *this.written < *this.buffered { - match this - .inner - .as_mut() - .poll_write(cx, &this.buf[*this.written..*this.buffered]) - { - Poll::Pending => { - break; - } - Poll::Ready(Ok(0)) => { - ret = Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write the buffered data", - )); - break; - } - Poll::Ready(Ok(n)) => *this.written += n, - Poll::Ready(Err(e)) => { - ret = Err(e); - break; - } - } - } - - if *this.written > 0 { - this.buf.copy_within(*this.written..*this.buffered, 0); - *this.buffered -= *this.written; - *this.written = 0; - - Poll::Ready(ret) - } else if *this.buffered == 0 { - Poll::Ready(ret) - } else { - ret?; - Poll::Pending - } - } - - fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - let mut ret = Ok(()); - while *this.written < *this.buffered { - match ready!(this - .inner - .as_mut() - .poll_write(cx, &this.buf[*this.written..*this.buffered])) - { - Ok(0) => { - ret = Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write the buffered data", - )); - break; - } - Ok(n) => *this.written += n, - Err(e) => { - ret = Err(e); - break; - } - } - } - this.buf.copy_within(*this.written..*this.buffered, 0); - *this.buffered -= *this.written; - *this.written = 0; - Poll::Ready(ret) - } -} - -impl BufWriter { - /// Gets a reference to the underlying writer. - pub fn get_ref(&self) -> &W { - &self.inner - } - - /// Gets a mutable reference to the underlying writer. - /// - /// It is inadvisable to directly write to the underlying writer. - pub fn get_mut(&mut self) -> &mut W { - &mut self.inner - } - - /// Gets a pinned mutable reference to the underlying writer. - /// - /// It is inadvisable to directly write to the underlying writer. - pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> { - self.project().inner - } - - /// Consumes this `BufWriter`, returning the underlying writer. - /// - /// Note that any leftover data in the internal buffer is lost. - pub fn into_inner(self) -> W { - self.inner - } -} - -impl AsyncWrite for BufWriter { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let this = self.as_mut().project(); - if *this.buffered + buf.len() > this.buf.len() { - ready!(self.as_mut().partial_flush_buf(cx))?; - } - - let this = self.as_mut().project(); - if buf.len() >= this.buf.len() { - if *this.buffered == 0 { - this.inner.poll_write(cx, buf) - } else { - // The only way that `partial_flush_buf` would have returned with - // `this.buffered != 0` is if it were Pending, so our waker was already queued - Poll::Pending - } - } else { - let len = min(this.buf.len() - *this.buffered, buf.len()); - this.buf[*this.buffered..*this.buffered + len].copy_from_slice(&buf[..len]); - *this.buffered += len; - Poll::Ready(Ok(len)) - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().flush_buf(cx))?; - self.project().inner.poll_flush(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().flush_buf(cx))?; - self.project().inner.poll_close(cx) - } -} - -impl AsyncBufWrite for BufWriter { - fn poll_partial_flush_buf( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - ready!(self.as_mut().partial_flush_buf(cx))?; - let this = self.project(); - Poll::Ready(Ok(&mut this.buf[*this.buffered..])) - } - - fn produce(self: Pin<&mut Self>, amt: usize) { - let this = self.project(); - debug_assert!( - *this.buffered + amt <= this.buf.len(), - "produce called with amt exceeding buffer capacity" - ); - *this.buffered += amt; - } -} - -impl fmt::Debug for BufWriter { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("BufWriter") - .field("writer", &self.inner) - .field( - "buffer", - &format_args!("{}/{}", self.buffered, self.buf.len()), - ) - .field("written", &self.written) - .finish() - } -} +impl_buf_writer!(poll_close); impl AsyncSeek for BufWriter { /// Seek to the offset, in bytes, in the underlying writer. @@ -225,6 +22,6 @@ impl AsyncSeek for BufWriter { pos: SeekFrom, ) -> Poll> { ready!(self.as_mut().flush_buf(cx))?; - self.project().inner.poll_seek(cx, pos) + self.project().writer.poll_seek(cx, pos) } } diff --git a/crates/async-compression/src/futures/write/generic/decoder.rs b/crates/async-compression/src/futures/write/generic/decoder.rs index 4a099f01..8657ab63 100644 --- a/crates/async-compression/src/futures/write/generic/decoder.rs +++ b/crates/async-compression/src/futures/write/generic/decoder.rs @@ -1,184 +1,12 @@ -use crate::codecs::Decode; -use crate::core::util::PartialBuffer; -use crate::futures::write::{AsyncBufWrite, BufWriter}; -use futures_core::ready; +use crate::{futures::write::BufWriter, generic::write::impl_decoder}; use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite, IoSliceMut}; -use pin_project_lite::pin_project; use std::{ io, pin::Pin, task::{Context, Poll}, }; -#[derive(Debug)] -enum State { - Decoding, - Finishing, - Done, -} - -pin_project! { - #[derive(Debug)] - pub struct Decoder { - #[pin] - writer: BufWriter, - decoder: D, - state: State, - } -} - -impl Decoder { - pub fn new(writer: W, decoder: D) -> Self { - Self { - writer: BufWriter::new(writer), - decoder, - state: State::Decoding, - } - } -} - -impl Decoder { - pub fn get_ref(&self) -> &W { - self.writer.get_ref() - } - - pub fn get_mut(&mut self) -> &mut W { - self.writer.get_mut() - } - - pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> { - self.project().writer.get_pin_mut() - } - - pub fn into_inner(self) -> W { - self.writer.into_inner() - } -} - -impl Decoder { - fn do_poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - input: &mut PartialBuffer<&[u8]>, - ) -> Poll> { - let mut this = self.project(); - - loop { - let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = PartialBuffer::new(output); - - *this.state = match this.state { - State::Decoding => { - if this.decoder.decode(input, &mut output)? { - State::Finishing - } else { - State::Decoding - } - } - - State::Finishing => { - if this.decoder.finish(&mut output)? { - State::Done - } else { - State::Finishing - } - } - - State::Done => { - return Poll::Ready(Err(io::Error::other("Write after end of stream"))) - } - }; - - let produced = output.written().len(); - this.writer.as_mut().produce(produced); - - if let State::Done = this.state { - return Poll::Ready(Ok(())); - } - - if input.unwritten().is_empty() { - return Poll::Ready(Ok(())); - } - } - } - - fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - loop { - let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = PartialBuffer::new(output); - - let (state, done) = match this.state { - State::Decoding => { - let done = this.decoder.flush(&mut output)?; - (State::Decoding, done) - } - - State::Finishing => { - if this.decoder.finish(&mut output)? { - (State::Done, false) - } else { - (State::Finishing, false) - } - } - - State::Done => (State::Done, true), - }; - - *this.state = state; - - let produced = output.written().len(); - this.writer.as_mut().produce(produced); - - if done { - return Poll::Ready(Ok(())); - } - } - } -} - -impl AsyncWrite for Decoder { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if buf.is_empty() { - return Poll::Ready(Ok(0)); - } - - let mut input = PartialBuffer::new(buf); - - match self.do_poll_write(cx, &mut input)? { - Poll::Pending if input.written().is_empty() => Poll::Pending, - _ => Poll::Ready(Ok(input.written().len())), - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().do_poll_flush(cx))?; - ready!(self.project().writer.as_mut().poll_flush(cx))?; - Poll::Ready(Ok(())) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if let State::Decoding = self.as_mut().project().state { - *self.as_mut().project().state = State::Finishing; - } - - ready!(self.as_mut().do_poll_flush(cx))?; - - if let State::Done = self.as_mut().project().state { - ready!(self.as_mut().project().writer.as_mut().poll_close(cx))?; - Poll::Ready(Ok(())) - } else { - Poll::Ready(Err(io::Error::other( - "Attempt to close before finishing input", - ))) - } - } -} +impl_decoder!(poll_close); impl AsyncRead for Decoder { fn poll_read( @@ -197,13 +25,3 @@ impl AsyncRead for Decoder { self.get_pin_mut().poll_read_vectored(cx, bufs) } } - -impl AsyncBufRead for Decoder { - fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_pin_mut().poll_fill_buf(cx) - } - - fn consume(self: Pin<&mut Self>, amt: usize) { - self.get_pin_mut().consume(amt) - } -} diff --git a/crates/async-compression/src/futures/write/mod.rs b/crates/async-compression/src/futures/write/mod.rs index 831328d6..f312aaaa 100644 --- a/crates/async-compression/src/futures/write/mod.rs +++ b/crates/async-compression/src/futures/write/mod.rs @@ -5,13 +5,12 @@ mod macros; mod generic; -mod buf_write; mod buf_writer; use self::{ - buf_write::AsyncBufWrite, buf_writer::BufWriter, generic::{Decoder, Encoder}, }; +use crate::generic::write::AsyncBufWrite; algos!(futures::write); diff --git a/crates/async-compression/src/generic/mod.rs b/crates/async-compression/src/generic/mod.rs index b480f00f..2f5140cc 100644 --- a/crates/async-compression/src/generic/mod.rs +++ b/crates/async-compression/src/generic/mod.rs @@ -1 +1,2 @@ pub(crate) mod bufread; +pub(crate) mod write; diff --git a/crates/async-compression/src/futures/write/buf_write.rs b/crates/async-compression/src/generic/write/buf_write.rs similarity index 100% rename from crates/async-compression/src/futures/write/buf_write.rs rename to crates/async-compression/src/generic/write/buf_write.rs diff --git a/crates/async-compression/src/generic/write/buf_writer.rs b/crates/async-compression/src/generic/write/buf_writer.rs new file mode 100644 index 00000000..3dc05591 --- /dev/null +++ b/crates/async-compression/src/generic/write/buf_writer.rs @@ -0,0 +1,260 @@ +// Originally sourced from `futures_util::io::buf_writer`, needs to be redefined locally so that +// the `AsyncBufWrite` impl can access its internals, and changed a bit to make it more efficient +// with those methods. + +use super::AsyncBufWrite; +use futures_core::ready; +use std::{ + fmt, io, + pin::Pin, + task::{Context, Poll}, +}; + +const DEFAULT_BUF_SIZE: usize = 8192; + +pub struct BufWriter { + buf: Box<[u8]>, + written: usize, + buffered: usize, +} + +impl fmt::Debug for BufWriter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GenericBufWriter") + .field( + "buffer", + &format_args!("{}/{}", self.buffered, self.buf.len()), + ) + .field("written", &self.written) + .finish() + } +} + +impl BufWriter { + /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB, + /// but may change in the future. + pub fn new() -> Self { + Self::with_capacity(DEFAULT_BUF_SIZE) + } + + /// Creates a new `BufWriter` with the specified buffer capacity. + pub fn with_capacity(cap: usize) -> Self { + Self { + buf: vec![0; cap].into(), + written: 0, + buffered: 0, + } + } + + /// Remove the already written data + fn reshuffle_and_remove_written(&mut self) { + self.buf.copy_within(self.written..self.buffered, 0); + self.buffered -= self.written; + self.written = 0; + } + + fn do_flush( + &mut self, + poll_write: &mut dyn FnMut(&[u8]) -> Poll>, + ) -> Poll> { + while self.written < self.buffered { + let bytes_written = ready!(poll_write(&self.buf[self.written..self.buffered]))?; + if bytes_written == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write the buffered data", + ))); + } + + self.written += bytes_written; + } + + Poll::Ready(Ok(())) + } + + fn partial_flush_buf( + &mut self, + poll_write: &mut dyn FnMut(&[u8]) -> Poll>, + ) -> Poll> { + let ret = if let Poll::Ready(res) = self.do_flush(poll_write) { + res + } else { + Ok(()) + }; + + if self.written > 0 { + self.reshuffle_and_remove_written(); + + Poll::Ready(ret) + } else if self.buffered == 0 { + Poll::Ready(ret) + } else { + ret?; + Poll::Pending + } + } + + pub fn flush_buf( + &mut self, + poll_write: &mut dyn FnMut(&[u8]) -> Poll>, + ) -> Poll> { + let ret = ready!(self.do_flush(poll_write)); + self.reshuffle_and_remove_written(); + Poll::Ready(ret) + } + + pub fn poll_write( + &mut self, + buf: &[u8], + poll_write: &mut dyn FnMut(&[u8]) -> Poll>, + ) -> Poll> { + if self.buffered + buf.len() > self.buf.len() { + ready!(self.partial_flush_buf(poll_write))?; + } + + if buf.len() >= self.buf.len() { + if self.buffered == 0 { + poll_write(buf) + } else { + // The only way that `partial_flush_buf` would have returned with + // `this.buffered != 0` is if it were Pending, so our waker was already queued + Poll::Pending + } + } else { + let len = buf.len().min(self.buf.len() - self.buffered); + self.buf[self.buffered..self.buffered + len].copy_from_slice(&buf[..len]); + self.buffered += len; + Poll::Ready(Ok(len)) + } + } + + pub fn poll_partial_flush_buf( + &mut self, + poll_write: &mut dyn FnMut(&[u8]) -> Poll>, + ) -> Poll> { + ready!(self.partial_flush_buf(poll_write))?; + Poll::Ready(Ok(&mut self.buf[self.buffered..])) + } + + pub fn produce(&mut self, amt: usize) { + debug_assert!( + self.buffered + amt <= self.buf.len(), + "produce called with amt exceeding buffer capacity" + ); + self.buffered += amt; + } +} + +macro_rules! impl_buf_writer { + ($poll_close: tt) => { + use crate::generic::write::{AsyncBufWrite, BufWriter as GenericBufWriter}; + use futures_core::ready; + use pin_project_lite::pin_project; + + pin_project! { + #[derive(Debug)] + pub struct BufWriter { + #[pin] + writer: W, + inner: GenericBufWriter, + } + } + + impl BufWriter { + /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB, + /// but may change in the future. + pub fn new(writer: W) -> Self { + Self { + writer, + inner: GenericBufWriter::new(), + } + } + + /// Creates a new `BufWriter` with the specified buffer capacity. + pub fn with_capacity(cap: usize, writer: W) -> Self { + Self { + writer, + inner: GenericBufWriter::with_capacity(cap), + } + } + + /// Gets a reference to the underlying writer. + pub fn get_ref(&self) -> &W { + &self.writer + } + + /// Gets a mutable reference to the underlying writer. + /// + /// It is inadvisable to directly write to the underlying writer. + pub fn get_mut(&mut self) -> &mut W { + &mut self.writer + } + + /// Gets a pinned mutable reference to the underlying writer. + /// + /// It is inadvisable to directly write to the underlying writer. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> { + self.project().writer + } + + /// Consumes this `BufWriter`, returning the underlying writer. + /// + /// Note that any leftover data in the internal buffer is lost. + pub fn into_inner(self) -> W { + self.writer + } + } + + fn get_poll_write<'a, 'b, W: AsyncWrite>( + mut writer: Pin<&'a mut W>, + cx: &'a mut Context<'b>, + ) -> impl for<'buf> FnMut(&'buf [u8]) -> Poll> + use<'a, 'b, W> { + move |buf| writer.as_mut().poll_write(cx, buf) + } + + impl BufWriter { + fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + this.inner.flush_buf(&mut get_poll_write(this.writer, cx)) + } + } + + impl AsyncWrite for BufWriter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.project(); + this.inner + .poll_write(buf, &mut get_poll_write(this.writer, cx)) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().flush_buf(cx))?; + self.project().writer.poll_flush(cx) + } + + fn $poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().flush_buf(cx))?; + self.project().writer.$poll_close(cx) + } + } + + impl AsyncBufWrite for BufWriter { + fn poll_partial_flush_buf( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = self.project(); + this.inner + .poll_partial_flush_buf(&mut get_poll_write(this.writer, cx)) + } + + fn produce(self: Pin<&mut Self>, amt: usize) { + self.project().inner.produce(amt) + } + } + }; +} +pub(crate) use impl_buf_writer; diff --git a/crates/async-compression/src/generic/write/decoder.rs b/crates/async-compression/src/generic/write/decoder.rs new file mode 100644 index 00000000..757ad475 --- /dev/null +++ b/crates/async-compression/src/generic/write/decoder.rs @@ -0,0 +1,245 @@ +use crate::{codecs::Decode, core::util::PartialBuffer, generic::write::AsyncBufWrite}; +use futures_core::ready; +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; + +#[derive(Debug)] +enum State { + Decoding, + Finishing, + Done, +} + +#[derive(Debug)] +pub struct Decoder { + state: State, +} + +impl Default for Decoder { + fn default() -> Self { + Self { + state: State::Decoding, + } + } +} + +impl Decoder { + pub fn do_poll_write( + &mut self, + cx: &mut Context<'_>, + input: &mut PartialBuffer<&[u8]>, + mut writer: Pin<&mut dyn AsyncBufWrite>, + decoder: &mut impl Decode, + ) -> Poll> { + loop { + let output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?; + let mut output = PartialBuffer::new(output); + + self.state = match self.state { + State::Decoding => { + if decoder.decode(input, &mut output)? { + State::Finishing + } else { + State::Decoding + } + } + + State::Finishing => { + if decoder.finish(&mut output)? { + State::Done + } else { + State::Finishing + } + } + + State::Done => { + return Poll::Ready(Err(io::Error::other("Write after end of stream"))) + } + }; + + let produced = output.written().len(); + writer.as_mut().produce(produced); + + if let State::Done = self.state { + return Poll::Ready(Ok(())); + } + + if input.unwritten().is_empty() { + return Poll::Ready(Ok(())); + } + } + } + + pub fn do_poll_flush( + &mut self, + cx: &mut Context<'_>, + mut writer: Pin<&mut dyn AsyncBufWrite>, + decoder: &mut impl Decode, + ) -> Poll> { + loop { + let output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?; + let mut output = PartialBuffer::new(output); + + let (state, done) = match self.state { + State::Decoding => { + let done = decoder.flush(&mut output)?; + (State::Decoding, done) + } + + State::Finishing => { + if decoder.finish(&mut output)? { + (State::Done, false) + } else { + (State::Finishing, false) + } + } + + State::Done => (State::Done, true), + }; + + self.state = state; + + let produced = output.written().len(); + writer.as_mut().produce(produced); + + if done { + break Poll::Ready(Ok(())); + } + } + } + + pub fn do_close(&mut self) { + if let State::Decoding = self.state { + self.state = State::Finishing; + } + } + + pub fn is_done(&self) -> bool { + matches!(self.state, State::Done) + } +} + +macro_rules! impl_decoder { + ($poll_close: tt) => { + use crate::{ + codecs::Decode, core::util::PartialBuffer, generic::write::Decoder as GenericDecoder, + }; + use futures_core::ready; + use pin_project_lite::pin_project; + + pin_project! { + #[derive(Debug)] + pub struct Decoder { + #[pin] + writer: BufWriter, + decoder: D, + inner: GenericDecoder, + } + } + + impl Decoder { + pub fn new(writer: W, decoder: D) -> Self { + Self { + writer: BufWriter::new(writer), + decoder, + inner: Default::default(), + } + } + } + + impl Decoder { + pub fn get_ref(&self) -> &W { + self.writer.get_ref() + } + + pub fn get_mut(&mut self) -> &mut W { + self.writer.get_mut() + } + + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> { + self.project().writer.get_pin_mut() + } + + pub fn into_inner(self) -> W { + self.writer.into_inner() + } + } + + impl Decoder { + fn do_poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + input: &mut PartialBuffer<&[u8]>, + ) -> Poll> { + let mut this = self.project(); + + this.inner + .do_poll_write(cx, input, this.writer, this.decoder) + } + + fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + this.inner.do_poll_flush(cx, this.writer, this.decoder) + } + } + + impl AsyncWrite for Decoder { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + + let mut input = PartialBuffer::new(buf); + + match self.do_poll_write(cx, &mut input)? { + Poll::Pending if input.written().is_empty() => Poll::Pending, + _ => Poll::Ready(Ok(input.written().len())), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().do_poll_flush(cx))?; + ready!(self.project().writer.as_mut().poll_flush(cx))?; + Poll::Ready(Ok(())) + } + + fn $poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.as_mut().project().inner.do_close(); + + ready!(self.as_mut().do_poll_flush(cx))?; + + let this = self.project(); + if this.inner.is_done() { + ready!(this.writer.$poll_close(cx))?; + Poll::Ready(Ok(())) + } else { + Poll::Ready(Err(io::Error::other( + "Attempt to close before finishing input", + ))) + } + } + } + + impl AsyncBufRead for Decoder { + fn poll_fill_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.get_pin_mut().poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.get_pin_mut().consume(amt) + } + } + }; +} +pub(crate) use impl_decoder; diff --git a/crates/async-compression/src/generic/write/mod.rs b/crates/async-compression/src/generic/write/mod.rs new file mode 100644 index 00000000..cf72319e --- /dev/null +++ b/crates/async-compression/src/generic/write/mod.rs @@ -0,0 +1,7 @@ +mod buf_write; +mod buf_writer; +mod decoder; + +pub(crate) use buf_write::*; +pub(crate) use buf_writer::*; +pub(crate) use decoder::*; diff --git a/crates/async-compression/src/tokio/write/buf_write.rs b/crates/async-compression/src/tokio/write/buf_write.rs deleted file mode 100644 index 5ca99731..00000000 --- a/crates/async-compression/src/tokio/write/buf_write.rs +++ /dev/null @@ -1,32 +0,0 @@ -use std::{ - io, - pin::Pin, - task::{Context, Poll}, -}; - -pub(crate) trait AsyncBufWrite { - /// Attempt to return an internal buffer to write to, flushing data out to the inner reader if - /// it is full. - /// - /// On success, returns `Poll::Ready(Ok(buf))`. - /// - /// If the buffer is full and cannot be flushed, the method returns `Poll::Pending` and - /// arranges for the current task context (`cx`) to receive a notification when the object - /// becomes readable or is closed. - fn poll_partial_flush_buf( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>; - - /// Tells this buffer that `amt` bytes have been written to its buffer, so they should be - /// written out to the underlying IO when possible. - /// - /// This function is a lower-level call. It needs to be paired with the `poll_flush_buf` method to - /// function properly. This function does not perform any I/O, it simply informs this object - /// that some amount of its buffer, returned from `poll_flush_buf`, has been written to and should - /// be sent. As such, this function may do odd things if `poll_flush_buf` isn't - /// called before calling it. - /// - /// The `amt` must be `<=` the number of bytes in the buffer returned by `poll_flush_buf`. - fn produce(self: Pin<&mut Self>, amt: usize); -} diff --git a/crates/async-compression/src/tokio/write/buf_writer.rs b/crates/async-compression/src/tokio/write/buf_writer.rs index c56c7e64..c7725ba7 100644 --- a/crates/async-compression/src/tokio/write/buf_writer.rs +++ b/crates/async-compression/src/tokio/write/buf_writer.rs @@ -2,215 +2,12 @@ // the `AsyncBufWrite` impl can access its internals, and changed a bit to make it more efficient // with those methods. -use super::AsyncBufWrite; -use futures_core::ready; -use pin_project_lite::pin_project; +use crate::generic::write::impl_buf_writer; use std::{ - cmp::min, - fmt, io, + io, pin::Pin, task::{Context, Poll}, }; use tokio::io::AsyncWrite; -const DEFAULT_BUF_SIZE: usize = 8192; - -pin_project! { - pub struct BufWriter { - #[pin] - inner: W, - buf: Box<[u8]>, - written: usize, - buffered: usize, - } -} - -impl BufWriter { - /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB, - /// but may change in the future. - pub fn new(inner: W) -> Self { - Self::with_capacity(DEFAULT_BUF_SIZE, inner) - } - - /// Creates a new `BufWriter` with the specified buffer capacity. - pub fn with_capacity(cap: usize, inner: W) -> Self { - Self { - inner, - buf: vec![0; cap].into(), - written: 0, - buffered: 0, - } - } - - fn partial_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - let mut ret = Ok(()); - while *this.written < *this.buffered { - match this - .inner - .as_mut() - .poll_write(cx, &this.buf[*this.written..*this.buffered]) - { - Poll::Pending => { - break; - } - Poll::Ready(Ok(0)) => { - ret = Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write the buffered data", - )); - break; - } - Poll::Ready(Ok(n)) => *this.written += n, - Poll::Ready(Err(e)) => { - ret = Err(e); - break; - } - } - } - - if *this.written > 0 { - this.buf.copy_within(*this.written..*this.buffered, 0); - *this.buffered -= *this.written; - *this.written = 0; - - Poll::Ready(ret) - } else if *this.buffered == 0 { - Poll::Ready(ret) - } else { - ret?; - Poll::Pending - } - } - - fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - let mut ret = Ok(()); - while *this.written < *this.buffered { - match ready!(this - .inner - .as_mut() - .poll_write(cx, &this.buf[*this.written..*this.buffered])) - { - Ok(0) => { - ret = Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write the buffered data", - )); - break; - } - Ok(n) => *this.written += n, - Err(e) => { - ret = Err(e); - break; - } - } - } - this.buf.copy_within(*this.written..*this.buffered, 0); - *this.buffered -= *this.written; - *this.written = 0; - Poll::Ready(ret) - } -} - -impl BufWriter { - /// Gets a reference to the underlying writer. - pub fn get_ref(&self) -> &W { - &self.inner - } - - /// Gets a mutable reference to the underlying writer. - /// - /// It is inadvisable to directly write to the underlying writer. - pub fn get_mut(&mut self) -> &mut W { - &mut self.inner - } - - /// Gets a pinned mutable reference to the underlying writer. - /// - /// It is inadvisable to directly write to the underlying writer. - pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> { - self.project().inner - } - - /// Consumes this `BufWriter`, returning the underlying writer. - /// - /// Note that any leftover data in the internal buffer is lost. - pub fn into_inner(self) -> W { - self.inner - } -} - -impl AsyncWrite for BufWriter { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let this = self.as_mut().project(); - if *this.buffered + buf.len() > this.buf.len() { - ready!(self.as_mut().partial_flush_buf(cx))?; - } - - let this = self.as_mut().project(); - if buf.len() >= this.buf.len() { - if *this.buffered == 0 { - this.inner.poll_write(cx, buf) - } else { - // The only way that `partial_flush_buf` would have returned with - // `this.buffered != 0` is if it were Pending, so our waker was already queued - Poll::Pending - } - } else { - let len = min(this.buf.len() - *this.buffered, buf.len()); - this.buf[*this.buffered..*this.buffered + len].copy_from_slice(&buf[..len]); - *this.buffered += len; - Poll::Ready(Ok(len)) - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().flush_buf(cx))?; - self.project().inner.poll_flush(cx) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().flush_buf(cx))?; - self.project().inner.poll_shutdown(cx) - } -} - -impl AsyncBufWrite for BufWriter { - fn poll_partial_flush_buf( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - ready!(self.as_mut().partial_flush_buf(cx))?; - let this = self.project(); - Poll::Ready(Ok(&mut this.buf[*this.buffered..])) - } - - fn produce(self: Pin<&mut Self>, amt: usize) { - let this = self.project(); - debug_assert!( - *this.buffered + amt <= this.buf.len(), - "produce called with amt exceeding buffer capacity" - ); - *this.buffered += amt; - } -} - -impl fmt::Debug for BufWriter { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("BufWriter") - .field("writer", &self.inner) - .field( - "buffer", - &format_args!("{}/{}", self.buffered, self.buf.len()), - ) - .field("written", &self.written) - .finish() - } -} +impl_buf_writer!(poll_shutdown); diff --git a/crates/async-compression/src/tokio/write/generic/decoder.rs b/crates/async-compression/src/tokio/write/generic/decoder.rs index 56f96656..07468435 100644 --- a/crates/async-compression/src/tokio/write/generic/decoder.rs +++ b/crates/async-compression/src/tokio/write/generic/decoder.rs @@ -1,8 +1,4 @@ -use crate::codecs::Decode; -use crate::core::util::PartialBuffer; -use crate::tokio::write::{AsyncBufWrite, BufWriter}; -use futures_core::ready; -use pin_project_lite::pin_project; +use crate::{generic::write::impl_decoder, tokio::write::BufWriter}; use std::{ io, pin::Pin, @@ -10,175 +6,7 @@ use std::{ }; use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; -#[derive(Debug)] -enum State { - Decoding, - Finishing, - Done, -} - -pin_project! { - #[derive(Debug)] - pub struct Decoder { - #[pin] - writer: BufWriter, - decoder: D, - state: State, - } -} - -impl Decoder { - pub fn new(writer: W, decoder: D) -> Self { - Self { - writer: BufWriter::new(writer), - decoder, - state: State::Decoding, - } - } -} - -impl Decoder { - pub fn get_ref(&self) -> &W { - self.writer.get_ref() - } - - pub fn get_mut(&mut self) -> &mut W { - self.writer.get_mut() - } - - pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> { - self.project().writer.get_pin_mut() - } - - pub fn into_inner(self) -> W { - self.writer.into_inner() - } -} - -impl Decoder { - fn do_poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - input: &mut PartialBuffer<&[u8]>, - ) -> Poll> { - let mut this = self.project(); - - loop { - let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = PartialBuffer::new(output); - - *this.state = match this.state { - State::Decoding => { - if this.decoder.decode(input, &mut output)? { - State::Finishing - } else { - State::Decoding - } - } - - State::Finishing => { - if this.decoder.finish(&mut output)? { - State::Done - } else { - State::Finishing - } - } - - State::Done => { - return Poll::Ready(Err(io::Error::other("Write after end of stream"))) - } - }; - - let produced = output.written().len(); - this.writer.as_mut().produce(produced); - - if let State::Done = this.state { - return Poll::Ready(Ok(())); - } - - if input.unwritten().is_empty() { - return Poll::Ready(Ok(())); - } - } - } - - fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - loop { - let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = PartialBuffer::new(output); - - let (state, done) = match this.state { - State::Decoding => { - let done = this.decoder.flush(&mut output)?; - (State::Decoding, done) - } - - State::Finishing => { - if this.decoder.finish(&mut output)? { - (State::Done, false) - } else { - (State::Finishing, false) - } - } - - State::Done => (State::Done, true), - }; - - *this.state = state; - - let produced = output.written().len(); - this.writer.as_mut().produce(produced); - - if done { - return Poll::Ready(Ok(())); - } - } - } -} - -impl AsyncWrite for Decoder { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if buf.is_empty() { - return Poll::Ready(Ok(0)); - } - - let mut input = PartialBuffer::new(buf); - - match self.do_poll_write(cx, &mut input)? { - Poll::Pending if input.written().is_empty() => Poll::Pending, - _ => Poll::Ready(Ok(input.written().len())), - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().do_poll_flush(cx))?; - ready!(self.project().writer.as_mut().poll_flush(cx))?; - Poll::Ready(Ok(())) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if let State::Decoding = self.as_mut().project().state { - *self.as_mut().project().state = State::Finishing; - } - - ready!(self.as_mut().do_poll_flush(cx))?; - - if let State::Done = self.as_mut().project().state { - ready!(self.as_mut().project().writer.as_mut().poll_shutdown(cx))?; - Poll::Ready(Ok(())) - } else { - Poll::Ready(Err(io::Error::other( - "Attempt to shutdown before finishing input", - ))) - } - } -} +impl_decoder!(poll_shutdown); impl AsyncRead for Decoder { fn poll_read( @@ -189,13 +17,3 @@ impl AsyncRead for Decoder { self.get_pin_mut().poll_read(cx, buf) } } - -impl AsyncBufRead for Decoder { - fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_pin_mut().poll_fill_buf(cx) - } - - fn consume(self: Pin<&mut Self>, amt: usize) { - self.get_pin_mut().consume(amt) - } -} diff --git a/crates/async-compression/src/tokio/write/mod.rs b/crates/async-compression/src/tokio/write/mod.rs index 409cd670..bb6b95a2 100644 --- a/crates/async-compression/src/tokio/write/mod.rs +++ b/crates/async-compression/src/tokio/write/mod.rs @@ -5,13 +5,12 @@ mod macros; mod generic; -mod buf_write; mod buf_writer; use self::{ - buf_write::AsyncBufWrite, buf_writer::BufWriter, generic::{Decoder, Encoder}, }; +use crate::generic::write::AsyncBufWrite; algos!(tokio::write);