Skip to content

Commit 8a7bfca

Browse files
committed
Merge branch 'dev' of github.com:conblem/acme-dns-rust into dev
2 parents 9c5f761 + 7ef72b1 commit 8a7bfca

File tree

2 files changed

+164
-63
lines changed

2 files changed

+164
-63
lines changed

src/api/proxy.rs

Lines changed: 161 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ use ppp::model::{Addresses, Header};
55
use std::future::Future;
66
use std::io::IoSlice;
77
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
8-
use std::ops::{Deref, DerefMut};
98
use std::pin::Pin;
109
use std::task::{Context, Poll};
1110
use tokio::io::{AsyncRead, AsyncWrite, Error as IoError, ErrorKind, ReadBuf, Result as IoResult};
@@ -17,15 +16,35 @@ use tracing::{error, Instrument, Span};
1716

1817
use crate::config::ProxyProtocol;
1918

19+
trait RemoteAddr {
20+
fn remote_addr(&self) -> IoResult<SocketAddr>;
21+
}
22+
23+
impl RemoteAddr for TcpStream {
24+
fn remote_addr(&self) -> IoResult<SocketAddr> {
25+
self.peer_addr()
26+
}
27+
}
28+
29+
impl<T: RemoteAddr> RemoteAddr for ProxyStream<T> {
30+
fn remote_addr(&self) -> IoResult<SocketAddr> {
31+
self.stream.remote_addr()
32+
}
33+
}
34+
2035
pub(super) fn wrap(
2136
listener: TcpListener,
2237
proxy: ProxyProtocol,
23-
) -> impl Stream<Item = IoResult<impl Future<Output = IoResult<ProxyStream>>>> + Send {
38+
) -> impl Stream<
39+
Item = IoResult<
40+
impl Future<Output = IoResult<impl AsyncRead + AsyncWrite + Send + Unpin + 'static>>,
41+
>,
42+
> + Send {
2443
TcpListenerStream::new(listener)
2544
.map_ok(move |conn| conn.source(proxy))
2645
.map_ok(|mut conn| {
2746
let span = Span::current();
28-
span.record("remote.addr", &debug(conn.peer_addr()));
47+
span.record("remote.addr", &debug(conn.remote_addr()));
2948
let span_clone = span.clone();
3049

3150
async move {
@@ -44,12 +63,15 @@ pub(super) fn wrap(
4463
})
4564
}
4665

47-
trait ToProxyStream {
48-
fn source(self, proxy: ProxyProtocol) -> ProxyStream;
66+
trait ToProxyStream: Sized {
67+
fn source(self, proxy: ProxyProtocol) -> ProxyStream<Self>;
4968
}
5069

51-
impl ToProxyStream for TcpStream {
52-
fn source(self, proxy: ProxyProtocol) -> ProxyStream {
70+
impl<T> ToProxyStream for T
71+
where
72+
T: AsyncRead + Unpin,
73+
{
74+
fn source(self, proxy: ProxyProtocol) -> ProxyStream<T> {
5375
let data = match proxy {
5476
ProxyProtocol::Enabled => Some(Default::default()),
5577
ProxyProtocol::Disabled => None,
@@ -62,19 +84,25 @@ impl ToProxyStream for TcpStream {
6284
}
6385
}
6486

65-
pub(super) struct ProxyStream {
66-
stream: TcpStream,
87+
pub(super) struct ProxyStream<T> {
88+
stream: T,
6789
data: Option<Vec<u8>>,
6890
start_of_data: usize,
6991
}
7092

71-
impl ProxyStream {
72-
fn real_addr(&mut self) -> RealAddrFuture<'_> {
93+
impl<T> ProxyStream<T>
94+
where
95+
T: AsyncRead + Unpin,
96+
{
97+
fn real_addr(&mut self) -> RealAddrFuture<'_, T> {
7398
RealAddrFuture { proxy_stream: self }
7499
}
75100
}
76101

77-
impl AsyncRead for ProxyStream {
102+
impl<T> AsyncRead for ProxyStream<T>
103+
where
104+
T: AsyncRead + Unpin,
105+
{
78106
fn poll_read(
79107
self: Pin<&mut Self>,
80108
cx: &mut Context<'_>,
@@ -88,7 +116,10 @@ impl AsyncRead for ProxyStream {
88116
}
89117
}
90118

91-
impl AsyncWrite for ProxyStream {
119+
impl<T> AsyncWrite for ProxyStream<T>
120+
where
121+
T: AsyncWrite + Unpin,
122+
{
92123
fn poll_write(
93124
mut self: Pin<&mut Self>,
94125
cx: &mut Context<'_>,
@@ -118,55 +149,41 @@ impl AsyncWrite for ProxyStream {
118149
}
119150
}
120151

121-
impl Deref for ProxyStream {
122-
type Target = TcpStream;
123-
124-
fn deref(&self) -> &Self::Target {
125-
&self.stream
126-
}
152+
struct RealAddrFuture<'a, T> {
153+
proxy_stream: &'a mut ProxyStream<T>,
127154
}
128155

129-
impl DerefMut for ProxyStream {
130-
fn deref_mut(&mut self) -> &mut Self::Target {
131-
&mut self.stream
132-
}
133-
}
156+
fn format_header(res: Header) -> IoResult<SocketAddr> {
157+
let addr = match res.addresses {
158+
Addresses::IPv4 {
159+
source_address,
160+
source_port,
161+
..
162+
} => {
163+
let port = source_port.unwrap_or_default();
164+
SocketAddrV4::new(source_address.into(), port).into()
165+
}
166+
Addresses::IPv6 {
167+
source_address,
168+
source_port,
169+
..
170+
} => {
171+
let port = source_port.unwrap_or_default();
172+
SocketAddrV6::new(source_address.into(), port, 0, 0).into()
173+
}
174+
address => {
175+
return Err(IoError::new(
176+
ErrorKind::Other,
177+
format!("Cannot convert {:?} to a SocketAddr", address),
178+
))
179+
}
180+
};
134181

135-
struct RealAddrFuture<'a> {
136-
proxy_stream: &'a mut ProxyStream,
182+
Ok(addr)
137183
}
138184

139-
impl<'a> RealAddrFuture<'a> {
140-
fn format_header(&self, res: Header) -> <Self as Future>::Output {
141-
let addr = match res.addresses {
142-
Addresses::IPv4 {
143-
source_address,
144-
source_port,
145-
..
146-
} => {
147-
let port = source_port.unwrap_or_default();
148-
SocketAddrV4::new(source_address.into(), port).into()
149-
}
150-
Addresses::IPv6 {
151-
source_address,
152-
source_port,
153-
..
154-
} => {
155-
let port = source_port.unwrap_or_default();
156-
SocketAddrV6::new(source_address.into(), port, 0, 0).into()
157-
}
158-
address => {
159-
return Err(IoError::new(
160-
ErrorKind::Other,
161-
format!("Cannot convert {:?} to a SocketAddr", address),
162-
))
163-
}
164-
};
165-
166-
Ok(Some(addr))
167-
}
168-
169-
fn get_header(&mut self) -> Poll<<Self as Future>::Output> {
185+
impl<'a, T> RealAddrFuture<'a, T> {
186+
fn get_header(&mut self) -> Poll<IoResult<Option<SocketAddr>>> {
170187
let data = match &mut self.proxy_stream.data {
171188
Some(data) => data,
172189
None => unreachable!("Future cannot be pulled anymore"),
@@ -186,11 +203,14 @@ impl<'a> RealAddrFuture<'a> {
186203
}
187204
};
188205

189-
Poll::Ready(self.format_header(res))
206+
Poll::Ready(format_header(res).map(Some))
190207
}
191208
}
192209

193-
impl Future for RealAddrFuture<'_> {
210+
impl<T> Future for RealAddrFuture<'_, T>
211+
where
212+
T: AsyncRead + Unpin,
213+
{
194214
type Output = IoResult<Option<SocketAddr>>;
195215

196216
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
@@ -206,7 +226,7 @@ impl Future for RealAddrFuture<'_> {
206226
Ok(0) => {
207227
return Poll::Ready(Err(IoError::new(
208228
ErrorKind::UnexpectedEof,
209-
"Streamed finished before end of proxy protocol header",
229+
"Stream finished before end of proxy protocol header",
210230
)))
211231
}
212232
Ok(_) => {}
@@ -216,3 +236,84 @@ impl Future for RealAddrFuture<'_> {
216236
this.get_header()
217237
}
218238
}
239+
240+
#[cfg(test)]
241+
mod tests {
242+
use ppp::model::{Addresses, Command, Header, Protocol, Version};
243+
use std::io::Cursor;
244+
use std::net::SocketAddr;
245+
use tokio::io::AsyncReadExt;
246+
247+
use super::{format_header, ToProxyStream};
248+
use crate::config::ProxyProtocol;
249+
250+
#[tokio::test]
251+
async fn test_disabled() {
252+
let mut proxy_stream = Cursor::new(vec![]).source(ProxyProtocol::Disabled);
253+
assert!(proxy_stream.real_addr().await.unwrap().is_none());
254+
}
255+
256+
fn generate_header(addresses: Addresses) -> Header {
257+
Header::new(
258+
Version::Two,
259+
Command::Proxy,
260+
Protocol::Stream,
261+
vec![],
262+
addresses,
263+
)
264+
}
265+
266+
fn generate_ipv4() -> Header {
267+
let adresses = Addresses::from(([1, 1, 1, 1], [2, 2, 2, 2], 24034, 443));
268+
generate_header(adresses)
269+
}
270+
271+
#[tokio::test]
272+
async fn test_header_parsing() {
273+
let mut header = ppp::to_bytes(generate_ipv4()).unwrap();
274+
header.extend_from_slice("Test".as_ref());
275+
let mut header = Cursor::new(header).source(ProxyProtocol::Enabled);
276+
277+
let actual = header.real_addr().await.unwrap().unwrap();
278+
279+
assert_eq!(SocketAddr::from(([1, 1, 1, 1], 24034)), actual);
280+
281+
let mut actual = String::new();
282+
let size = header.read_to_string(&mut actual).await.unwrap();
283+
assert_eq!(4, size);
284+
assert_eq!("Test", actual);
285+
}
286+
287+
#[tokio::test]
288+
async fn test_incomplete() {
289+
let header = ppp::to_bytes(generate_ipv4()).unwrap();
290+
let header = &mut &header[..10];
291+
let mut header = header.source(ProxyProtocol::Enabled);
292+
293+
let actual = header.real_addr().await.unwrap_err();
294+
assert_eq!(
295+
format!("{}", actual),
296+
"Stream finished before end of proxy protocol header"
297+
);
298+
}
299+
300+
#[tokio::test]
301+
async fn test_failure() {
302+
let invalid = Vec::from("invalid header");
303+
let invalid = &mut &invalid[..];
304+
305+
let mut invalid = invalid.source(ProxyProtocol::Enabled);
306+
307+
let actual = invalid.real_addr().await.unwrap_err();
308+
assert_eq!(format!("{}", actual), "Proxy Parser Error");
309+
}
310+
311+
#[test]
312+
fn test_adresses() {
313+
let address = [1, 1, 1, 1, 1, 1, 1, 1];
314+
let addresses = Addresses::from((address, address, 24034, 443));
315+
316+
let actual = format_header(generate_header(addresses)).unwrap();
317+
assert_eq!(SocketAddr::from((address, 24034)), actual);
318+
}
319+
}

src/api/tls.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@ use tokio::io::{AsyncRead, AsyncWrite, Result as IoResult};
1111
use tokio_rustls::TlsAcceptor;
1212
use tracing::{error, info};
1313

14-
use super::proxy::ProxyStream;
1514
use crate::cert::{Cert, CertFacade};
1615
use crate::util::to_u64;
1716

18-
pub(super) fn wrap<L, I>(
17+
pub(super) fn wrap<L, I, S>(
1918
listener: L,
2019
pool: PgPool,
2120
) -> impl Stream<
@@ -25,7 +24,8 @@ pub(super) fn wrap<L, I>(
2524
> + Send
2625
where
2726
L: Stream<Item = IoResult<I>> + Send,
28-
I: Future<Output = IoResult<ProxyStream>> + Send,
27+
I: Future<Output = IoResult<S>> + Send,
28+
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
2929
{
3030
let acceptor = Acceptor::new(pool);
3131

0 commit comments

Comments
 (0)