Skip to content

Commit b27e639

Browse files
gpgabrielfisherdarling
authored andcommitted
Add unix domain socket for telemetry server
An old socket path will be removed if it exists.
1 parent 191b593 commit b27e639

File tree

9 files changed

+297
-45
lines changed

9 files changed

+297
-45
lines changed

examples/http_server/example_conf.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ telemetry:
5757
# Enables telemetry server
5858
enabled: true
5959
# Telemetry server address.
60+
# Can be either a TCP socket address (e.g., "127.0.0.1:8080")
61+
# or a Unix domain socket path (e.g., "/tmp/telemetry.sock") on Unix systems.
6062
addr: "127.0.0.1:0"
63+
# Example Unix socket configuration (uncomment to use):
64+
# addr: "/tmp/telemetry.sock"
6165
# HTTP endpoints configuration.
6266
endpoints:
6367
Example endpoint:

examples/http_server/main.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ mod settings;
1313

1414
use self::settings::{EndpointSettings, HttpServerSettings, ResponseSettings};
1515
use anyhow::anyhow;
16+
use foundations::addr::ListenAddr;
1617
use foundations::cli::{Arg, ArgAction, Cli};
1718
use foundations::settings::collections::Map;
1819
use foundations::telemetry::{self, log, tracing, TelemetryConfig, TelemetryContext};
@@ -56,8 +57,11 @@ async fn main() -> BootstrapResult<()> {
5657
custom_server_routes: vec![],
5758
})?;
5859

59-
if let Some(tele_serv_addr) = tele_driver.server_addr() {
60-
log::info!("Telemetry server is listening on http://{}", tele_serv_addr);
60+
if let Some(addr) = tele_driver.server_addr() {
61+
match addr {
62+
ListenAddr::Tcp(addr) => log::info!("Telemetry server is listening on http://{addr}"),
63+
ListenAddr::Unix(path) => log::info!("Telemetry server is listening on {path:?}"),
64+
}
6165
}
6266

6367
// Spawn TCP listeners for each endpoint. Note that `Map<EndpointsSettings>` is ordered, so

foundations/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ telemetry-server = [
8282
"dep:hyper-util",
8383
"dep:socket2",
8484
"dep:percent-encoding",
85+
"dep:serde",
8586
]
8687

8788
# Enables telemetry reporting over gRPC

foundations/src/addr.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
//! Network address types that support both TCP and Unix domain sockets.
2+
//!
3+
//! This module provides the [`ListenAddr`] enum, a flexible address type that can represent
4+
//! either TCP socket addresses or Unix domain socket paths.
5+
6+
#[cfg(feature = "settings")]
7+
use crate::settings::Settings;
8+
#[cfg(any(feature = "telemetry-server", feature = "settings"))]
9+
use serde::Deserialize;
10+
#[cfg(feature = "settings")]
11+
use serde::Serialize;
12+
use std::fmt;
13+
use std::net::{Ipv4Addr, SocketAddr};
14+
15+
/// Address that can be either TCP socket or Unix domain socket endpoint
16+
#[derive(Clone, Debug)]
17+
#[cfg_attr(
18+
any(feature = "telemetry-server", feature = "settings"),
19+
derive(Deserialize)
20+
)]
21+
#[cfg_attr(feature = "settings", derive(Serialize))]
22+
#[cfg_attr(
23+
any(feature = "telemetry-server", feature = "settings"),
24+
serde(untagged)
25+
)]
26+
pub enum ListenAddr {
27+
/// TCP network socket address
28+
Tcp(std::net::SocketAddr),
29+
/// Unix domain socket path
30+
#[cfg(unix)]
31+
Unix(std::path::PathBuf),
32+
}
33+
34+
impl Default for ListenAddr {
35+
fn default() -> Self {
36+
ListenAddr::Tcp((Ipv4Addr::LOCALHOST, 0).into())
37+
}
38+
}
39+
40+
#[cfg(feature = "settings")]
41+
impl From<crate::settings::net::SocketAddr> for ListenAddr {
42+
fn from(addr: crate::settings::net::SocketAddr) -> Self {
43+
ListenAddr::Tcp(addr.into())
44+
}
45+
}
46+
47+
impl From<SocketAddr> for ListenAddr {
48+
fn from(addr: SocketAddr) -> Self {
49+
ListenAddr::Tcp(addr)
50+
}
51+
}
52+
53+
#[cfg(unix)]
54+
impl From<std::path::PathBuf> for ListenAddr {
55+
fn from(path: std::path::PathBuf) -> Self {
56+
ListenAddr::Unix(path)
57+
}
58+
}
59+
60+
impl fmt::Display for ListenAddr {
61+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62+
match self {
63+
ListenAddr::Tcp(addr) => write!(f, "{addr}"),
64+
#[cfg(unix)]
65+
ListenAddr::Unix(path) => write!(f, "{}", path.display()),
66+
}
67+
}
68+
}
69+
70+
#[cfg(feature = "settings")]
71+
impl Settings for ListenAddr {}

foundations/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@
6666

6767
mod utils;
6868

69+
pub mod addr;
70+
6971
#[cfg(feature = "cli")]
7072
pub mod cli;
7173

foundations/src/telemetry/driver.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#[cfg(feature = "telemetry-server")]
2+
use crate::addr::ListenAddr;
13
use crate::utils::feature_use;
24
use crate::BootstrapResult;
35
use futures_util::future::BoxFuture;
@@ -9,7 +11,6 @@ use std::task::{Context, Poll};
911

1012
feature_use!(cfg(feature = "telemetry-server"), {
1113
use super::server::TelemetryServerFuture;
12-
use std::net::SocketAddr;
1314
});
1415

1516
/// A future that drives async telemetry functionality and that is returned
@@ -21,7 +22,7 @@ feature_use!(cfg(feature = "telemetry-server"), {
2122
/// [security syscall-related]: `crate::security`
2223
pub struct TelemetryDriver {
2324
#[cfg(feature = "telemetry-server")]
24-
server_addr: Option<SocketAddr>,
25+
server_addr: Option<ListenAddr>,
2526

2627
#[cfg(feature = "telemetry-server")]
2728
server_fut: Option<TelemetryServerFuture>,
@@ -36,7 +37,7 @@ impl TelemetryDriver {
3637
) -> Self {
3738
Self {
3839
#[cfg(feature = "telemetry-server")]
39-
server_addr: server_fut.as_ref().map(|fut| fut.local_addr()),
40+
server_addr: server_fut.as_ref().and_then(|fut| fut.local_addr().ok()),
4041

4142
#[cfg(feature = "telemetry-server")]
4243
server_fut,
@@ -49,8 +50,8 @@ impl TelemetryDriver {
4950
///
5051
/// Returns `None` if the server wasn't spawned.
5152
#[cfg(feature = "telemetry-server")]
52-
pub fn server_addr(&self) -> Option<SocketAddr> {
53-
self.server_addr
53+
pub fn server_addr(&self) -> Option<&ListenAddr> {
54+
self.server_addr.as_ref()
5455
}
5556

5657
/// Instructs the telemetry driver and server to perform an orderly shutdown when the given

foundations/src/telemetry/server/mod.rs

Lines changed: 141 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#[cfg(feature = "metrics")]
22
use super::metrics;
33
use super::settings::TelemetrySettings;
4+
use crate::addr::ListenAddr;
45
use crate::telemetry::log;
56
use crate::BootstrapResult;
67
use anyhow::Context as _;
@@ -14,18 +15,127 @@ use std::net::SocketAddr;
1415
use std::pin::Pin;
1516
use std::sync::Arc;
1617
use std::task::{Context, Poll};
18+
use tokio::io::{AsyncRead, AsyncWrite};
1719
use tokio::net::TcpListener;
20+
#[cfg(unix)]
21+
use tokio::net::{TcpStream, UnixListener, UnixStream};
1822
use tokio::sync::watch;
1923

2024
mod router;
2125

2226
use router::Router;
27+
28+
enum TelemetryStream {
29+
Tcp(TcpStream),
30+
#[cfg(unix)]
31+
Unix(UnixStream),
32+
}
33+
34+
impl AsyncRead for TelemetryStream {
35+
fn poll_read(
36+
self: Pin<&mut Self>,
37+
cx: &mut Context<'_>,
38+
buf: &mut tokio::io::ReadBuf<'_>,
39+
) -> Poll<std::io::Result<()>> {
40+
match self.get_mut() {
41+
TelemetryStream::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
42+
#[cfg(unix)]
43+
TelemetryStream::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
44+
}
45+
}
46+
}
47+
48+
impl AsyncWrite for TelemetryStream {
49+
fn poll_write(
50+
self: Pin<&mut Self>,
51+
cx: &mut Context<'_>,
52+
buf: &[u8],
53+
) -> Poll<Result<usize, std::io::Error>> {
54+
match self.get_mut() {
55+
TelemetryStream::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
56+
#[cfg(unix)]
57+
TelemetryStream::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
58+
}
59+
}
60+
61+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
62+
match self.get_mut() {
63+
TelemetryStream::Tcp(stream) => Pin::new(stream).poll_flush(cx),
64+
#[cfg(unix)]
65+
TelemetryStream::Unix(stream) => Pin::new(stream).poll_flush(cx),
66+
}
67+
}
68+
69+
fn poll_shutdown(
70+
self: Pin<&mut Self>,
71+
cx: &mut Context<'_>,
72+
) -> Poll<Result<(), std::io::Error>> {
73+
match self.get_mut() {
74+
TelemetryStream::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
75+
#[cfg(unix)]
76+
TelemetryStream::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
77+
}
78+
}
79+
}
80+
81+
enum TelemetryListener {
82+
Tcp(TcpListener),
83+
#[cfg(unix)]
84+
Unix(UnixListener),
85+
}
86+
87+
impl TelemetryListener {
88+
pub(crate) fn local_addr(&self) -> BootstrapResult<ListenAddr> {
89+
match self {
90+
TelemetryListener::Tcp(listener) => Ok(listener.local_addr()?.into()),
91+
#[cfg(unix)]
92+
TelemetryListener::Unix(listener) => match listener.local_addr()?.as_pathname() {
93+
Some(path) => Ok(path.to_path_buf().into()),
94+
None => Err(anyhow::anyhow!("unix socket listener has no pathname")),
95+
},
96+
}
97+
}
98+
99+
pub(crate) async fn accept(&self) -> std::io::Result<TelemetryStream> {
100+
match self {
101+
TelemetryListener::Tcp(listener) => listener
102+
.accept()
103+
.await
104+
.map(|(conn, _)| TelemetryStream::Tcp(conn)),
105+
#[cfg(unix)]
106+
TelemetryListener::Unix(listener) => listener
107+
.accept()
108+
.await
109+
.map(|(conn, _)| TelemetryStream::Unix(conn)),
110+
}
111+
}
112+
113+
pub(crate) fn poll_accept(
114+
&mut self,
115+
cx: &mut std::task::Context<'_>,
116+
) -> std::task::Poll<std::io::Result<TelemetryStream>> {
117+
match self {
118+
TelemetryListener::Tcp(listener) => match std::task::ready!(listener.poll_accept(cx)) {
119+
Ok((conn, _)) => std::task::Poll::Ready(Ok(TelemetryStream::Tcp(conn))),
120+
Err(e) => std::task::Poll::Ready(Err(e)),
121+
},
122+
#[cfg(unix)]
123+
TelemetryListener::Unix(listener) => {
124+
match std::task::ready!(listener.poll_accept(cx)) {
125+
Ok((conn, _)) => std::task::Poll::Ready(Ok(TelemetryStream::Unix(conn))),
126+
Err(e) => std::task::Poll::Ready(Err(e)),
127+
}
128+
}
129+
}
130+
}
131+
}
132+
23133
pub use router::{
24134
BoxError, TelemetryRouteHandler, TelemetryRouteHandlerFuture, TelemetryServerRoute,
25135
};
26136

27137
pub(super) struct TelemetryServerFuture {
28-
listener: TcpListener,
138+
listener: TelemetryListener,
29139
router: Router,
30140
}
31141

@@ -47,27 +157,38 @@ impl TelemetryServerFuture {
47157
.map_err(|err| anyhow::anyhow!(err))?;
48158
}
49159

50-
let addr = settings.server.addr;
51-
52-
#[cfg(feature = "settings")]
53-
let addr = SocketAddr::from(addr);
54-
55-
let router = Router::new(custom_routes, settings);
56-
57-
let listener = {
58-
let std_listener = std::net::TcpListener::from(
59-
bind_socket(addr).with_context(|| format!("binding to socket {addr:?}"))?,
60-
);
61-
62-
std_listener.set_nonblocking(true)?;
160+
let router = Router::new(custom_routes, Arc::clone(&settings));
161+
162+
let listener = match &settings.server.addr {
163+
ListenAddr::Tcp(addr) => {
164+
let std_listener = std::net::TcpListener::from(
165+
bind_socket(*addr)
166+
.with_context(|| format!("binding to TCP socket {addr:?}"))?,
167+
);
168+
std_listener.set_nonblocking(true)?;
169+
let tokio_listener = tokio::net::TcpListener::from_std(std_listener)?;
170+
TelemetryListener::Tcp(tokio_listener)
171+
}
172+
#[cfg(unix)]
173+
ListenAddr::Unix(path) => {
174+
// Remove existing socket file if it exists to avoid bind errors
175+
if path.exists() {
176+
if let Err(e) = std::fs::remove_file(path) {
177+
log::warn!("failed to remove existing Unix socket file"; "path" => %path.display(), "error" => e);
178+
}
179+
}
63180

64-
tokio::net::TcpListener::from_std(std_listener)?
181+
let unix_listener = UnixListener::bind(path)
182+
.with_context(|| format!("binding to Unix socket {path:?}"))?;
183+
TelemetryListener::Unix(unix_listener)
184+
}
65185
};
66186

67187
Ok(Some(TelemetryServerFuture { listener, router }))
68188
}
69-
pub(super) fn local_addr(&self) -> SocketAddr {
70-
self.listener.local_addr().unwrap()
189+
190+
pub(super) fn local_addr(&self) -> BootstrapResult<ListenAddr> {
191+
self.listener.local_addr()
71192
}
72193

73194
// Adapted from Hyper 0.14 Server stuff and axum::serve::serve.
@@ -87,15 +208,12 @@ impl TelemetryServerFuture {
87208
let (close_tx, close_rx) = watch::channel(());
88209
let listener = self.listener;
89210

90-
pin_mut!(listener);
91-
92211
loop {
93212
let socket = tokio::select! {
94213
conn = listener.accept() => match conn {
95-
Ok((conn, _)) => TokioIo::new(conn),
214+
Ok(conn) => TokioIo::new(conn),
96215
Err(e) => {
97216
log::warn!("failed to accept connection"; "error" => e);
98-
99217
continue;
100218
}
101219
},
@@ -140,11 +258,10 @@ impl Future for TelemetryServerFuture {
140258
let this = &mut *self;
141259

142260
loop {
143-
let socket = match ready!(Pin::new(&mut this.listener).poll_accept(cx)) {
144-
Ok((conn, _)) => TokioIo::new(conn),
261+
let socket = match ready!(this.listener.poll_accept(cx)) {
262+
Ok(conn) => TokioIo::new(conn),
145263
Err(e) => {
146264
log::warn!("failed to accept connection"; "error" => e);
147-
148265
continue;
149266
}
150267
};

0 commit comments

Comments
 (0)