@@ -5,7 +5,6 @@ use ppp::model::{Addresses, Header};
55use std:: future:: Future ;
66use std:: io:: IoSlice ;
77use std:: net:: { SocketAddr , SocketAddrV4 , SocketAddrV6 } ;
8- use std:: ops:: { Deref , DerefMut } ;
98use std:: pin:: Pin ;
109use std:: task:: { Context , Poll } ;
1110use tokio:: io:: { AsyncRead , AsyncWrite , Error as IoError , ErrorKind , ReadBuf , Result as IoResult } ;
@@ -27,6 +26,12 @@ impl RemoteAddr for TcpStream {
2726 }
2827}
2928
29+ impl < T : RemoteAddr > RemoteAddr for ProxyStream < T > {
30+ fn remote_addr ( & self ) -> IoResult < SocketAddr > {
31+ self . stream . remote_addr ( )
32+ }
33+ }
34+
3035pub ( super ) fn wrap (
3136 listener : TcpListener ,
3237 proxy : ProxyProtocol ,
@@ -144,54 +149,40 @@ where
144149 }
145150}
146151
147- impl < T > Deref for ProxyStream < T > {
148- type Target = T ;
149-
150- fn deref ( & self ) -> & Self :: Target {
151- & self . stream
152- }
153- }
154-
155- impl < T > DerefMut for ProxyStream < T > {
156- fn deref_mut ( & mut self ) -> & mut Self :: Target {
157- & mut self . stream
158- }
159- }
160-
161152struct RealAddrFuture < ' a , T > {
162153 proxy_stream : & ' a mut ProxyStream < T > ,
163154}
164155
165- impl < ' a , T > RealAddrFuture < ' a , T > {
166- fn format_header ( & self , res : Header ) -> IoResult < Option < SocketAddr > > {
167- let addr = match res. addresses {
168- Addresses :: IPv4 {
169- source_address,
170- source_port,
171- ..
172- } => {
173- let port = source_port. unwrap_or_default ( ) ;
174- SocketAddrV4 :: new ( source_address. into ( ) , port) . into ( )
175- }
176- Addresses :: IPv6 {
177- source_address,
178- source_port,
179- ..
180- } => {
181- let port = source_port. unwrap_or_default ( ) ;
182- SocketAddrV6 :: new ( source_address. into ( ) , port, 0 , 0 ) . into ( )
183- }
184- address => {
185- return Err ( IoError :: new (
186- ErrorKind :: Other ,
187- format ! ( "Cannot convert {:?} to a SocketAddr" , address) ,
188- ) )
189- }
190- } ;
156+ fn format_header ( res : Header ) -> IoResult < SocketAddr > {
157+ let addr = match res. addresses {
158+ Addresses :: IPv4 {
159+ source_address,
160+ source_port,
161+ ..
162+ } => {
163+ let port = source_port. unwrap_or_default ( ) ;
164+ SocketAddrV4 :: new ( source_address. into ( ) , port) . into ( )
165+ }
166+ Addresses :: IPv6 {
167+ source_address,
168+ source_port,
169+ ..
170+ } => {
171+ let port = source_port. unwrap_or_default ( ) ;
172+ SocketAddrV6 :: new ( source_address. into ( ) , port, 0 , 0 ) . into ( )
173+ }
174+ address => {
175+ return Err ( IoError :: new (
176+ ErrorKind :: Other ,
177+ format ! ( "Cannot convert {:?} to a SocketAddr" , address) ,
178+ ) )
179+ }
180+ } ;
191181
192- Ok ( Some ( addr) )
193- }
182+ Ok ( addr)
183+ }
194184
185+ impl < ' a , T > RealAddrFuture < ' a , T > {
195186 fn get_header ( & mut self ) -> Poll < IoResult < Option < SocketAddr > > > {
196187 let data = match & mut self . proxy_stream . data {
197188 Some ( data) => data,
@@ -212,7 +203,7 @@ impl<'a, T> RealAddrFuture<'a, T> {
212203 }
213204 } ;
214205
215- Poll :: Ready ( self . format_header ( res) )
206+ Poll :: Ready ( format_header ( res) . map ( Some ) )
216207 }
217208}
218209
@@ -235,7 +226,7 @@ where
235226 Ok ( 0 ) => {
236227 return Poll :: Ready ( Err ( IoError :: new (
237228 ErrorKind :: UnexpectedEof ,
238- "Streamed finished before end of proxy protocol header" ,
229+ "Stream finished before end of proxy protocol header" ,
239230 ) ) )
240231 }
241232 Ok ( _) => { }
@@ -250,25 +241,79 @@ where
250241mod tests {
251242 use ppp:: model:: { Addresses , Command , Header , Protocol , Version } ;
252243 use std:: io:: Cursor ;
244+ use std:: net:: SocketAddr ;
245+ use tokio:: io:: AsyncReadExt ;
253246
254- use super :: ToProxyStream ;
247+ use super :: { format_header , ToProxyStream } ;
255248 use crate :: config:: ProxyProtocol ;
256- use std:: net:: SocketAddr ;
257249
258250 #[ tokio:: test]
259- async fn test_header_parsing ( ) {
260- let header = Header :: new (
251+ async fn test_disabled ( ) {
252+ let mut proxy_stream = Cursor :: new ( vec ! [ ] ) . source ( ProxyProtocol :: Disabled ) ;
253+ assert ! ( proxy_stream. real_addr( ) . await . unwrap( ) . is_none( ) ) ;
254+ }
255+
256+ fn generate_header ( addresses : Addresses ) -> Header {
257+ Header :: new (
261258 Version :: Two ,
262259 Command :: Proxy ,
263260 Protocol :: Stream ,
264261 vec ! [ ] ,
265- Addresses :: from ( ( [ 1 , 1 , 1 , 1 ] , [ 2 , 2 , 2 , 2 ] , 24034 , 443 ) ) ,
266- ) ;
267- let header = ppp:: to_bytes ( header) . unwrap ( ) ;
262+ addresses,
263+ )
264+ }
265+
266+ fn generate_ipv4 ( ) -> Header {
267+ let adresses = Addresses :: from ( ( [ 1 , 1 , 1 , 1 ] , [ 2 , 2 , 2 , 2 ] , 24034 , 443 ) ) ;
268+ generate_header ( adresses)
269+ }
270+
271+ #[ tokio:: test]
272+ async fn test_header_parsing ( ) {
273+ let mut header = ppp:: to_bytes ( generate_ipv4 ( ) ) . unwrap ( ) ;
274+ header. extend_from_slice ( "Test" . as_ref ( ) ) ;
268275 let mut header = Cursor :: new ( header) . source ( ProxyProtocol :: Enabled ) ;
269276
270277 let actual = header. real_addr ( ) . await . unwrap ( ) . unwrap ( ) ;
271278
272279 assert_eq ! ( SocketAddr :: from( ( [ 1 , 1 , 1 , 1 ] , 24034 ) ) , actual) ;
280+
281+ let mut actual = String :: new ( ) ;
282+ let size = header. read_to_string ( & mut actual) . await . unwrap ( ) ;
283+ assert_eq ! ( 4 , size) ;
284+ assert_eq ! ( "Test" , actual) ;
285+ }
286+
287+ #[ tokio:: test]
288+ async fn test_incomplete ( ) {
289+ let header = ppp:: to_bytes ( generate_ipv4 ( ) ) . unwrap ( ) ;
290+ let header = & mut & header[ ..10 ] ;
291+ let mut header = header. source ( ProxyProtocol :: Enabled ) ;
292+
293+ let actual = header. real_addr ( ) . await . unwrap_err ( ) ;
294+ assert_eq ! (
295+ format!( "{}" , actual) ,
296+ "Stream finished before end of proxy protocol header"
297+ ) ;
298+ }
299+
300+ #[ tokio:: test]
301+ async fn test_failure ( ) {
302+ let invalid = Vec :: from ( "invalid header" ) ;
303+ let invalid = & mut & invalid[ ..] ;
304+
305+ let mut invalid = invalid. source ( ProxyProtocol :: Enabled ) ;
306+
307+ let actual = invalid. real_addr ( ) . await . unwrap_err ( ) ;
308+ assert_eq ! ( format!( "{}" , actual) , "Proxy Parser Error" ) ;
309+ }
310+
311+ #[ test]
312+ fn test_adresses ( ) {
313+ let address = [ 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ] ;
314+ let addresses = Addresses :: from ( ( address, address, 24034 , 443 ) ) ;
315+
316+ let actual = format_header ( generate_header ( addresses) ) . unwrap ( ) ;
317+ assert_eq ! ( SocketAddr :: from( ( address, 24034 ) ) , actual) ;
273318 }
274319}
0 commit comments