|
| 1 | +// |
| 2 | +// Copyright (c) 2025 ZettaScale Technology |
| 3 | +// |
| 4 | +// This program and the accompanying materials are made available under the |
| 5 | +// terms of the Eclipse Public License 2.0 which is available at |
| 6 | +// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 |
| 7 | +// which is available at https://www.apache.org/licenses/LICENSE-2.0. |
| 8 | +// |
| 9 | +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 |
| 10 | +// |
| 11 | +// Contributors: |
| 12 | +// ZettaScale Zenoh Team, <[email protected]> |
| 13 | +// |
| 14 | + |
| 15 | +/// mostly taken from https://github.com/pixsper/socket-pktinfo/blob/main/src/unix.rs |
| 16 | +use std::io::{Error, IoSliceMut}; |
| 17 | +use std::{ |
| 18 | + io, mem, |
| 19 | + mem::MaybeUninit, |
| 20 | + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, |
| 21 | + os::unix::io::{AsRawFd, RawFd}, |
| 22 | + ptr, |
| 23 | +}; |
| 24 | + |
| 25 | +use socket2::SockAddr; |
| 26 | +use tokio::{io::Interest, net::UdpSocket}; |
| 27 | + |
| 28 | +unsafe fn setsockopt<T>( |
| 29 | + socket: libc::c_int, |
| 30 | + level: libc::c_int, |
| 31 | + name: libc::c_int, |
| 32 | + value: T, |
| 33 | +) -> io::Result<()> |
| 34 | +where |
| 35 | + T: Copy, |
| 36 | +{ |
| 37 | + let value = &value as *const T as *const libc::c_void; |
| 38 | + if libc::setsockopt( |
| 39 | + socket, |
| 40 | + level, |
| 41 | + name, |
| 42 | + value, |
| 43 | + mem::size_of::<T>() as libc::socklen_t, |
| 44 | + ) == 0 |
| 45 | + { |
| 46 | + Ok(()) |
| 47 | + } else { |
| 48 | + Err(Error::last_os_error()) |
| 49 | + } |
| 50 | +} |
| 51 | + |
| 52 | +#[derive(Clone)] |
| 53 | +pub(crate) struct PktInfoRetrievalData { |
| 54 | + port: u16, |
| 55 | +} |
| 56 | + |
| 57 | +pub(crate) fn enable_pktinfo(socket: &UdpSocket) -> io::Result<PktInfoRetrievalData> { |
| 58 | + let local_src_addr = socket.local_addr()?; |
| 59 | + match local_src_addr.is_ipv6() { |
| 60 | + false => unsafe { |
| 61 | + setsockopt(socket.as_raw_fd(), libc::IPPROTO_IP, libc::IP_PKTINFO, 1)?; |
| 62 | + }, |
| 63 | + true => unsafe { |
| 64 | + setsockopt( |
| 65 | + socket.as_raw_fd(), |
| 66 | + libc::IPPROTO_IPV6, |
| 67 | + libc::IPV6_RECVPKTINFO, |
| 68 | + 1, |
| 69 | + )?; |
| 70 | + }, |
| 71 | + } |
| 72 | + Ok(PktInfoRetrievalData { |
| 73 | + port: local_src_addr.port(), |
| 74 | + }) |
| 75 | +} |
| 76 | + |
| 77 | +fn recv_with_dst_inner( |
| 78 | + fd: RawFd, |
| 79 | + local_port: u16, |
| 80 | + buf: &mut [u8], |
| 81 | +) -> io::Result<(usize, SocketAddr, Option<SocketAddr>)> { |
| 82 | + let mut addr_src: MaybeUninit<libc::sockaddr_storage> = MaybeUninit::uninit(); |
| 83 | + let mut msg_iov = IoSliceMut::new(buf); |
| 84 | + let mut cmsg = { |
| 85 | + let space = unsafe { |
| 86 | + libc::CMSG_SPACE(mem::size_of::<libc::in_pktinfo>() as libc::c_uint) as usize |
| 87 | + }; |
| 88 | + Vec::<u8>::with_capacity(space) |
| 89 | + }; |
| 90 | + |
| 91 | + let mut mhdr = unsafe { |
| 92 | + let mut mhdr = MaybeUninit::<libc::msghdr>::zeroed(); |
| 93 | + let p = mhdr.as_mut_ptr(); |
| 94 | + (*p).msg_name = addr_src.as_mut_ptr() as *mut libc::c_void; |
| 95 | + (*p).msg_namelen = mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t; |
| 96 | + (*p).msg_iov = &mut msg_iov as *mut IoSliceMut as *mut libc::iovec; |
| 97 | + (*p).msg_iovlen = 1; |
| 98 | + (*p).msg_control = cmsg.as_mut_ptr() as *mut libc::c_void; |
| 99 | + (*p).msg_controllen = cmsg.capacity() as _; |
| 100 | + (*p).msg_flags = 0; |
| 101 | + mhdr.assume_init() |
| 102 | + }; |
| 103 | + |
| 104 | + let bytes_recv = unsafe { libc::recvmsg(fd, &mut mhdr as *mut libc::msghdr, 0) }; |
| 105 | + if bytes_recv <= 0 { |
| 106 | + return Err(Error::last_os_error()); |
| 107 | + } |
| 108 | + |
| 109 | + let addr_src = unsafe { |
| 110 | + SockAddr::new( |
| 111 | + addr_src.assume_init(), |
| 112 | + mem::size_of::<libc::sockaddr_storage>() as _, |
| 113 | + ) |
| 114 | + } |
| 115 | + .as_socket() |
| 116 | + .unwrap(); |
| 117 | + |
| 118 | + let mut header = if mhdr.msg_controllen > 0 { |
| 119 | + debug_assert!(!mhdr.msg_control.is_null()); |
| 120 | + debug_assert!(cmsg.capacity() >= mhdr.msg_controllen as usize); |
| 121 | + |
| 122 | + Some(unsafe { |
| 123 | + libc::CMSG_FIRSTHDR(&mhdr as *const libc::msghdr) |
| 124 | + .as_ref() |
| 125 | + .unwrap() |
| 126 | + }) |
| 127 | + } else { |
| 128 | + None |
| 129 | + }; |
| 130 | + |
| 131 | + let mut addr_dst = None; |
| 132 | + |
| 133 | + while addr_dst.is_none() && header.is_some() { |
| 134 | + let h = header.unwrap(); |
| 135 | + let p = unsafe { libc::CMSG_DATA(h) }; |
| 136 | + |
| 137 | + match (h.cmsg_level, h.cmsg_type) { |
| 138 | + (libc::IPPROTO_IP, libc::IP_PKTINFO) => { |
| 139 | + let pktinfo = unsafe { ptr::read_unaligned(p as *const libc::in_pktinfo) }; |
| 140 | + addr_dst = Some(SocketAddr::new( |
| 141 | + IpAddr::V4(Ipv4Addr::from(u32::from_be(pktinfo.ipi_addr.s_addr))), |
| 142 | + local_port, |
| 143 | + )); |
| 144 | + } |
| 145 | + (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => { |
| 146 | + let pktinfo = unsafe { ptr::read_unaligned(p as *const libc::in6_pktinfo) }; |
| 147 | + addr_dst = Some(SocketAddr::new( |
| 148 | + IpAddr::V6(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr)), |
| 149 | + local_port, |
| 150 | + )); |
| 151 | + } |
| 152 | + _ => { |
| 153 | + header = unsafe { |
| 154 | + let p = libc::CMSG_NXTHDR(&mhdr as *const _, h as *const _); |
| 155 | + p.as_ref() |
| 156 | + }; |
| 157 | + } |
| 158 | + } |
| 159 | + } |
| 160 | + Ok((bytes_recv as _, addr_src, addr_dst)) |
| 161 | +} |
| 162 | + |
| 163 | +pub(crate) async fn recv_with_dst( |
| 164 | + socket: &UdpSocket, |
| 165 | + data: &PktInfoRetrievalData, |
| 166 | + buffer: &mut [u8], |
| 167 | +) -> io::Result<(usize, SocketAddr, Option<SocketAddr>)> { |
| 168 | + let fd = socket.as_raw_fd(); |
| 169 | + let local_port = data.port; |
| 170 | + |
| 171 | + socket |
| 172 | + .async_io(Interest::READABLE, || { |
| 173 | + recv_with_dst_inner(fd, local_port, buffer) |
| 174 | + }) |
| 175 | + .await |
| 176 | +} |
0 commit comments