Skip to content

Commit 7ef72b1

Browse files
committed
fix: improve test coverage
1 parent 9941578 commit 7ef72b1

File tree

1 file changed

+97
-52
lines changed

1 file changed

+97
-52
lines changed

src/api/proxy.rs

Lines changed: 97 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ use ppp::model::{Addresses, Header};
55
use std::future::Future;
66
use std::io::IoSlice;
77
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
8-
use std::ops::{Deref, DerefMut};
98
use std::pin::Pin;
109
use std::task::{Context, Poll};
1110
use tokio::io::{AsyncRead, AsyncWrite, Error as IoError, ErrorKind, ReadBuf, Result as IoResult};
@@ -27,6 +26,12 @@ impl RemoteAddr for TcpStream {
2726
}
2827
}
2928

29+
impl<T: RemoteAddr> RemoteAddr for ProxyStream<T> {
30+
fn remote_addr(&self) -> IoResult<SocketAddr> {
31+
self.stream.remote_addr()
32+
}
33+
}
34+
3035
pub(super) fn wrap(
3136
listener: TcpListener,
3237
proxy: ProxyProtocol,
@@ -144,54 +149,40 @@ where
144149
}
145150
}
146151

147-
impl<T> Deref for ProxyStream<T> {
148-
type Target = T;
149-
150-
fn deref(&self) -> &Self::Target {
151-
&self.stream
152-
}
153-
}
154-
155-
impl<T> DerefMut for ProxyStream<T> {
156-
fn deref_mut(&mut self) -> &mut Self::Target {
157-
&mut self.stream
158-
}
159-
}
160-
161152
struct RealAddrFuture<'a, T> {
162153
proxy_stream: &'a mut ProxyStream<T>,
163154
}
164155

165-
impl<'a, T> RealAddrFuture<'a, T> {
166-
fn format_header(&self, res: Header) -> IoResult<Option<SocketAddr>> {
167-
let addr = match res.addresses {
168-
Addresses::IPv4 {
169-
source_address,
170-
source_port,
171-
..
172-
} => {
173-
let port = source_port.unwrap_or_default();
174-
SocketAddrV4::new(source_address.into(), port).into()
175-
}
176-
Addresses::IPv6 {
177-
source_address,
178-
source_port,
179-
..
180-
} => {
181-
let port = source_port.unwrap_or_default();
182-
SocketAddrV6::new(source_address.into(), port, 0, 0).into()
183-
}
184-
address => {
185-
return Err(IoError::new(
186-
ErrorKind::Other,
187-
format!("Cannot convert {:?} to a SocketAddr", address),
188-
))
189-
}
190-
};
156+
fn format_header(res: Header) -> IoResult<SocketAddr> {
157+
let addr = match res.addresses {
158+
Addresses::IPv4 {
159+
source_address,
160+
source_port,
161+
..
162+
} => {
163+
let port = source_port.unwrap_or_default();
164+
SocketAddrV4::new(source_address.into(), port).into()
165+
}
166+
Addresses::IPv6 {
167+
source_address,
168+
source_port,
169+
..
170+
} => {
171+
let port = source_port.unwrap_or_default();
172+
SocketAddrV6::new(source_address.into(), port, 0, 0).into()
173+
}
174+
address => {
175+
return Err(IoError::new(
176+
ErrorKind::Other,
177+
format!("Cannot convert {:?} to a SocketAddr", address),
178+
))
179+
}
180+
};
191181

192-
Ok(Some(addr))
193-
}
182+
Ok(addr)
183+
}
194184

185+
impl<'a, T> RealAddrFuture<'a, T> {
195186
fn get_header(&mut self) -> Poll<IoResult<Option<SocketAddr>>> {
196187
let data = match &mut self.proxy_stream.data {
197188
Some(data) => data,
@@ -212,7 +203,7 @@ impl<'a, T> RealAddrFuture<'a, T> {
212203
}
213204
};
214205

215-
Poll::Ready(self.format_header(res))
206+
Poll::Ready(format_header(res).map(Some))
216207
}
217208
}
218209

@@ -235,7 +226,7 @@ where
235226
Ok(0) => {
236227
return Poll::Ready(Err(IoError::new(
237228
ErrorKind::UnexpectedEof,
238-
"Streamed finished before end of proxy protocol header",
229+
"Stream finished before end of proxy protocol header",
239230
)))
240231
}
241232
Ok(_) => {}
@@ -250,25 +241,79 @@ where
250241
mod tests {
251242
use ppp::model::{Addresses, Command, Header, Protocol, Version};
252243
use std::io::Cursor;
244+
use std::net::SocketAddr;
245+
use tokio::io::AsyncReadExt;
253246

254-
use super::ToProxyStream;
247+
use super::{format_header, ToProxyStream};
255248
use crate::config::ProxyProtocol;
256-
use std::net::SocketAddr;
257249

258250
#[tokio::test]
259-
async fn test_header_parsing() {
260-
let header = Header::new(
251+
async fn test_disabled() {
252+
let mut proxy_stream = Cursor::new(vec![]).source(ProxyProtocol::Disabled);
253+
assert!(proxy_stream.real_addr().await.unwrap().is_none());
254+
}
255+
256+
fn generate_header(addresses: Addresses) -> Header {
257+
Header::new(
261258
Version::Two,
262259
Command::Proxy,
263260
Protocol::Stream,
264261
vec![],
265-
Addresses::from(([1, 1, 1, 1], [2, 2, 2, 2], 24034, 443)),
266-
);
267-
let header = ppp::to_bytes(header).unwrap();
262+
addresses,
263+
)
264+
}
265+
266+
fn generate_ipv4() -> Header {
267+
let adresses = Addresses::from(([1, 1, 1, 1], [2, 2, 2, 2], 24034, 443));
268+
generate_header(adresses)
269+
}
270+
271+
#[tokio::test]
272+
async fn test_header_parsing() {
273+
let mut header = ppp::to_bytes(generate_ipv4()).unwrap();
274+
header.extend_from_slice("Test".as_ref());
268275
let mut header = Cursor::new(header).source(ProxyProtocol::Enabled);
269276

270277
let actual = header.real_addr().await.unwrap().unwrap();
271278

272279
assert_eq!(SocketAddr::from(([1, 1, 1, 1], 24034)), actual);
280+
281+
let mut actual = String::new();
282+
let size = header.read_to_string(&mut actual).await.unwrap();
283+
assert_eq!(4, size);
284+
assert_eq!("Test", actual);
285+
}
286+
287+
#[tokio::test]
288+
async fn test_incomplete() {
289+
let header = ppp::to_bytes(generate_ipv4()).unwrap();
290+
let header = &mut &header[..10];
291+
let mut header = header.source(ProxyProtocol::Enabled);
292+
293+
let actual = header.real_addr().await.unwrap_err();
294+
assert_eq!(
295+
format!("{}", actual),
296+
"Stream finished before end of proxy protocol header"
297+
);
298+
}
299+
300+
#[tokio::test]
301+
async fn test_failure() {
302+
let invalid = Vec::from("invalid header");
303+
let invalid = &mut &invalid[..];
304+
305+
let mut invalid = invalid.source(ProxyProtocol::Enabled);
306+
307+
let actual = invalid.real_addr().await.unwrap_err();
308+
assert_eq!(format!("{}", actual), "Proxy Parser Error");
309+
}
310+
311+
#[test]
312+
fn test_adresses() {
313+
let address = [1, 1, 1, 1, 1, 1, 1, 1];
314+
let addresses = Addresses::from((address, address, 24034, 443));
315+
316+
let actual = format_header(generate_header(addresses)).unwrap();
317+
assert_eq!(SocketAddr::from((address, 24034)), actual);
273318
}
274319
}

0 commit comments

Comments
 (0)