@@ -17,15 +17,29 @@ use tracing::{error, Instrument, Span};
1717
1818use crate :: config:: ProxyProtocol ;
1919
20+ trait RemoteAddr {
21+ fn remote_addr ( & self ) -> IoResult < SocketAddr > ;
22+ }
23+
24+ impl RemoteAddr for TcpStream {
25+ fn remote_addr ( & self ) -> IoResult < SocketAddr > {
26+ self . peer_addr ( )
27+ }
28+ }
29+
2030pub ( super ) fn wrap (
2131 listener : TcpListener ,
2232 proxy : ProxyProtocol ,
23- ) -> impl Stream < Item = IoResult < impl Future < Output = IoResult < ProxyStream > > > > + Send {
33+ ) -> impl Stream <
34+ Item = IoResult <
35+ impl Future < Output = IoResult < impl AsyncRead + AsyncWrite + Send + Unpin + ' static > > ,
36+ > ,
37+ > + Send {
2438 TcpListenerStream :: new ( listener)
2539 . map_ok ( move |conn| conn. source ( proxy) )
2640 . map_ok ( |mut conn| {
2741 let span = Span :: current ( ) ;
28- span. record ( "remote.addr" , & debug ( conn. peer_addr ( ) ) ) ;
42+ span. record ( "remote.addr" , & debug ( conn. remote_addr ( ) ) ) ;
2943 let span_clone = span. clone ( ) ;
3044
3145 async move {
@@ -44,12 +58,15 @@ pub(super) fn wrap(
4458 } )
4559}
4660
47- trait ToProxyStream {
48- fn source ( self , proxy : ProxyProtocol ) -> ProxyStream ;
61+ trait ToProxyStream : Sized {
62+ fn source ( self , proxy : ProxyProtocol ) -> ProxyStream < Self > ;
4963}
5064
51- impl ToProxyStream for TcpStream {
52- fn source ( self , proxy : ProxyProtocol ) -> ProxyStream {
65+ impl < T > ToProxyStream for T
66+ where
67+ T : AsyncRead + Unpin ,
68+ {
69+ fn source ( self , proxy : ProxyProtocol ) -> ProxyStream < T > {
5370 let data = match proxy {
5471 ProxyProtocol :: Enabled => Some ( Default :: default ( ) ) ,
5572 ProxyProtocol :: Disabled => None ,
@@ -62,19 +79,25 @@ impl ToProxyStream for TcpStream {
6279 }
6380}
6481
65- pub ( super ) struct ProxyStream {
66- stream : TcpStream ,
82+ pub ( super ) struct ProxyStream < T > {
83+ stream : T ,
6784 data : Option < Vec < u8 > > ,
6885 start_of_data : usize ,
6986}
7087
71- impl ProxyStream {
72- fn real_addr ( & mut self ) -> RealAddrFuture < ' _ > {
88+ impl < T > ProxyStream < T >
89+ where
90+ T : AsyncRead + Unpin ,
91+ {
92+ fn real_addr ( & mut self ) -> RealAddrFuture < ' _ , T > {
7393 RealAddrFuture { proxy_stream : self }
7494 }
7595}
7696
77- impl AsyncRead for ProxyStream {
97+ impl < T > AsyncRead for ProxyStream < T >
98+ where
99+ T : AsyncRead + Unpin ,
100+ {
78101 fn poll_read (
79102 self : Pin < & mut Self > ,
80103 cx : & mut Context < ' _ > ,
@@ -88,7 +111,10 @@ impl AsyncRead for ProxyStream {
88111 }
89112}
90113
91- impl AsyncWrite for ProxyStream {
114+ impl < T > AsyncWrite for ProxyStream < T >
115+ where
116+ T : AsyncWrite + Unpin ,
117+ {
92118 fn poll_write (
93119 mut self : Pin < & mut Self > ,
94120 cx : & mut Context < ' _ > ,
@@ -118,26 +144,26 @@ impl AsyncWrite for ProxyStream {
118144 }
119145}
120146
121- impl Deref for ProxyStream {
122- type Target = TcpStream ;
147+ impl < T > Deref for ProxyStream < T > {
148+ type Target = T ;
123149
124150 fn deref ( & self ) -> & Self :: Target {
125151 & self . stream
126152 }
127153}
128154
129- impl DerefMut for ProxyStream {
155+ impl < T > DerefMut for ProxyStream < T > {
130156 fn deref_mut ( & mut self ) -> & mut Self :: Target {
131157 & mut self . stream
132158 }
133159}
134160
135- struct RealAddrFuture < ' a > {
136- proxy_stream : & ' a mut ProxyStream ,
161+ struct RealAddrFuture < ' a , T > {
162+ proxy_stream : & ' a mut ProxyStream < T > ,
137163}
138164
139- impl < ' a > RealAddrFuture < ' a > {
140- fn format_header ( & self , res : Header ) -> < Self as Future > :: Output {
165+ impl < ' a , T > RealAddrFuture < ' a , T > {
166+ fn format_header ( & self , res : Header ) -> IoResult < Option < SocketAddr > > {
141167 let addr = match res. addresses {
142168 Addresses :: IPv4 {
143169 source_address,
@@ -166,7 +192,7 @@ impl<'a> RealAddrFuture<'a> {
166192 Ok ( Some ( addr) )
167193 }
168194
169- fn get_header ( & mut self ) -> Poll < < Self as Future > :: Output > {
195+ fn get_header ( & mut self ) -> Poll < IoResult < Option < SocketAddr > > > {
170196 let data = match & mut self . proxy_stream . data {
171197 Some ( data) => data,
172198 None => unreachable ! ( "Future cannot be pulled anymore" ) ,
@@ -190,7 +216,10 @@ impl<'a> RealAddrFuture<'a> {
190216 }
191217}
192218
193- impl Future for RealAddrFuture < ' _ > {
219+ impl < T > Future for RealAddrFuture < ' _ , T >
220+ where
221+ T : AsyncRead + Unpin ,
222+ {
194223 type Output = IoResult < Option < SocketAddr > > ;
195224
196225 fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
@@ -216,3 +245,30 @@ impl Future for RealAddrFuture<'_> {
216245 this. get_header ( )
217246 }
218247}
248+
249+ #[ cfg( test) ]
250+ mod tests {
251+ use ppp:: model:: { Addresses , Command , Header , Protocol , Version } ;
252+ use std:: io:: Cursor ;
253+
254+ use super :: ToProxyStream ;
255+ use crate :: config:: ProxyProtocol ;
256+ use std:: net:: SocketAddr ;
257+
258+ #[ tokio:: test]
259+ async fn test_header_parsing ( ) {
260+ let header = Header :: new (
261+ Version :: Two ,
262+ Command :: Proxy ,
263+ Protocol :: Stream ,
264+ vec ! [ ] ,
265+ Addresses :: from ( ( [ 1 , 1 , 1 , 1 ] , [ 2 , 2 , 2 , 2 ] , 24034 , 443 ) ) ,
266+ ) ;
267+ let header = ppp:: to_bytes ( header) . unwrap ( ) ;
268+ let mut header = Cursor :: new ( header) . source ( ProxyProtocol :: Enabled ) ;
269+
270+ let actual = header. real_addr ( ) . await . unwrap ( ) . unwrap ( ) ;
271+
272+ assert_eq ! ( SocketAddr :: from( ( [ 1 , 1 , 1 , 1 ] , 24034 ) ) , actual) ;
273+ }
274+ }
0 commit comments