Skip to content

Commit 70c2a79

Browse files
authored
Fix #487 (#501)
1 parent 33e4ee6 commit 70c2a79

File tree

2 files changed

+28
-15
lines changed

2 files changed

+28
-15
lines changed

pgdog/src/net/error.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@ use tokio_rustls::rustls;
77

88
#[derive(Debug, Error)]
99
pub enum Error {
10-
#[error("{0}")]
10+
#[error("io: {0}")]
1111
Io(#[from] std::io::Error),
1212

13+
#[error("connection closed by peer")]
14+
UnexpectedEof,
15+
1316
#[error("unsupported startup request: {0}")]
1417
UnsupportedStartup(i32),
1518

@@ -58,9 +61,6 @@ pub enum Error {
5861
#[error("unknown transaction state identifier: {0}")]
5962
UnknownTransactionStateIdentifier(char),
6063

61-
#[error("eof")]
62-
Eof,
63-
6464
#[error("not text encoding")]
6565
NotTextEncoding,
6666

pgdog/src/net/stream.rs

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufStream, R
66
use tokio::net::TcpStream;
77
use tracing::{debug, enabled, trace, Level};
88

9-
use std::io::Error;
9+
use std::io::{Error, ErrorKind};
1010
use std::net::SocketAddr;
1111
use std::ops::Deref;
1212
use std::pin::Pin;
@@ -108,8 +108,8 @@ impl Stream {
108108
pub async fn check(&mut self) -> Result<(), crate::net::Error> {
109109
let mut buf = [0u8; 1];
110110
match self {
111-
Self::Plain(plain) => plain.get_mut().peek(&mut buf).await?,
112-
Self::Tls(tls) => tls.get_mut().get_mut().0.peek(&mut buf).await?,
111+
Self::Plain(plain) => eof(plain.get_mut().peek(&mut buf).await)?,
112+
Self::Tls(tls) => eof(tls.get_mut().get_mut().0.peek(&mut buf).await)?,
113113
Self::DevNull => 0,
114114
};
115115

@@ -126,8 +126,8 @@ impl Stream {
126126
let bytes = message.to_bytes()?;
127127

128128
match self {
129-
Stream::Plain(ref mut stream) => stream.write_all(&bytes).await?,
130-
Stream::Tls(ref mut stream) => stream.write_all(&bytes).await?,
129+
Stream::Plain(ref mut stream) => eof(stream.write_all(&bytes).await)?,
130+
Stream::Tls(ref mut stream) => eof(stream.write_all(&bytes).await)?,
131131
Self::DevNull => (),
132132
}
133133

@@ -165,7 +165,7 @@ impl Stream {
165165
message: &impl Protocol,
166166
) -> Result<usize, crate::net::Error> {
167167
let sent = self.send(message).await?;
168-
self.flush().await?;
168+
eof(self.flush().await)?;
169169
trace!("😳");
170170

171171
Ok(sent)
@@ -180,7 +180,7 @@ impl Stream {
180180
for message in messages {
181181
sent += self.send(message).await?;
182182
}
183-
self.flush().await?;
183+
eof(self.flush().await)?;
184184
trace!("😳");
185185
Ok(sent)
186186
}
@@ -199,15 +199,15 @@ impl Stream {
199199

200200
/// Read data into a buffer, avoiding unnecessary allocations.
201201
pub async fn read_buf(&mut self, bytes: &mut BytesMut) -> Result<Message, crate::net::Error> {
202-
let code = self.read_u8().await?;
203-
let len = self.read_i32().await?;
202+
let code = eof(self.read_u8().await)?;
203+
let len = eof(self.read_i32().await)?;
204204

205205
bytes.put_u8(code);
206206
bytes.put_i32(len);
207207

208208
// Length must be at least 4 bytes.
209209
if len < 4 {
210-
return Err(crate::net::Error::Eof);
210+
return Err(crate::net::Error::UnexpectedEof);
211211
}
212212

213213
let capacity = len as usize + 1;
@@ -218,7 +218,7 @@ impl Stream {
218218
bytes.set_len(capacity);
219219
}
220220

221-
self.read_exact(&mut bytes[5..capacity]).await?;
221+
eof(self.read_exact(&mut bytes[5..capacity]).await)?;
222222

223223
let message = Message::new(bytes.split().freeze());
224224

@@ -261,6 +261,19 @@ impl Stream {
261261
}
262262
}
263263

264+
fn eof<T>(result: std::io::Result<T>) -> Result<T, crate::net::Error> {
265+
match result {
266+
Ok(val) => Ok(val),
267+
Err(err) => {
268+
if err.kind() == ErrorKind::UnexpectedEof {
269+
Err(crate::net::Error::UnexpectedEof)
270+
} else {
271+
Err(crate::net::Error::Io(err))
272+
}
273+
}
274+
}
275+
}
276+
264277
/// Wrapper around SocketAddr
265278
/// to make it easier to debug.
266279
pub struct PeerAddr {

0 commit comments

Comments
 (0)