Skip to content

Commit 9941578

Browse files
committed
fix: improve test coverage
1 parent cad0e9a commit 9941578

File tree

2 files changed

+80
-24
lines changed

2 files changed

+80
-24
lines changed

src/api/proxy.rs

Lines changed: 77 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,29 @@ use tracing::{error, Instrument, Span};
1717

1818
use crate::config::ProxyProtocol;
1919

20+
trait RemoteAddr {
21+
fn remote_addr(&self) -> IoResult<SocketAddr>;
22+
}
23+
24+
impl RemoteAddr for TcpStream {
25+
fn remote_addr(&self) -> IoResult<SocketAddr> {
26+
self.peer_addr()
27+
}
28+
}
29+
2030
pub(super) fn wrap(
2131
listener: TcpListener,
2232
proxy: ProxyProtocol,
23-
) -> impl Stream<Item = IoResult<impl Future<Output = IoResult<ProxyStream>>>> + Send {
33+
) -> impl Stream<
34+
Item = IoResult<
35+
impl Future<Output = IoResult<impl AsyncRead + AsyncWrite + Send + Unpin + 'static>>,
36+
>,
37+
> + Send {
2438
TcpListenerStream::new(listener)
2539
.map_ok(move |conn| conn.source(proxy))
2640
.map_ok(|mut conn| {
2741
let span = Span::current();
28-
span.record("remote.addr", &debug(conn.peer_addr()));
42+
span.record("remote.addr", &debug(conn.remote_addr()));
2943
let span_clone = span.clone();
3044

3145
async move {
@@ -44,12 +58,15 @@ pub(super) fn wrap(
4458
})
4559
}
4660

47-
trait ToProxyStream {
48-
fn source(self, proxy: ProxyProtocol) -> ProxyStream;
61+
trait ToProxyStream: Sized {
62+
fn source(self, proxy: ProxyProtocol) -> ProxyStream<Self>;
4963
}
5064

51-
impl ToProxyStream for TcpStream {
52-
fn source(self, proxy: ProxyProtocol) -> ProxyStream {
65+
impl<T> ToProxyStream for T
66+
where
67+
T: AsyncRead + Unpin,
68+
{
69+
fn source(self, proxy: ProxyProtocol) -> ProxyStream<T> {
5370
let data = match proxy {
5471
ProxyProtocol::Enabled => Some(Default::default()),
5572
ProxyProtocol::Disabled => None,
@@ -62,19 +79,25 @@ impl ToProxyStream for TcpStream {
6279
}
6380
}
6481

65-
pub(super) struct ProxyStream {
66-
stream: TcpStream,
82+
pub(super) struct ProxyStream<T> {
83+
stream: T,
6784
data: Option<Vec<u8>>,
6885
start_of_data: usize,
6986
}
7087

71-
impl ProxyStream {
72-
fn real_addr(&mut self) -> RealAddrFuture<'_> {
88+
impl<T> ProxyStream<T>
89+
where
90+
T: AsyncRead + Unpin,
91+
{
92+
fn real_addr(&mut self) -> RealAddrFuture<'_, T> {
7393
RealAddrFuture { proxy_stream: self }
7494
}
7595
}
7696

77-
impl AsyncRead for ProxyStream {
97+
impl<T> AsyncRead for ProxyStream<T>
98+
where
99+
T: AsyncRead + Unpin,
100+
{
78101
fn poll_read(
79102
self: Pin<&mut Self>,
80103
cx: &mut Context<'_>,
@@ -88,7 +111,10 @@ impl AsyncRead for ProxyStream {
88111
}
89112
}
90113

91-
impl AsyncWrite for ProxyStream {
114+
impl<T> AsyncWrite for ProxyStream<T>
115+
where
116+
T: AsyncWrite + Unpin,
117+
{
92118
fn poll_write(
93119
mut self: Pin<&mut Self>,
94120
cx: &mut Context<'_>,
@@ -118,26 +144,26 @@ impl AsyncWrite for ProxyStream {
118144
}
119145
}
120146

121-
impl Deref for ProxyStream {
122-
type Target = TcpStream;
147+
impl<T> Deref for ProxyStream<T> {
148+
type Target = T;
123149

124150
fn deref(&self) -> &Self::Target {
125151
&self.stream
126152
}
127153
}
128154

129-
impl DerefMut for ProxyStream {
155+
impl<T> DerefMut for ProxyStream<T> {
130156
fn deref_mut(&mut self) -> &mut Self::Target {
131157
&mut self.stream
132158
}
133159
}
134160

135-
struct RealAddrFuture<'a> {
136-
proxy_stream: &'a mut ProxyStream,
161+
struct RealAddrFuture<'a, T> {
162+
proxy_stream: &'a mut ProxyStream<T>,
137163
}
138164

139-
impl<'a> RealAddrFuture<'a> {
140-
fn format_header(&self, res: Header) -> <Self as Future>::Output {
165+
impl<'a, T> RealAddrFuture<'a, T> {
166+
fn format_header(&self, res: Header) -> IoResult<Option<SocketAddr>> {
141167
let addr = match res.addresses {
142168
Addresses::IPv4 {
143169
source_address,
@@ -166,7 +192,7 @@ impl<'a> RealAddrFuture<'a> {
166192
Ok(Some(addr))
167193
}
168194

169-
fn get_header(&mut self) -> Poll<<Self as Future>::Output> {
195+
fn get_header(&mut self) -> Poll<IoResult<Option<SocketAddr>>> {
170196
let data = match &mut self.proxy_stream.data {
171197
Some(data) => data,
172198
None => unreachable!("Future cannot be pulled anymore"),
@@ -190,7 +216,10 @@ impl<'a> RealAddrFuture<'a> {
190216
}
191217
}
192218

193-
impl Future for RealAddrFuture<'_> {
219+
impl<T> Future for RealAddrFuture<'_, T>
220+
where
221+
T: AsyncRead + Unpin,
222+
{
194223
type Output = IoResult<Option<SocketAddr>>;
195224

196225
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
@@ -216,3 +245,30 @@ impl Future for RealAddrFuture<'_> {
216245
this.get_header()
217246
}
218247
}
248+
249+
#[cfg(test)]
250+
mod tests {
251+
use ppp::model::{Addresses, Command, Header, Protocol, Version};
252+
use std::io::Cursor;
253+
254+
use super::ToProxyStream;
255+
use crate::config::ProxyProtocol;
256+
use std::net::SocketAddr;
257+
258+
#[tokio::test]
259+
async fn test_header_parsing() {
260+
let header = Header::new(
261+
Version::Two,
262+
Command::Proxy,
263+
Protocol::Stream,
264+
vec![],
265+
Addresses::from(([1, 1, 1, 1], [2, 2, 2, 2], 24034, 443)),
266+
);
267+
let header = ppp::to_bytes(header).unwrap();
268+
let mut header = Cursor::new(header).source(ProxyProtocol::Enabled);
269+
270+
let actual = header.real_addr().await.unwrap().unwrap();
271+
272+
assert_eq!(SocketAddr::from(([1, 1, 1, 1], 24034)), actual);
273+
}
274+
}

src/api/tls.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@ use tokio::io::{AsyncRead, AsyncWrite, Result as IoResult};
1111
use tokio_rustls::TlsAcceptor;
1212
use tracing::{error, info};
1313

14-
use super::proxy::ProxyStream;
1514
use crate::cert::{Cert, CertFacade};
1615
use crate::util::to_u64;
1716

18-
pub(super) fn wrap<L, I>(
17+
pub(super) fn wrap<L, I, S>(
1918
listener: L,
2019
pool: PgPool,
2120
) -> impl Stream<
@@ -25,7 +24,8 @@ pub(super) fn wrap<L, I>(
2524
> + Send
2625
where
2726
L: Stream<Item = IoResult<I>> + Send,
28-
I: Future<Output = IoResult<ProxyStream>> + Send,
27+
I: Future<Output = IoResult<S>> + Send,
28+
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
2929
{
3030
let acceptor = Acceptor::new(pool);
3131

0 commit comments

Comments
 (0)