diff --git a/crates/async-compression/src/generic/write/buf_write.rs b/crates/async-compression/src/generic/write/buf_write.rs index 5ca99731..5a12b3ab 100644 --- a/crates/async-compression/src/generic/write/buf_write.rs +++ b/crates/async-compression/src/generic/write/buf_write.rs @@ -1,3 +1,4 @@ +use super::Buffer; use std::{ io, pin::Pin, @@ -16,17 +17,5 @@ pub(crate) trait AsyncBufWrite { fn poll_partial_flush_buf( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>; - - /// Tells this buffer that `amt` bytes have been written to its buffer, so they should be - /// written out to the underlying IO when possible. - /// - /// This function is a lower-level call. It needs to be paired with the `poll_flush_buf` method to - /// function properly. This function does not perform any I/O, it simply informs this object - /// that some amount of its buffer, returned from `poll_flush_buf`, has been written to and should - /// be sent. As such, this function may do odd things if `poll_flush_buf` isn't - /// called before calling it. - /// - /// The `amt` must be `<=` the number of bytes in the buffer returned by `poll_flush_buf`. - fn produce(self: Pin<&mut Self>, amt: usize); + ) -> Poll>>; } diff --git a/crates/async-compression/src/generic/write/buf_writer.rs b/crates/async-compression/src/generic/write/buf_writer.rs index 8db1797b..ddb4a526 100644 --- a/crates/async-compression/src/generic/write/buf_writer.rs +++ b/crates/async-compression/src/generic/write/buf_writer.rs @@ -3,6 +3,7 @@ // with those methods. use super::AsyncBufWrite; +use compression_core::util::WriteBuffer; use futures_core::ready; use std::{ fmt, io, @@ -133,7 +134,7 @@ impl BufWriter { pub fn poll_partial_flush_buf( &mut self, poll_write: &mut dyn FnMut(&[u8]) -> Poll>, - ) -> Poll> { + ) -> Poll>> { ready!(self.partial_flush_buf(poll_write))?; // when the flushed data is larger than or equal to half of yet-to-be-flushed data, @@ -146,21 +147,27 @@ impl BufWriter { self.remove_written(); } - Poll::Ready(Ok(&mut self.buf[self.buffered..])) + Poll::Ready(Ok(Buffer { + write_buffer: WriteBuffer::new_initialized(&mut self.buf[self.buffered..]), + buffered: &mut self.buffered, + })) } +} + +pub struct Buffer<'a> { + buffered: &'a mut usize, + pub write_buffer: WriteBuffer<'a>, +} - pub fn produce(&mut self, amt: usize) { - debug_assert!( - self.buffered + amt <= self.buf.len(), - "produce called with amt exceeding buffer capacity" - ); - self.buffered += amt; +impl Drop for Buffer<'_> { + fn drop(&mut self) { + *self.buffered += self.write_buffer.written_len(); } } macro_rules! impl_buf_writer { ($poll_close: tt) => { - use crate::generic::write::{AsyncBufWrite, BufWriter as GenericBufWriter}; + use crate::generic::write::{AsyncBufWrite, BufWriter as GenericBufWriter, Buffer}; use futures_core::ready; use pin_project_lite::pin_project; @@ -258,15 +265,11 @@ macro_rules! impl_buf_writer { fn poll_partial_flush_buf( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll> { + ) -> Poll>> { let this = self.project(); this.inner .poll_partial_flush_buf(&mut get_poll_write(this.writer, cx)) } - - fn produce(self: Pin<&mut Self>, amt: usize) { - self.project().inner.produce(amt) - } } }; } diff --git a/crates/async-compression/src/generic/write/decoder.rs b/crates/async-compression/src/generic/write/decoder.rs index 0ab58bab..c9108278 100644 --- a/crates/async-compression/src/generic/write/decoder.rs +++ b/crates/async-compression/src/generic/write/decoder.rs @@ -39,12 +39,12 @@ impl Decoder { decoder: &mut dyn DecodeV2, ) -> Poll> { loop { - let output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = WriteBuffer::new_initialized(output); + let mut output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?; + let output = &mut output.write_buffer; self.state = match self.state { State::Decoding => { - if decoder.decode(input, &mut output)? { + if decoder.decode(input, output)? { State::Finishing } else { State::Decoding @@ -52,7 +52,7 @@ impl Decoder { } State::Finishing => { - if decoder.finish(&mut output)? { + if decoder.finish(output)? { State::Done } else { State::Finishing @@ -64,9 +64,6 @@ impl Decoder { } }; - let produced = output.written_len(); - writer.as_mut().produce(produced); - if let State::Done = self.state { return Poll::Ready(Ok(())); } @@ -103,17 +100,17 @@ impl Decoder { decoder: &mut dyn DecodeV2, ) -> Poll> { loop { - let output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = WriteBuffer::new_initialized(output); + let mut output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?; + let output = &mut output.write_buffer; let (state, done) = match self.state { State::Decoding => { - let done = decoder.flush(&mut output)?; + let done = decoder.flush(output)?; (State::Decoding, done) } State::Finishing => { - if decoder.finish(&mut output)? { + if decoder.finish(output)? { (State::Done, false) } else { (State::Finishing, false) @@ -125,9 +122,6 @@ impl Decoder { self.state = state; - let produced = output.written_len(); - writer.as_mut().produce(produced); - if done { break Poll::Ready(Ok(())); } diff --git a/crates/async-compression/src/generic/write/encoder.rs b/crates/async-compression/src/generic/write/encoder.rs index 8c030780..8537cc64 100644 --- a/crates/async-compression/src/generic/write/encoder.rs +++ b/crates/async-compression/src/generic/write/encoder.rs @@ -39,12 +39,12 @@ impl Encoder { encoder: &mut dyn EncodeV2, ) -> Poll> { loop { - let output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = WriteBuffer::new_initialized(output); + let mut output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?; + let output = &mut output.write_buffer; self.state = match self.state { State::Encoding => { - encoder.encode(input, &mut output)?; + encoder.encode(input, output)?; State::Encoding } @@ -53,9 +53,6 @@ impl Encoder { } }; - let produced = output.written_len(); - writer.as_mut().produce(produced); - if input.unwritten().is_empty() { break Poll::Ready(Ok(())); } @@ -88,20 +85,17 @@ impl Encoder { encoder: &mut dyn EncodeV2, ) -> Poll> { loop { - let output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = WriteBuffer::new_initialized(output); + let mut output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?; + let output = &mut output.write_buffer; let done = match self.state { - State::Encoding => encoder.flush(&mut output)?, + State::Encoding => encoder.flush(output)?, State::Finishing | State::Done => { break Poll::Ready(Err(io::Error::other("Flush after close"))) } }; - let produced = output.written_len(); - writer.as_mut().produce(produced); - if done { break Poll::Ready(Ok(())); } @@ -115,12 +109,12 @@ impl Encoder { encoder: &mut dyn EncodeV2, ) -> Poll> { loop { - let output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = WriteBuffer::new_initialized(output); + let mut output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?; + let output = &mut output.write_buffer; self.state = match self.state { State::Encoding | State::Finishing => { - if encoder.finish(&mut output)? { + if encoder.finish(output)? { State::Done } else { State::Finishing @@ -130,9 +124,6 @@ impl Encoder { State::Done => State::Done, }; - let produced = output.written_len(); - writer.as_mut().produce(produced); - if let State::Done = self.state { break Poll::Ready(Ok(())); }