Skip to content

Commit acf250f

Browse files
authored
duplex: Ensure written data is flushed (#944)
In testing Postgres dumps through Linkerd, we see some situations where data is written to rustls but does not become visible to the client. Rustls documents that written data may need to be flushed. This change updates the `HalfDuplex::copy_into` function to flush data when all available data has been read. This change also increases the copy buffer from 4KB to 64KB to reduce processing overhead.
1 parent f481c66 commit acf250f

File tree

1 file changed

+154
-105
lines changed

1 file changed

+154
-105
lines changed

linkerd/duplex/src/lib.rs

Lines changed: 154 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
use bytes::{Buf, BufMut};
44
use futures::ready;
5-
use io::{AsyncRead, AsyncWrite};
6-
use linkerd_io as io;
5+
use linkerd_io::{self as io, AsyncRead, AsyncWrite};
76
use pin_project::pin_project;
87
use std::task::{Context, Poll};
98
use std::{future::Future, pin::Pin};
10-
use tracing::trace;
9+
use tracing::{error, trace};
1110

1211
/// A future piping data bi-directionally to In and Out.
1312
#[pin_project]
@@ -24,6 +23,7 @@ struct HalfDuplex<T> {
2423
#[pin]
2524
io: T,
2625
direction: &'static str,
26+
flushing: bool,
2727
}
2828

2929
/// A buffer used to copy bytes from one IO to another.
@@ -39,6 +39,18 @@ struct CopyBuf {
3939
write_pos: usize,
4040
}
4141

42+
enum Buffered {
43+
NotEmpty,
44+
Read(usize),
45+
Eof,
46+
}
47+
48+
enum Drained {
49+
BufferEmpty,
50+
Partial(usize),
51+
All(usize),
52+
}
53+
4254
impl<In, Out> Duplex<In, Out>
4355
where
4456
In: AsyncRead + AsyncWrite + Unpin,
@@ -57,9 +69,9 @@ where
5769
In: AsyncRead + AsyncWrite + Unpin,
5870
Out: AsyncRead + AsyncWrite + Unpin,
5971
{
60-
type Output = Result<(), io::Error>;
72+
type Output = io::Result<()>;
6173

62-
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
74+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> {
6375
let mut this = self.project();
6476
// This purposefully ignores the Async part, since we don't want to
6577
// return early if the first half isn't ready, but the other half
@@ -85,17 +97,19 @@ where
8597
is_shutdown: false,
8698
io,
8799
direction,
100+
flushing: false,
88101
}
89102
}
90103

91-
fn copy_into<U>(
104+
/// Reads data from `self`, buffering it, and writing it to `dst.
105+
///
106+
/// Returns ready when the stream has shutdown such that no more data may be
107+
/// proxied.
108+
fn copy_into<U: AsyncWrite + Unpin>(
92109
&mut self,
93110
dst: &mut HalfDuplex<U>,
94111
cx: &mut Context<'_>,
95-
) -> Poll<Result<(), io::Error>>
96-
where
97-
U: AsyncWrite + Unpin,
98-
{
112+
) -> io::Poll<()> {
99113
// Since Duplex::poll() intentionally ignores the Async part of our
100114
// return value, we may be polled again after returning Ready, if the
101115
// other half isn't ready. In that case, if the destination has
@@ -105,61 +119,156 @@ where
105119
trace!(direction = %self.direction, "already shutdown");
106120
return Poll::Ready(Ok(()));
107121
}
122+
123+
// If the last invocation returned pending while flushing, resume
124+
// flushing and only proceed when the flush is complete.
125+
if self.flushing {
126+
ready!(self.poll_flush(dst, cx))?;
127+
}
128+
129+
// `needs_flush` is set to true if the buffer is written so that, if a
130+
// read returns pending, that data may be flushed.
131+
let mut needs_flush = false;
132+
108133
loop {
109-
ready!(self.poll_read(cx))?;
110-
ready!(self.poll_write_into(dst, cx))?;
111-
if self.buf.is_none() {
112-
trace!(direction = %self.direction, "shutting down");
113-
debug_assert!(!dst.is_shutdown, "attempted to shut down destination twice");
114-
ready!(Pin::new(&mut dst.io).poll_shutdown(cx))?;
115-
dst.is_shutdown = true;
116-
117-
return Poll::Ready(Ok(()));
134+
// As long as the underlying socket is alive, ensure we've read data
135+
// from it into the local buffer.
136+
match self.poll_buffer(cx)? {
137+
Poll::Pending => {
138+
// If there's no data to be read and we've written data, try
139+
// flushing before returning pending.
140+
if needs_flush {
141+
// The poll status of the flush isn't relevant, as we
142+
// have registered interest in the read (and maybe the
143+
// write as well). If the flush did not complete
144+
// `self.flushing` is true so that it may be resumed on
145+
// the next poll.
146+
let _ = self.poll_flush(dst, cx)?;
147+
}
148+
return Poll::Pending;
149+
}
150+
151+
Poll::Ready(Buffered::NotEmpty) | Poll::Ready(Buffered::Read(_)) => {
152+
// Write buffered data to the destination.
153+
match self.drain_into(dst, cx)? {
154+
// All of the buffered data was written, so continue reading more.
155+
Drained::All(sz) => {
156+
debug_assert!(sz > 0);
157+
needs_flush = true;
158+
}
159+
// Only some of the buffered data could be written
160+
// before the destination became pending. Try to flush
161+
// the written data to get capacity.
162+
Drained::Partial(_) => {
163+
ready!(self.poll_flush(dst, cx))?;
164+
// If the flush completed, try writing again to
165+
// ensure that we have a notification registered. If
166+
// all of the buffered data still cannot be written,
167+
// return pending. Otherwise, continue.
168+
if let Drained::Partial(_) = self.drain_into(dst, cx)? {
169+
return Poll::Pending;
170+
}
171+
needs_flush = false;
172+
}
173+
Drained::BufferEmpty => {
174+
error!(
175+
direction = self.direction,
176+
"Invalid state: attempted to write from an empty buffer"
177+
);
178+
debug_assert!(false, "The write buffer should never be empty");
179+
return Poll::Ready(Ok(()));
180+
}
181+
}
182+
}
183+
184+
// The socket closed, so initiate shutdown on the destination.
185+
Poll::Ready(Buffered::Eof) => {
186+
trace!(direction = %self.direction, "shutting down");
187+
debug_assert!(!dst.is_shutdown, "attempted to shut down destination twice");
188+
ready!(Pin::new(&mut dst.io).poll_shutdown(cx))?;
189+
dst.is_shutdown = true;
190+
return Poll::Ready(Ok(()));
191+
}
118192
}
119193
}
120194
}
121195

122-
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
123-
let mut is_eof = false;
124-
if let Some(ref mut buf) = self.buf {
125-
if !buf.has_remaining() {
126-
buf.reset();
196+
/// Attempts to read and buffer data from the underlying stream, returning
197+
/// the number of bytes read. If the buffer already has data, no new data
198+
/// will be read.
199+
fn poll_buffer(&mut self, cx: &mut Context<'_>) -> io::Poll<Buffered> {
200+
// Buffer data only if no data is buffered.
201+
//
202+
// TODO should we read more data as long as there's buffer capacity?
203+
// To do this, we'd have to get more complex about handling EOF.
204+
if let Some(buf) = self.buf.as_mut() {
205+
if buf.has_remaining() {
206+
// Data was already buffered, so just return immediately.
207+
trace!(direction = %self.direction, remaining = buf.remaining(), "skipping read");
208+
return Poll::Ready(Ok(Buffered::NotEmpty));
209+
}
127210

128-
trace!(direction = %self.direction, "reading");
129-
let n = ready!(io::poll_read_buf(Pin::new(&mut self.io), cx, buf))?;
130-
trace!(direction = %self.direction, "read {}B", n);
211+
buf.reset();
212+
trace!(direction = %self.direction, "reading");
213+
let sz = ready!(io::poll_read_buf(Pin::new(&mut self.io), cx, buf))?;
214+
trace!(direction = %self.direction, "read {}B", sz);
131215

132-
is_eof = n == 0;
216+
// If data was read, return the number of bytes read.
217+
if sz > 0 {
218+
return Poll::Ready(Ok(Buffered::Read(sz)));
133219
}
134220
}
135-
if is_eof {
136-
trace!("eof");
137-
self.buf = None;
138-
}
139221

140-
Poll::Ready(Ok(()))
222+
// No more data can be read.
223+
trace!("eof");
224+
self.buf = None;
225+
Poll::Ready(Ok(Buffered::Eof))
226+
}
227+
228+
/// Attempts to flush the destination. `self.flushing` is set to true iff the
229+
/// flush operation did not complete.
230+
fn poll_flush<U: AsyncWrite + Unpin>(
231+
&mut self,
232+
dst: &mut HalfDuplex<U>,
233+
cx: &mut Context<'_>,
234+
) -> io::Poll<()> {
235+
trace!(direction = %self.direction, "flushing");
236+
let poll = Pin::new(&mut dst.io).poll_flush(cx);
237+
self.flushing = poll.is_pending();
238+
if poll.is_ready() {
239+
trace!(direction = %self.direction, "flushed");
240+
}
241+
poll
141242
}
142243

143-
fn poll_write_into<U>(
244+
/// Writes as much buffered data as possible, returning the number of bytes written.
245+
fn drain_into<U: AsyncWrite + Unpin>(
144246
&mut self,
145247
dst: &mut HalfDuplex<U>,
146248
cx: &mut Context<'_>,
147-
) -> Poll<Result<(), io::Error>>
148-
where
149-
U: AsyncWrite + Unpin,
150-
{
151-
if let Some(ref mut buf) = self.buf {
249+
) -> io::Result<Drained> {
250+
let mut sz = 0;
251+
252+
if let Some(buf) = self.buf.as_mut() {
152253
while buf.has_remaining() {
153254
trace!(direction = %self.direction, "writing {}B", buf.remaining());
154-
let n = ready!(io::poll_write_buf(Pin::new(&mut dst.io), cx, buf))?;
255+
let n = match io::poll_write_buf(Pin::new(&mut dst.io), cx, buf)? {
256+
Poll::Pending => return Ok(Drained::Partial(sz)),
257+
Poll::Ready(n) => n,
258+
};
155259
trace!(direction = %self.direction, "wrote {}B", n);
156260
if n == 0 {
157-
return Poll::Ready(Err(write_zero()));
261+
return Err(write_zero());
158262
}
263+
sz += n;
159264
}
160265
}
161266

162-
Poll::Ready(Ok(()))
267+
if sz == 0 {
268+
Ok(Drained::BufferEmpty)
269+
} else {
270+
Ok(Drained::All(sz))
271+
}
163272
}
164273

165274
fn is_done(&self) -> bool {
@@ -174,7 +283,7 @@ fn write_zero() -> io::Error {
174283
impl CopyBuf {
175284
fn new() -> Self {
176285
CopyBuf {
177-
buf: Box::new([0; 4096]),
286+
buf: Box::new([0; 64 * 1024]),
178287
read_pos: 0,
179288
write_pos: 0,
180289
}
@@ -208,11 +317,9 @@ unsafe impl BufMut for CopyBuf {
208317
}
209318

210319
fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice {
320+
// Safety: The memory is initialized. This is the only way to turn a
321+
// `&[T]` into a `&[MaybeUninit<T>]` without ptr casting.
211322
unsafe {
212-
// this is, in fact, _totally fine and safe_: all the memory is
213-
// initialized.
214-
// there's just no way to turn a `&[T]` into a `&[MaybeUninit<T>]`
215-
// without ptr casting.
216323
bytes::buf::UninitSlice::from_raw_parts_mut(
217324
&mut self.buf[self.write_pos] as *mut _,
218325
self.buf.len() - self.write_pos,
@@ -225,61 +332,3 @@ unsafe impl BufMut for CopyBuf {
225332
self.write_pos += cnt;
226333
}
227334
}
228-
229-
// #[cfg(test)]
230-
// mod tests {
231-
// use std::io::{Error, Read, Result, Write};
232-
// use std::sync::atomic::{AtomicBool, Ordering};
233-
234-
// use super::*;
235-
// use tokio::io::{AsyncRead, AsyncWrite};
236-
237-
// #[derive(Debug)]
238-
// struct DoneIo(AtomicBool);
239-
240-
// impl<'a> Read for &'a DoneIo {
241-
// fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
242-
// if self.0.swap(false, Ordering::Relaxed) {
243-
// Ok(buf.len())
244-
// } else {
245-
// Ok(0)
246-
// }
247-
// }
248-
// }
249-
250-
// impl<'a> AsyncRead for &'a DoneIo {
251-
// unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [u8]) -> bool {
252-
// true
253-
// }
254-
// }
255-
256-
// impl<'a> Write for &'a DoneIo {
257-
// fn write(&mut self, buf: &[u8]) -> Result<usize> {
258-
// Ok(buf.len())
259-
// }
260-
// fn flush(&mut self) -> Result<()> {
261-
// Ok(())
262-
// }
263-
// }
264-
// impl<'a> AsyncWrite for &'a DoneIo {
265-
// fn shutdown(&mut self) -> Poll<(), Error> {
266-
// if self.0.swap(false, Ordering::Relaxed) {
267-
// Ok(Async::NotReady)
268-
// } else {
269-
// Ok(Async::Ready(()))
270-
// }
271-
// }
272-
// }
273-
274-
// #[test]
275-
// fn duplex_doesnt_hang_when_one_half_finishes() {
276-
// // Test reproducing an infinite loop in Duplex that caused issue #519,
277-
// // where a Duplex would enter an infinite loop when one half finishes.
278-
// let io_1 = DoneIo(AtomicBool::new(true));
279-
// let io_2 = DoneIo(AtomicBool::new(true));
280-
// let mut duplex = Duplex::new(&io_1, &io_2);
281-
282-
// assert_eq!(duplex.poll().unwrap(), Async::NotReady);
283-
// assert_eq!(duplex.poll().unwrap(), Async::Ready(()));
284-
// }
285-
// }

0 commit comments

Comments
 (0)