Skip to content

Commit b8cc2d0

Browse files
authored
chore(transport): User socket2 to obtain orig_dst (#3626)
Signed-off-by: Zahari Dichev <[email protected]>
1 parent 30a9f24 commit b8cc2d0

File tree

1 file changed

+14
-142
lines changed

1 file changed

+14
-142
lines changed

linkerd/proxy/transport/src/orig_dst.rs

Lines changed: 14 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use futures::prelude::*;
66
use linkerd_error::Result;
77
use linkerd_io as io;
88
use linkerd_stack::Param;
9-
use std::{net::SocketAddr, pin::Pin};
9+
use std::pin::Pin;
1010
use tokio::net::TcpStream;
1111

1212
#[derive(Copy, Clone, Debug, Default)]
@@ -83,14 +83,7 @@ where
8383

8484
let incoming = incoming.map(|res| {
8585
let (inner, tcp) = res?;
86-
let orig_dst = match inner.param() {
87-
// IPv4-mapped IPv6 addresses are unwrapped by BindTcp::bind() and received here as
88-
// SocketAddr::V4. We must call getsockopt with IPv4 constants (via
89-
// orig_dst_addr_v4) even if it originally was an IPv6
90-
Remote(ClientAddr(SocketAddr::V4(_))) => orig_dst_addr_v4(&tcp)?,
91-
Remote(ClientAddr(SocketAddr::V6(_))) => orig_dst_addr_v6(&tcp)?,
92-
};
93-
let orig_dst = OrigDstAddr(orig_dst);
86+
let (orig_dst, tcp) = orig_dst(tcp)?;
9487
let addrs = Addrs { inner, orig_dst };
9588
Ok((addrs, tcp))
9689
});
@@ -99,139 +92,18 @@ where
9992
}
10093
}
10194

102-
#[cfg(target_os = "linux")]
103-
#[allow(unsafe_code)]
104-
fn orig_dst_addr_v4(sock: &TcpStream) -> io::Result<SocketAddr> {
105-
use std::os::unix::io::AsRawFd;
95+
fn orig_dst(sock: TcpStream) -> io::Result<(OrigDstAddr, TcpStream)> {
96+
let sock = {
97+
let stream = tokio::net::TcpStream::into_std(sock)?;
98+
socket2::Socket::from(stream)
99+
};
106100

107-
let fd = sock.as_raw_fd();
108-
unsafe { linux::so_original_dst_v4(fd) }
109-
}
110-
111-
#[cfg(target_os = "linux")]
112-
#[allow(unsafe_code)]
113-
fn orig_dst_addr_v6(sock: &TcpStream) -> io::Result<SocketAddr> {
114-
use std::os::unix::io::AsRawFd;
115-
116-
let fd = sock.as_raw_fd();
117-
unsafe { linux::so_original_dst_v6(fd) }
118-
}
119-
120-
#[cfg(not(target_os = "linux"))]
121-
fn orig_dst_addr_v4(_: &TcpStream) -> io::Result<SocketAddr> {
122-
Err(io::Error::new(
123-
io::ErrorKind::Other,
124-
"SO_ORIGINAL_DST not supported on this operating system",
125-
))
126-
}
127-
128-
#[cfg(not(target_os = "linux"))]
129-
fn orig_dst_addr_v6(_: &TcpStream) -> io::Result<SocketAddr> {
130-
Err(io::Error::new(
131-
io::ErrorKind::Other,
132-
"SO_ORIGINAL_DST not supported on this operating system",
133-
))
134-
}
135-
136-
#[cfg(target_os = "linux")]
137-
#[allow(unsafe_code)]
138-
mod linux {
139-
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
140-
use std::os::unix::io::RawFd;
141-
use std::{io, mem};
142-
143-
pub unsafe fn so_original_dst(fd: RawFd, level: i32, optname: i32) -> io::Result<SocketAddr> {
144-
let mut sockaddr: libc::sockaddr_storage = mem::zeroed();
145-
let mut sockaddr_len: libc::socklen_t = mem::size_of::<libc::sockaddr_storage>() as u32;
146-
147-
let ret = libc::getsockopt(
148-
fd,
149-
level,
150-
optname,
151-
&mut sockaddr as *mut _ as *mut _,
152-
&mut sockaddr_len as *mut _ as *mut _,
153-
);
154-
if ret != 0 {
155-
return Err(io::Error::last_os_error());
156-
}
157-
158-
mk_addr(&sockaddr, sockaddr_len)
159-
}
101+
let orig_dst = sock.original_dst()?.as_socket().ok_or(io::Error::new(
102+
io::ErrorKind::InvalidInput,
103+
"Invalid address format",
104+
))?;
160105

161-
pub unsafe fn so_original_dst_v4(fd: RawFd) -> io::Result<SocketAddr> {
162-
so_original_dst(fd, libc::SOL_IP, libc::SO_ORIGINAL_DST)
163-
}
164-
165-
pub unsafe fn so_original_dst_v6(fd: RawFd) -> io::Result<SocketAddr> {
166-
so_original_dst(fd, libc::SOL_IPV6, libc::IP6T_SO_ORIGINAL_DST)
167-
}
168-
169-
// Borrowed with love from net2-rs
170-
// https://github.com/rust-lang-nursery/net2-rs/blob/1b4cb4fb05fbad750b271f38221eab583b666e5e/src/socket.rs#L103
171-
//
172-
// Copyright (c) 2014 The Rust Project Developers
173-
fn mk_addr(storage: &libc::sockaddr_storage, len: libc::socklen_t) -> io::Result<SocketAddr> {
174-
match storage.ss_family as libc::c_int {
175-
libc::AF_INET => {
176-
assert!(len as usize >= mem::size_of::<libc::sockaddr_in>());
177-
178-
let sa = {
179-
let sa = storage as *const _ as *const libc::sockaddr_in;
180-
unsafe { *sa }
181-
};
182-
183-
let bits = ntoh32(sa.sin_addr.s_addr);
184-
let ip = Ipv4Addr::new(
185-
(bits >> 24) as u8,
186-
(bits >> 16) as u8,
187-
(bits >> 8) as u8,
188-
bits as u8,
189-
);
190-
let port = sa.sin_port;
191-
Ok(SocketAddr::V4(SocketAddrV4::new(ip, ntoh16(port))))
192-
}
193-
libc::AF_INET6 => {
194-
assert!(len as usize >= mem::size_of::<libc::sockaddr_in6>());
195-
196-
let sa = {
197-
let sa = storage as *const _ as *const libc::sockaddr_in6;
198-
unsafe { *sa }
199-
};
200-
201-
let arr = sa.sin6_addr.s6_addr;
202-
let ip = Ipv6Addr::new(
203-
(arr[0] as u16) << 8 | (arr[1] as u16),
204-
(arr[2] as u16) << 8 | (arr[3] as u16),
205-
(arr[4] as u16) << 8 | (arr[5] as u16),
206-
(arr[6] as u16) << 8 | (arr[7] as u16),
207-
(arr[8] as u16) << 8 | (arr[9] as u16),
208-
(arr[10] as u16) << 8 | (arr[11] as u16),
209-
(arr[12] as u16) << 8 | (arr[13] as u16),
210-
(arr[14] as u16) << 8 | (arr[15] as u16),
211-
);
212-
213-
let port = sa.sin6_port;
214-
let flowinfo = sa.sin6_flowinfo;
215-
let scope_id = sa.sin6_scope_id;
216-
Ok(SocketAddr::V6(SocketAddrV6::new(
217-
ip,
218-
ntoh16(port),
219-
flowinfo,
220-
scope_id,
221-
)))
222-
}
223-
_ => Err(io::Error::new(
224-
io::ErrorKind::InvalidInput,
225-
"invalid argument",
226-
)),
227-
}
228-
}
229-
230-
fn ntoh16(i: u16) -> u16 {
231-
<u16>::from_be(i)
232-
}
233-
234-
fn ntoh32(i: u32) -> u32 {
235-
<u32>::from_be(i)
236-
}
106+
let stream: std::net::TcpStream = socket2::Socket::into(sock);
107+
let stream = tokio::net::TcpStream::from_std(stream)?;
108+
Ok((OrigDstAddr(orig_dst), stream))
237109
}

0 commit comments

Comments
 (0)