@@ -5,7 +5,6 @@ use ppp::model::{Addresses, Header};
55use std:: future:: Future ;
66use std:: io:: IoSlice ;
77use std:: net:: { SocketAddr , SocketAddrV4 , SocketAddrV6 } ;
8- use std:: ops:: { Deref , DerefMut } ;
98use std:: pin:: Pin ;
109use std:: task:: { Context , Poll } ;
1110use tokio:: io:: { AsyncRead , AsyncWrite , Error as IoError , ErrorKind , ReadBuf , Result as IoResult } ;
@@ -17,15 +16,35 @@ use tracing::{error, Instrument, Span};
1716
1817use crate :: config:: ProxyProtocol ;
1918
19+ trait RemoteAddr {
20+ fn remote_addr ( & self ) -> IoResult < SocketAddr > ;
21+ }
22+
23+ impl RemoteAddr for TcpStream {
24+ fn remote_addr ( & self ) -> IoResult < SocketAddr > {
25+ self . peer_addr ( )
26+ }
27+ }
28+
29+ impl < T : RemoteAddr > RemoteAddr for ProxyStream < T > {
30+ fn remote_addr ( & self ) -> IoResult < SocketAddr > {
31+ self . stream . remote_addr ( )
32+ }
33+ }
34+
2035pub ( super ) fn wrap (
2136 listener : TcpListener ,
2237 proxy : ProxyProtocol ,
23- ) -> impl Stream < Item = IoResult < impl Future < Output = IoResult < ProxyStream > > > > + Send {
38+ ) -> impl Stream <
39+ Item = IoResult <
40+ impl Future < Output = IoResult < impl AsyncRead + AsyncWrite + Send + Unpin + ' static > > ,
41+ > ,
42+ > + Send {
2443 TcpListenerStream :: new ( listener)
2544 . map_ok ( move |conn| conn. source ( proxy) )
2645 . map_ok ( |mut conn| {
2746 let span = Span :: current ( ) ;
28- span. record ( "remote.addr" , & debug ( conn. peer_addr ( ) ) ) ;
47+ span. record ( "remote.addr" , & debug ( conn. remote_addr ( ) ) ) ;
2948 let span_clone = span. clone ( ) ;
3049
3150 async move {
@@ -44,12 +63,15 @@ pub(super) fn wrap(
4463 } )
4564}
4665
47- trait ToProxyStream {
48- fn source ( self , proxy : ProxyProtocol ) -> ProxyStream ;
66+ trait ToProxyStream : Sized {
67+ fn source ( self , proxy : ProxyProtocol ) -> ProxyStream < Self > ;
4968}
5069
51- impl ToProxyStream for TcpStream {
52- fn source ( self , proxy : ProxyProtocol ) -> ProxyStream {
70+ impl < T > ToProxyStream for T
71+ where
72+ T : AsyncRead + Unpin ,
73+ {
74+ fn source ( self , proxy : ProxyProtocol ) -> ProxyStream < T > {
5375 let data = match proxy {
5476 ProxyProtocol :: Enabled => Some ( Default :: default ( ) ) ,
5577 ProxyProtocol :: Disabled => None ,
@@ -62,19 +84,25 @@ impl ToProxyStream for TcpStream {
6284 }
6385}
6486
65- pub ( super ) struct ProxyStream {
66- stream : TcpStream ,
87+ pub ( super ) struct ProxyStream < T > {
88+ stream : T ,
6789 data : Option < Vec < u8 > > ,
6890 start_of_data : usize ,
6991}
7092
71- impl ProxyStream {
72- fn real_addr ( & mut self ) -> RealAddrFuture < ' _ > {
93+ impl < T > ProxyStream < T >
94+ where
95+ T : AsyncRead + Unpin ,
96+ {
97+ fn real_addr ( & mut self ) -> RealAddrFuture < ' _ , T > {
7398 RealAddrFuture { proxy_stream : self }
7499 }
75100}
76101
77- impl AsyncRead for ProxyStream {
102+ impl < T > AsyncRead for ProxyStream < T >
103+ where
104+ T : AsyncRead + Unpin ,
105+ {
78106 fn poll_read (
79107 self : Pin < & mut Self > ,
80108 cx : & mut Context < ' _ > ,
@@ -88,7 +116,10 @@ impl AsyncRead for ProxyStream {
88116 }
89117}
90118
91- impl AsyncWrite for ProxyStream {
119+ impl < T > AsyncWrite for ProxyStream < T >
120+ where
121+ T : AsyncWrite + Unpin ,
122+ {
92123 fn poll_write (
93124 mut self : Pin < & mut Self > ,
94125 cx : & mut Context < ' _ > ,
@@ -118,55 +149,41 @@ impl AsyncWrite for ProxyStream {
118149 }
119150}
120151
121- impl Deref for ProxyStream {
122- type Target = TcpStream ;
123-
124- fn deref ( & self ) -> & Self :: Target {
125- & self . stream
126- }
152+ struct RealAddrFuture < ' a , T > {
153+ proxy_stream : & ' a mut ProxyStream < T > ,
127154}
128155
129- impl DerefMut for ProxyStream {
130- fn deref_mut ( & mut self ) -> & mut Self :: Target {
131- & mut self . stream
132- }
133- }
156+ fn format_header ( res : Header ) -> IoResult < SocketAddr > {
157+ let addr = match res. addresses {
158+ Addresses :: IPv4 {
159+ source_address,
160+ source_port,
161+ ..
162+ } => {
163+ let port = source_port. unwrap_or_default ( ) ;
164+ SocketAddrV4 :: new ( source_address. into ( ) , port) . into ( )
165+ }
166+ Addresses :: IPv6 {
167+ source_address,
168+ source_port,
169+ ..
170+ } => {
171+ let port = source_port. unwrap_or_default ( ) ;
172+ SocketAddrV6 :: new ( source_address. into ( ) , port, 0 , 0 ) . into ( )
173+ }
174+ address => {
175+ return Err ( IoError :: new (
176+ ErrorKind :: Other ,
177+ format ! ( "Cannot convert {:?} to a SocketAddr" , address) ,
178+ ) )
179+ }
180+ } ;
134181
135- struct RealAddrFuture < ' a > {
136- proxy_stream : & ' a mut ProxyStream ,
182+ Ok ( addr)
137183}
138184
139- impl < ' a > RealAddrFuture < ' a > {
140- fn format_header ( & self , res : Header ) -> <Self as Future >:: Output {
141- let addr = match res. addresses {
142- Addresses :: IPv4 {
143- source_address,
144- source_port,
145- ..
146- } => {
147- let port = source_port. unwrap_or_default ( ) ;
148- SocketAddrV4 :: new ( source_address. into ( ) , port) . into ( )
149- }
150- Addresses :: IPv6 {
151- source_address,
152- source_port,
153- ..
154- } => {
155- let port = source_port. unwrap_or_default ( ) ;
156- SocketAddrV6 :: new ( source_address. into ( ) , port, 0 , 0 ) . into ( )
157- }
158- address => {
159- return Err ( IoError :: new (
160- ErrorKind :: Other ,
161- format ! ( "Cannot convert {:?} to a SocketAddr" , address) ,
162- ) )
163- }
164- } ;
165-
166- Ok ( Some ( addr) )
167- }
168-
169- fn get_header ( & mut self ) -> Poll < <Self as Future >:: Output > {
185+ impl < ' a , T > RealAddrFuture < ' a , T > {
186+ fn get_header ( & mut self ) -> Poll < IoResult < Option < SocketAddr > > > {
170187 let data = match & mut self . proxy_stream . data {
171188 Some ( data) => data,
172189 None => unreachable ! ( "Future cannot be pulled anymore" ) ,
@@ -186,11 +203,14 @@ impl<'a> RealAddrFuture<'a> {
186203 }
187204 } ;
188205
189- Poll :: Ready ( self . format_header ( res) )
206+ Poll :: Ready ( format_header ( res) . map ( Some ) )
190207 }
191208}
192209
193- impl Future for RealAddrFuture < ' _ > {
210+ impl < T > Future for RealAddrFuture < ' _ , T >
211+ where
212+ T : AsyncRead + Unpin ,
213+ {
194214 type Output = IoResult < Option < SocketAddr > > ;
195215
196216 fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
@@ -206,7 +226,7 @@ impl Future for RealAddrFuture<'_> {
206226 Ok ( 0 ) => {
207227 return Poll :: Ready ( Err ( IoError :: new (
208228 ErrorKind :: UnexpectedEof ,
209- "Streamed finished before end of proxy protocol header" ,
229+ "Stream finished before end of proxy protocol header" ,
210230 ) ) )
211231 }
212232 Ok ( _) => { }
@@ -216,3 +236,84 @@ impl Future for RealAddrFuture<'_> {
216236 this. get_header ( )
217237 }
218238}
239+
240+ #[ cfg( test) ]
241+ mod tests {
242+ use ppp:: model:: { Addresses , Command , Header , Protocol , Version } ;
243+ use std:: io:: Cursor ;
244+ use std:: net:: SocketAddr ;
245+ use tokio:: io:: AsyncReadExt ;
246+
247+ use super :: { format_header, ToProxyStream } ;
248+ use crate :: config:: ProxyProtocol ;
249+
250+ #[ tokio:: test]
251+ async fn test_disabled ( ) {
252+ let mut proxy_stream = Cursor :: new ( vec ! [ ] ) . source ( ProxyProtocol :: Disabled ) ;
253+ assert ! ( proxy_stream. real_addr( ) . await . unwrap( ) . is_none( ) ) ;
254+ }
255+
256+ fn generate_header ( addresses : Addresses ) -> Header {
257+ Header :: new (
258+ Version :: Two ,
259+ Command :: Proxy ,
260+ Protocol :: Stream ,
261+ vec ! [ ] ,
262+ addresses,
263+ )
264+ }
265+
266+ fn generate_ipv4 ( ) -> Header {
267+ let adresses = Addresses :: from ( ( [ 1 , 1 , 1 , 1 ] , [ 2 , 2 , 2 , 2 ] , 24034 , 443 ) ) ;
268+ generate_header ( adresses)
269+ }
270+
271+ #[ tokio:: test]
272+ async fn test_header_parsing ( ) {
273+ let mut header = ppp:: to_bytes ( generate_ipv4 ( ) ) . unwrap ( ) ;
274+ header. extend_from_slice ( "Test" . as_ref ( ) ) ;
275+ let mut header = Cursor :: new ( header) . source ( ProxyProtocol :: Enabled ) ;
276+
277+ let actual = header. real_addr ( ) . await . unwrap ( ) . unwrap ( ) ;
278+
279+ assert_eq ! ( SocketAddr :: from( ( [ 1 , 1 , 1 , 1 ] , 24034 ) ) , actual) ;
280+
281+ let mut actual = String :: new ( ) ;
282+ let size = header. read_to_string ( & mut actual) . await . unwrap ( ) ;
283+ assert_eq ! ( 4 , size) ;
284+ assert_eq ! ( "Test" , actual) ;
285+ }
286+
287+ #[ tokio:: test]
288+ async fn test_incomplete ( ) {
289+ let header = ppp:: to_bytes ( generate_ipv4 ( ) ) . unwrap ( ) ;
290+ let header = & mut & header[ ..10 ] ;
291+ let mut header = header. source ( ProxyProtocol :: Enabled ) ;
292+
293+ let actual = header. real_addr ( ) . await . unwrap_err ( ) ;
294+ assert_eq ! (
295+ format!( "{}" , actual) ,
296+ "Stream finished before end of proxy protocol header"
297+ ) ;
298+ }
299+
300+ #[ tokio:: test]
301+ async fn test_failure ( ) {
302+ let invalid = Vec :: from ( "invalid header" ) ;
303+ let invalid = & mut & invalid[ ..] ;
304+
305+ let mut invalid = invalid. source ( ProxyProtocol :: Enabled ) ;
306+
307+ let actual = invalid. real_addr ( ) . await . unwrap_err ( ) ;
308+ assert_eq ! ( format!( "{}" , actual) , "Proxy Parser Error" ) ;
309+ }
310+
311+ #[ test]
312+ fn test_adresses ( ) {
313+ let address = [ 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ] ;
314+ let addresses = Addresses :: from ( ( address, address, 24034 , 443 ) ) ;
315+
316+ let actual = format_header ( generate_header ( addresses) ) . unwrap ( ) ;
317+ assert_eq ! ( SocketAddr :: from( ( address, 24034 ) ) , actual) ;
318+ }
319+ }
0 commit comments