diff --git a/mitmproxy-rs/src/server/udp.rs b/mitmproxy-rs/src/server/udp.rs index b30eba47..b7b3c6ba 100644 --- a/mitmproxy-rs/src/server/udp.rs +++ b/mitmproxy-rs/src/server/udp.rs @@ -1,10 +1,9 @@ -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use mitmproxy::packet_sources::udp::UdpConf; -use pyo3::prelude::*; - use crate::server::base::Server; +use pyo3::prelude::*; /// A running UDP server. /// @@ -50,17 +49,19 @@ impl UdpServer { /// Start a UDP server that is configured with the given parameters: /// -/// - `host`: The host address. +/// - `host`: The host IP address. /// - `port`: The listen port. /// - `handle_udp_stream`: An async function that will be called for each new UDP `Stream`. #[pyfunction] pub fn start_udp_server( py: Python<'_>, - host: String, + host: IpAddr, port: u16, handle_udp_stream: PyObject, ) -> PyResult> { - let conf = UdpConf { host, port }; + let conf = UdpConf { + listen_addr: SocketAddr::from((host, port)), + }; let handle_tcp_stream = py.None(); pyo3_async_runtimes::tokio::future_into_py(py, async move { let (server, local_addr) = Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; diff --git a/mitmproxy-rs/src/server/wireguard.rs b/mitmproxy-rs/src/server/wireguard.rs index 1d5ea8b5..c54ee384 100644 --- a/mitmproxy-rs/src/server/wireguard.rs +++ b/mitmproxy-rs/src/server/wireguard.rs @@ -1,4 +1,4 @@ -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use crate::util::string_to_key; @@ -63,7 +63,7 @@ impl WireGuardServer { #[pyfunction] pub fn start_wireguard_server( py: Python<'_>, - host: String, + host: IpAddr, port: u16, private_key: String, peer_public_keys: Vec, @@ -76,8 +76,7 @@ pub fn start_wireguard_server( .map(string_to_key) .collect::>>()?; let conf = WireGuardConf { - host, - port, + listen_addr: SocketAddr::from((host, port)), private_key, peer_public_keys, }; diff --git a/src/network/udp.rs b/src/network/udp.rs index 74502a77..3527757b 100644 --- a/src/network/udp.rs +++ b/src/network/udp.rs @@ -264,6 +264,7 @@ mod tests { use crate::packet_sources::{PacketSourceConf, PacketSourceTask}; use crate::shutdown; use std::net::{IpAddr, Ipv4Addr}; + use std::str::FromStr; use tokio::net::UdpSocket; #[test] @@ -316,8 +317,7 @@ mod tests { let (events_tx, mut events_rx) = tokio::sync::mpsc::channel(1); let (shutdown_tx, shutdown_rx) = shutdown::channel(); let (task, addr) = UdpConf { - host: "127.0.0.1".to_string(), - port: 0, + listen_addr: SocketAddr::from_str("127.0.0.1:0").unwrap(), } .build(events_tx, commands_rx, shutdown_rx) .await?; diff --git a/src/packet_sources/udp.rs b/src/packet_sources/udp.rs index d654239a..9fee2ee4 100644 --- a/src/packet_sources/udp.rs +++ b/src/packet_sources/udp.rs @@ -1,5 +1,4 @@ use std::net::{Ipv4Addr, SocketAddr}; -use std::str::FromStr; use anyhow::{Context, Result}; @@ -23,9 +22,39 @@ pub fn remote_host_closed_conn(_res: &Result) -> bool { false } +/// Creates a nonblocking UDP socket bound to the specified address, restricted to either IPv4 or IPv6 only. +pub(crate) fn create_and_bind_udp_socket(addr: SocketAddr) -> Result { + let domain = if addr.is_ipv4() { + Domain::IPV4 + } else { + Domain::IPV6 + }; + + // We use socket2::Socket to set IPV6_V6ONLY and convert back to std::net::UdpSocket + let sock2 = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; + + // Ensure that IPv6 sockets listen on IPv6 only + if addr.is_ipv6() { + sock2 + .set_only_v6(true) + .context("Failed to set IPV6_V6ONLY flag")?; + } + + sock2 + .bind(&addr.into()) + .context(format!("Failed to bind UDP socket to {}", addr))?; + + let std_sock: std::net::UdpSocket = sock2.into(); + std_sock + .set_nonblocking(true) + .context("Failed to make UDP socket non-blocking")?; + let socket = UdpSocket::from_std(std_sock)?; + + Ok(socket) +} + pub struct UdpConf { - pub host: String, - pub port: u16, + pub listen_addr: SocketAddr, } impl PacketSourceConf for UdpConf { @@ -42,31 +71,8 @@ impl PacketSourceConf for UdpConf { transport_commands_rx: UnboundedReceiver, shutdown: shutdown::Receiver, ) -> Result<(Self::Task, Self::Data)> { - let addr = format!("{}:{}", self.host, self.port); - let sock_addr = SocketAddr::from_str(&addr).context("Invalid listen address specified")?; - - let domain = if sock_addr.is_ipv4() { - Domain::IPV4 - } else { - Domain::IPV6 - }; - let sock2 = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; - - // Ensure that IPv6 sockets listen on IPv6 only - if sock_addr.is_ipv6() { - sock2 - .set_only_v6(true) - .context("Failed to set IPV6_V6ONLY flag")?; - } - - sock2 - .bind(&sock_addr.into()) - .context(format!("Failed to bind UDP socket to {}", addr))?; - - let std_sock: std::net::UdpSocket = sock2.into(); - std_sock.set_nonblocking(true)?; - let socket = UdpSocket::from_std(std_sock)?; - let local_addr = socket.local_addr()?; + let socket = create_and_bind_udp_socket(self.listen_addr)?; + let local_addr: SocketAddr = socket.local_addr()?; log::debug!("UDP server listening on {} ...", local_addr); diff --git a/src/packet_sources/wireguard.rs b/src/packet_sources/wireguard.rs index 19fb5ef4..86ee0498 100755 --- a/src/packet_sources/wireguard.rs +++ b/src/packet_sources/wireguard.rs @@ -23,7 +23,7 @@ use tokio::{ }, }; -use crate::packet_sources::udp::remote_host_closed_conn; +use crate::packet_sources::udp::{create_and_bind_udp_socket, remote_host_closed_conn}; use crate::shutdown; // WireGuard headers are 60 bytes for IPv4 and 80 bytes for IPv6 @@ -36,8 +36,7 @@ pub struct WireGuardPeer { } pub struct WireGuardConf { - pub host: String, - pub port: u16, + pub listen_addr: SocketAddr, pub private_key: StaticSecret, pub peer_public_keys: Vec, } @@ -84,26 +83,12 @@ impl PacketSourceConf for WireGuardConf { peers_by_key.insert(public_key, peer); } - // bind to UDP socket(s) - let socket_addrs = if self.host.is_empty() { - vec![ - SocketAddr::new("0.0.0.0".parse().unwrap(), self.port), - SocketAddr::new("::".parse().unwrap(), self.port), - ] - } else { - vec![SocketAddr::new(self.host.parse()?, self.port)] - }; - - let socket = UdpSocket::bind(socket_addrs.as_slice()).await?; + let socket = create_and_bind_udp_socket(self.listen_addr)?; let local_addr = socket.local_addr()?; log::debug!( "WireGuard server listening for UDP connections on {} ...", - socket_addrs - .iter() - .map(|addr| addr.to_string()) - .collect::>() - .join(" and ") + local_addr ); let public_key = PublicKey::from(&self.private_key);