Skip to content

Commit a2df295

Browse files
authored
feat: log the number of read/written bytes on IMAP stream read error (#6924)
1 parent 6df1d16 commit a2df295

File tree

4 files changed

+237
-16
lines changed

4 files changed

+237
-16
lines changed

src/imap/client.rs

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,13 @@ use tokio::io::BufWriter;
88

99
use super::capabilities::Capabilities;
1010
use crate::context::Context;
11-
use crate::log::{info, warn};
11+
use crate::log::{LoggingStream, info, warn};
1212
use crate::login_param::{ConnectionCandidate, ConnectionSecurity};
1313
use crate::net::dns::{lookup_host_with_cache, update_connect_timestamp};
1414
use crate::net::proxy::ProxyConfig;
1515
use crate::net::session::SessionStream;
1616
use crate::net::tls::wrap_tls;
17-
use crate::net::{
18-
connect_tcp_inner, connect_tls_inner, run_connection_attempts, update_connection_history,
19-
};
17+
use crate::net::{connect_tcp_inner, run_connection_attempts, update_connection_history};
2018
use crate::tools::time;
2119

2220
#[derive(Debug)]
@@ -126,12 +124,12 @@ impl Client {
126124
);
127125
let res = match security {
128126
ConnectionSecurity::Tls => {
129-
Client::connect_secure(resolved_addr, host, strict_tls).await
127+
Client::connect_secure(context, resolved_addr, host, strict_tls).await
130128
}
131129
ConnectionSecurity::Starttls => {
132-
Client::connect_starttls(resolved_addr, host, strict_tls).await
130+
Client::connect_starttls(context, resolved_addr, host, strict_tls).await
133131
}
134-
ConnectionSecurity::Plain => Client::connect_insecure(resolved_addr).await,
132+
ConnectionSecurity::Plain => Client::connect_insecure(context, resolved_addr).await,
135133
};
136134
match res {
137135
Ok(client) => {
@@ -202,8 +200,17 @@ impl Client {
202200
}
203201
}
204202

205-
async fn connect_secure(addr: SocketAddr, hostname: &str, strict_tls: bool) -> Result<Self> {
206-
let tls_stream = connect_tls_inner(addr, hostname, strict_tls, alpn(addr.port())).await?;
203+
async fn connect_secure(
204+
context: &Context,
205+
addr: SocketAddr,
206+
hostname: &str,
207+
strict_tls: bool,
208+
) -> Result<Self> {
209+
let tcp_stream = connect_tcp_inner(addr).await?;
210+
let account_id = context.get_id();
211+
let events = context.events.clone();
212+
let logging_stream = LoggingStream::new(tcp_stream, account_id, events);
213+
let tls_stream = wrap_tls(strict_tls, hostname, alpn(addr.port()), logging_stream).await?;
207214
let buffered_stream = BufWriter::new(tls_stream);
208215
let session_stream: Box<dyn SessionStream> = Box::new(buffered_stream);
209216
let mut client = Client::new(session_stream);
@@ -214,9 +221,12 @@ impl Client {
214221
Ok(client)
215222
}
216223

217-
async fn connect_insecure(addr: SocketAddr) -> Result<Self> {
224+
async fn connect_insecure(context: &Context, addr: SocketAddr) -> Result<Self> {
218225
let tcp_stream = connect_tcp_inner(addr).await?;
219-
let buffered_stream = BufWriter::new(tcp_stream);
226+
let account_id = context.get_id();
227+
let events = context.events.clone();
228+
let logging_stream = LoggingStream::new(tcp_stream, account_id, events);
229+
let buffered_stream = BufWriter::new(logging_stream);
220230
let session_stream: Box<dyn SessionStream> = Box::new(buffered_stream);
221231
let mut client = Client::new(session_stream);
222232
let _greeting = client
@@ -226,9 +236,18 @@ impl Client {
226236
Ok(client)
227237
}
228238

229-
async fn connect_starttls(addr: SocketAddr, host: &str, strict_tls: bool) -> Result<Self> {
239+
async fn connect_starttls(
240+
context: &Context,
241+
addr: SocketAddr,
242+
host: &str,
243+
strict_tls: bool,
244+
) -> Result<Self> {
230245
let tcp_stream = connect_tcp_inner(addr).await?;
231246

247+
let account_id = context.get_id();
248+
let events = context.events.clone();
249+
let tcp_stream = LoggingStream::new(tcp_stream, account_id, events);
250+
232251
// Run STARTTLS command and convert the client back into a stream.
233252
let buffered_tcp_stream = BufWriter::new(tcp_stream);
234253
let mut client = async_imap::Client::new(buffered_tcp_stream);
@@ -246,7 +265,6 @@ impl Client {
246265
let tls_stream = wrap_tls(strict_tls, host, &[], tcp_stream)
247266
.await
248267
.context("STARTTLS upgrade failed")?;
249-
250268
let buffered_stream = BufWriter::new(tls_stream);
251269
let session_stream: Box<dyn SessionStream> = Box::new(buffered_stream);
252270
let client = Client::new(session_stream);

src/log.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
use crate::context::Context;
66

7+
mod stream;
8+
9+
pub(crate) use stream::LoggingStream;
10+
711
macro_rules! info {
812
($ctx:expr, $msg:expr) => {
913
info!($ctx, $msg,)

src/log/stream.rs

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
//! Stream that logs errors as events.
2+
//!
3+
//! This stream can be used to wrap IMAP,
4+
//! SMTP and HTTP streams so errors
5+
//! that occur are logged before
6+
//! they are processed.
7+
8+
use std::net::SocketAddr;
9+
use std::pin::Pin;
10+
use std::task::{Context, Poll};
11+
use std::time::Duration;
12+
13+
use anyhow::Result;
14+
use pin_project::pin_project;
15+
16+
use crate::events::{Event, EventType, Events};
17+
use crate::net::session::SessionStream;
18+
19+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
20+
21+
#[derive(Debug)]
22+
struct Metrics {
23+
/// Total number of bytes read.
24+
pub total_read: usize,
25+
26+
/// Total number of bytes written.
27+
pub total_written: usize,
28+
}
29+
30+
impl Metrics {
31+
fn new() -> Self {
32+
Self {
33+
total_read: 0,
34+
total_written: 0,
35+
}
36+
}
37+
}
38+
39+
/// Stream that logs errors to the event channel.
40+
#[derive(Debug)]
41+
#[pin_project]
42+
pub(crate) struct LoggingStream<S: SessionStream> {
43+
#[pin]
44+
inner: S,
45+
46+
/// Account ID for logging.
47+
account_id: u32,
48+
49+
/// Event channel.
50+
events: Events,
51+
52+
/// Metrics for this stream.
53+
metrics: Metrics,
54+
}
55+
56+
impl<S: SessionStream> LoggingStream<S> {
57+
pub fn new(inner: S, account_id: u32, events: Events) -> Self {
58+
Self {
59+
inner,
60+
account_id,
61+
events,
62+
metrics: Metrics::new(),
63+
}
64+
}
65+
}
66+
67+
impl<S: SessionStream> AsyncRead for LoggingStream<S> {
68+
fn poll_read(
69+
self: Pin<&mut Self>,
70+
cx: &mut Context<'_>,
71+
buf: &mut ReadBuf<'_>,
72+
) -> Poll<std::io::Result<()>> {
73+
let this = self.project();
74+
let peer_addr = this.inner.peer_addr();
75+
let old_remaining = buf.remaining();
76+
77+
let res = this.inner.poll_read(cx, buf);
78+
79+
if let Poll::Ready(Err(ref err)) = res {
80+
debug_assert!(
81+
peer_addr.is_ok(),
82+
"Logging stream should be created over a bound socket"
83+
);
84+
let log_message = match peer_addr {
85+
Ok(peer_addr) => format!(
86+
"Read error on stream {peer_addr:?} after reading {} and writing {} bytes: {err}.",
87+
this.metrics.total_read, this.metrics.total_written
88+
),
89+
Err(_) => {
90+
format!("Read error on a stream that does not have a peer address: {err}.")
91+
}
92+
};
93+
this.events.emit(Event {
94+
id: *this.account_id,
95+
typ: EventType::Warning(log_message),
96+
});
97+
}
98+
99+
let n = old_remaining - buf.remaining();
100+
this.metrics.total_read = this.metrics.total_read.saturating_add(n);
101+
102+
res
103+
}
104+
}
105+
106+
impl<S: SessionStream> AsyncWrite for LoggingStream<S> {
107+
fn poll_write(
108+
self: Pin<&mut Self>,
109+
cx: &mut std::task::Context<'_>,
110+
buf: &[u8],
111+
) -> Poll<std::io::Result<usize>> {
112+
let this = self.project();
113+
let res = this.inner.poll_write(cx, buf);
114+
if let Poll::Ready(Ok(n)) = res {
115+
this.metrics.total_written = this.metrics.total_written.saturating_add(n);
116+
}
117+
res
118+
}
119+
120+
fn poll_flush(
121+
self: Pin<&mut Self>,
122+
cx: &mut std::task::Context<'_>,
123+
) -> Poll<std::io::Result<()>> {
124+
self.project().inner.poll_flush(cx)
125+
}
126+
127+
fn poll_shutdown(
128+
self: Pin<&mut Self>,
129+
cx: &mut std::task::Context<'_>,
130+
) -> Poll<std::io::Result<()>> {
131+
self.project().inner.poll_shutdown(cx)
132+
}
133+
134+
fn poll_write_vectored(
135+
self: Pin<&mut Self>,
136+
cx: &mut Context<'_>,
137+
bufs: &[std::io::IoSlice<'_>],
138+
) -> Poll<std::io::Result<usize>> {
139+
let this = self.project();
140+
let res = this.inner.poll_write_vectored(cx, bufs);
141+
if let Poll::Ready(Ok(n)) = res {
142+
this.metrics.total_written = this.metrics.total_written.saturating_add(n);
143+
}
144+
res
145+
}
146+
147+
fn is_write_vectored(&self) -> bool {
148+
self.inner.is_write_vectored()
149+
}
150+
}
151+
152+
impl<S: SessionStream> SessionStream for LoggingStream<S> {
153+
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
154+
self.inner.set_read_timeout(timeout)
155+
}
156+
157+
fn peer_addr(&self) -> Result<SocketAddr> {
158+
self.inner.peer_addr()
159+
}
160+
}

src/net/session.rs

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,101 @@
1+
use anyhow::Result;
12
use fast_socks5::client::Socks5Stream;
3+
use std::net::SocketAddr;
24
use std::pin::Pin;
35
use std::time::Duration;
46
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufStream, BufWriter};
7+
use tokio::net::TcpStream;
58
use tokio_io_timeout::TimeoutStream;
69

710
pub(crate) trait SessionStream:
811
AsyncRead + AsyncWrite + Unpin + Send + Sync + std::fmt::Debug
912
{
1013
/// Change the read timeout on the session stream.
1114
fn set_read_timeout(&mut self, timeout: Option<Duration>);
15+
16+
/// Returns the remote address that this stream is connected to.
17+
fn peer_addr(&self) -> Result<SocketAddr>;
1218
}
1319

1420
impl SessionStream for Box<dyn SessionStream> {
1521
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
1622
self.as_mut().set_read_timeout(timeout);
1723
}
24+
25+
fn peer_addr(&self) -> Result<SocketAddr> {
26+
self.as_ref().peer_addr()
27+
}
1828
}
1929
impl<T: SessionStream> SessionStream for async_native_tls::TlsStream<T> {
2030
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
2131
self.get_mut().set_read_timeout(timeout);
2232
}
33+
34+
fn peer_addr(&self) -> Result<SocketAddr> {
35+
self.get_ref().peer_addr()
36+
}
2337
}
2438
impl<T: SessionStream> SessionStream for tokio_rustls::client::TlsStream<T> {
2539
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
2640
self.get_mut().0.set_read_timeout(timeout);
2741
}
42+
fn peer_addr(&self) -> Result<SocketAddr> {
43+
self.get_ref().0.peer_addr()
44+
}
2845
}
2946
impl<T: SessionStream> SessionStream for BufStream<T> {
3047
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
3148
self.get_mut().set_read_timeout(timeout);
3249
}
50+
51+
fn peer_addr(&self) -> Result<SocketAddr> {
52+
self.get_ref().peer_addr()
53+
}
3354
}
3455
impl<T: SessionStream> SessionStream for BufWriter<T> {
3556
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
3657
self.get_mut().set_read_timeout(timeout);
3758
}
59+
60+
fn peer_addr(&self) -> Result<SocketAddr> {
61+
self.get_ref().peer_addr()
62+
}
3863
}
39-
impl<T: AsyncRead + AsyncWrite + Send + Sync + std::fmt::Debug> SessionStream
40-
for Pin<Box<TimeoutStream<T>>>
41-
{
64+
impl SessionStream for Pin<Box<TimeoutStream<TcpStream>>> {
4265
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
4366
self.as_mut().set_read_timeout_pinned(timeout);
4467
}
68+
69+
fn peer_addr(&self) -> Result<SocketAddr> {
70+
Ok(self.get_ref().peer_addr()?)
71+
}
4572
}
4673
impl<T: SessionStream> SessionStream for Socks5Stream<T> {
4774
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
4875
self.get_socket_mut().set_read_timeout(timeout)
4976
}
77+
78+
fn peer_addr(&self) -> Result<SocketAddr> {
79+
self.get_socket_ref().peer_addr()
80+
}
5081
}
5182
impl<T: SessionStream> SessionStream for shadowsocks::ProxyClientStream<T> {
5283
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
5384
self.get_mut().set_read_timeout(timeout)
5485
}
86+
87+
fn peer_addr(&self) -> Result<SocketAddr> {
88+
self.get_ref().peer_addr()
89+
}
5590
}
5691
impl<T: SessionStream> SessionStream for async_imap::DeflateStream<T> {
5792
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
5893
self.get_mut().set_read_timeout(timeout)
5994
}
95+
96+
fn peer_addr(&self) -> Result<SocketAddr> {
97+
self.get_ref().peer_addr()
98+
}
6099
}
61100

62101
/// Session stream with a read buffer.

0 commit comments

Comments
 (0)