Skip to content
13 changes: 7 additions & 6 deletions mitmproxy-rs/src/server/udp.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use std::net::SocketAddr;
use std::net::{IpAddr, SocketAddr};

use mitmproxy::packet_sources::udp::UdpConf;

use pyo3::prelude::*;

use crate::server::base::Server;
use pyo3::prelude::*;

/// A running UDP server.
///
Expand Down Expand Up @@ -50,17 +49,19 @@ impl UdpServer {

/// Start a UDP server that is configured with the given parameters:
///
/// - `host`: The host address.
/// - `host`: The host IP address.
/// - `port`: The listen port.
/// - `handle_udp_stream`: An async function that will be called for each new UDP `Stream`.
#[pyfunction]
pub fn start_udp_server(
py: Python<'_>,
host: String,
host: IpAddr,
port: u16,
handle_udp_stream: PyObject,
) -> PyResult<Bound<PyAny>> {
let conf = UdpConf { host, port };
let conf = UdpConf {
listen_addr: SocketAddr::from((host, port)),
};
let handle_tcp_stream = py.None();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (server, local_addr) = Server::init(conf, handle_tcp_stream, handle_udp_stream).await?;
Expand Down
7 changes: 3 additions & 4 deletions mitmproxy-rs/src/server/wireguard.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::net::SocketAddr;
use std::net::{IpAddr, SocketAddr};

use crate::util::string_to_key;

Expand Down Expand Up @@ -63,7 +63,7 @@ impl WireGuardServer {
#[pyfunction]
pub fn start_wireguard_server(
py: Python<'_>,
host: String,
host: IpAddr,
port: u16,
private_key: String,
peer_public_keys: Vec<String>,
Expand All @@ -76,8 +76,7 @@ pub fn start_wireguard_server(
.map(string_to_key)
.collect::<PyResult<Vec<PublicKey>>>()?;
let conf = WireGuardConf {
host,
port,
listen_addr: SocketAddr::from((host, port)),
private_key,
peer_public_keys,
};
Expand Down
4 changes: 2 additions & 2 deletions src/network/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ mod tests {
use crate::packet_sources::{PacketSourceConf, PacketSourceTask};
use crate::shutdown;
use std::net::{IpAddr, Ipv4Addr};
use std::str::FromStr;
use tokio::net::UdpSocket;

#[test]
Expand Down Expand Up @@ -316,8 +317,7 @@ mod tests {
let (events_tx, mut events_rx) = tokio::sync::mpsc::channel(1);
let (shutdown_tx, shutdown_rx) = shutdown::channel();
let (task, addr) = UdpConf {
host: "127.0.0.1".to_string(),
port: 0,
listen_addr: SocketAddr::from_str("127.0.0.1:0").unwrap(),
}
.build(events_tx, commands_rx, shutdown_rx)
.await?;
Expand Down
62 changes: 34 additions & 28 deletions src/packet_sources/udp.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::net::{Ipv4Addr, SocketAddr};
use std::str::FromStr;

use anyhow::{Context, Result};

Expand All @@ -23,9 +22,39 @@ pub fn remote_host_closed_conn<T>(_res: &Result<T, std::io::Error>) -> bool {
false
}

/// Creates a nonblocking UDP socket bound to the specified address, restricted to either IPv4 or IPv6 only.
pub(crate) fn create_and_bind_udp_socket(addr: SocketAddr) -> Result<UdpSocket> {
let domain = if addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};

// We use socket2::Socket to set IPV6_V6ONLY and convert back to std::net::UdpSocket
let sock2 = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;

// Ensure that IPv6 sockets listen on IPv6 only
if addr.is_ipv6() {
sock2
.set_only_v6(true)
.context("Failed to set IPV6_V6ONLY flag")?;
}

sock2
.bind(&addr.into())
.context(format!("Failed to bind UDP socket to {}", addr))?;

let std_sock: std::net::UdpSocket = sock2.into();
std_sock
.set_nonblocking(true)
.context("Failed to make UDP socket non-blocking")?;
let socket = UdpSocket::from_std(std_sock)?;

Ok(socket)
}

pub struct UdpConf {
pub host: String,
pub port: u16,
pub listen_addr: SocketAddr,
}

impl PacketSourceConf for UdpConf {
Expand All @@ -42,31 +71,8 @@ impl PacketSourceConf for UdpConf {
transport_commands_rx: UnboundedReceiver<TransportCommand>,
shutdown: shutdown::Receiver,
) -> Result<(Self::Task, Self::Data)> {
let addr = format!("{}:{}", self.host, self.port);
let sock_addr = SocketAddr::from_str(&addr).context("Invalid listen address specified")?;

let domain = if sock_addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let sock2 = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;

// Ensure that IPv6 sockets listen on IPv6 only
if sock_addr.is_ipv6() {
sock2
.set_only_v6(true)
.context("Failed to set IPV6_V6ONLY flag")?;
}

sock2
.bind(&sock_addr.into())
.context(format!("Failed to bind UDP socket to {}", addr))?;

let std_sock: std::net::UdpSocket = sock2.into();
std_sock.set_nonblocking(true)?;
let socket = UdpSocket::from_std(std_sock)?;
let local_addr = socket.local_addr()?;
let socket = create_and_bind_udp_socket(self.listen_addr)?;
let local_addr: SocketAddr = socket.local_addr()?;

log::debug!("UDP server listening on {} ...", local_addr);

Expand Down
23 changes: 4 additions & 19 deletions src/packet_sources/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use tokio::{
},
};

use crate::packet_sources::udp::remote_host_closed_conn;
use crate::packet_sources::udp::{create_and_bind_udp_socket, remote_host_closed_conn};
use crate::shutdown;

// WireGuard headers are 60 bytes for IPv4 and 80 bytes for IPv6
Expand All @@ -36,8 +36,7 @@ pub struct WireGuardPeer {
}

pub struct WireGuardConf {
pub host: String,
pub port: u16,
pub listen_addr: SocketAddr,
pub private_key: StaticSecret,
pub peer_public_keys: Vec<PublicKey>,
}
Expand Down Expand Up @@ -84,26 +83,12 @@ impl PacketSourceConf for WireGuardConf {
peers_by_key.insert(public_key, peer);
}

// bind to UDP socket(s)
let socket_addrs = if self.host.is_empty() {
vec![
SocketAddr::new("0.0.0.0".parse().unwrap(), self.port),
SocketAddr::new("::".parse().unwrap(), self.port),
]
} else {
vec![SocketAddr::new(self.host.parse()?, self.port)]
};

let socket = UdpSocket::bind(socket_addrs.as_slice()).await?;
let socket = create_and_bind_udp_socket(self.listen_addr)?;
let local_addr = socket.local_addr()?;

log::debug!(
"WireGuard server listening for UDP connections on {} ...",
socket_addrs
.iter()
.map(|addr| addr.to_string())
.collect::<Vec<String>>()
.join(" and ")
local_addr
);

let public_key = PublicKey::from(&self.private_key);
Expand Down