Skip to content

Commit 9e9ca3e

Browse files
authored
feat(velo): add unix domain socket (#7197)
Signed-off-by: Patrick Riel <priel@nvidia.com> Signed-off-by: Ryan Olson <rolson@nvidia.com>
1 parent d5680ed commit 9e9ca3e

File tree

9 files changed

+1934
-8
lines changed

9 files changed

+1934
-8
lines changed

lib/velo-transports/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ mod address;
3131

3232
pub mod tcp;
3333

34+
#[cfg(unix)]
35+
pub mod uds;
36+
3437
// #[cfg(feature = "ucx")]
3538
// pub mod ucx;
3639

lib/velo-transports/src/tcp/transport.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -392,9 +392,9 @@ async fn connection_writer_task(
392392
instance_id: crate::InstanceId,
393393
rx: flume::Receiver<SendTask>,
394394
connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>>,
395-
_cancel_token: CancellationToken,
395+
cancel_token: CancellationToken,
396396
) -> Result<()> {
397-
let result = connection_writer_inner(addr, instance_id, &rx).await;
397+
let result = connection_writer_inner(addr, instance_id, &rx, &cancel_token).await;
398398

399399
// Always drain queued messages and notify their error handlers.
400400
//
@@ -426,10 +426,14 @@ async fn connection_writer_inner(
426426
addr: SocketAddr,
427427
instance_id: crate::InstanceId,
428428
rx: &flume::Receiver<SendTask>,
429+
cancel_token: &CancellationToken,
429430
) -> Result<()> {
430431
debug!("Connecting to {}", addr);
431432

432-
let mut stream = TcpStream::connect(addr).await.context("connect failed")?;
433+
let mut stream = tokio::select! {
434+
_ = cancel_token.cancelled() => return Ok(()),
435+
res = TcpStream::connect(addr) => res.context("connect failed")?,
436+
};
433437

434438
if let Err(e) = stream.set_nodelay(true) {
435439
warn!("Failed to set TCP_NODELAY: {}", e);
@@ -450,7 +454,14 @@ async fn connection_writer_inner(
450454

451455
debug!("Connected to {}", addr);
452456

453-
while let Ok(msg) = rx.recv_async().await {
457+
loop {
458+
let msg = tokio::select! {
459+
_ = cancel_token.cancelled() => break,
460+
res = rx.recv_async() => match res {
461+
Ok(msg) => msg,
462+
Err(_) => break,
463+
},
464+
};
454465
if let Err(e) =
455466
TcpFrameCodec::encode_frame(&mut stream, msg.msg_type, &msg.header, &msg.payload).await
456467
{
Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
//! UDS listener for ActiveMessage transport
5+
//!
6+
//! Mirrors `tcp/listener.rs` but uses `UnixListener`/`UnixStream`.
7+
//! Reuses `TcpFrameCodec` for framing. Supports drain-aware frame handling
8+
//! via `ShutdownState`.
9+
10+
use anyhow::{Context, Result};
11+
use bytes::Bytes;
12+
use futures::StreamExt;
13+
use std::path::PathBuf;
14+
use std::sync::Arc;
15+
use tokio::net::UnixListener as TokioUnixListener;
16+
use tokio::net::UnixStream;
17+
use tokio_util::codec::Framed;
18+
use tracing::{debug, error, info, warn};
19+
20+
use crate::{MessageType, ShutdownState, TransportAdapter, TransportErrorHandler};
21+
22+
use crate::tcp::TcpFrameCodec;
23+
24+
/// UDS listener for ActiveMessage transport
25+
///
26+
/// Accepts incoming Unix domain socket connections and routes decoded frames
27+
/// to the appropriate transport streams. Supports graceful drain: during drain,
28+
/// new `Message` frames are rejected with a `ShuttingDown` response while
29+
/// `Response`/`Event`/`Ack` frames continue to flow.
30+
pub struct UdsListener {
31+
socket_path: PathBuf,
32+
adapter: TransportAdapter,
33+
error_handler: Arc<dyn TransportErrorHandler>,
34+
shutdown_state: ShutdownState,
35+
}
36+
37+
/// UDS listener that has been bound to a socket path, ready to accept connections.
38+
///
39+
/// Created by [`UdsListener::bind`]. Holding this value proves the OS-level bind
40+
/// succeeded, so callers can detect failures before spawning a task.
41+
pub struct BoundUdsListener {
42+
socket_path: PathBuf,
43+
adapter: TransportAdapter,
44+
error_handler: Arc<dyn TransportErrorHandler>,
45+
shutdown_state: ShutdownState,
46+
listener: TokioUnixListener,
47+
}
48+
49+
impl UdsListener {
50+
/// Create a new builder for UdsListener
51+
pub fn builder() -> UdsListenerBuilder {
52+
UdsListenerBuilder::new()
53+
}
54+
55+
/// Bind to the socket path and return a [`BoundUdsListener`] ready to serve.
56+
///
57+
/// `TokioUnixListener::bind` is synchronous, so this method is also
58+
/// synchronous. Callers that need to propagate bind failures before spawning
59+
/// a task should call `bind()` first, then spawn `bound.serve()`.
60+
pub fn bind(self) -> Result<BoundUdsListener> {
61+
let listener = TokioUnixListener::bind(&self.socket_path)
62+
.with_context(|| format!("Failed to bind UDS listener to {:?}", self.socket_path))?;
63+
info!("UDS listener bound to {:?}", self.socket_path);
64+
Ok(BoundUdsListener {
65+
socket_path: self.socket_path,
66+
adapter: self.adapter,
67+
error_handler: self.error_handler,
68+
shutdown_state: self.shutdown_state,
69+
listener,
70+
})
71+
}
72+
73+
/// Convenience shim: bind and serve in one call.
74+
pub async fn serve(self) -> Result<()> {
75+
self.bind()?.serve().await
76+
}
77+
78+
/// Handle a single UDS connection
79+
async fn handle_connection(
80+
stream: UnixStream,
81+
adapter: TransportAdapter,
82+
error_handler: Arc<dyn TransportErrorHandler>,
83+
shutdown_state: ShutdownState,
84+
) -> Result<()> {
85+
debug!("Configuring UDS connection");
86+
87+
// Create framed stream with zero-copy codec (same as TCP)
88+
let mut framed = Framed::new(stream, TcpFrameCodec::new());
89+
let teardown_token = shutdown_state.teardown_token().clone();
90+
91+
debug!("UDS connection ready for frames");
92+
93+
loop {
94+
tokio::select! {
95+
frame_result = framed.next() => {
96+
match frame_result {
97+
Some(Ok((msg_type, header, payload))) => {
98+
// During drain: reject new Message frames with ShuttingDown,
99+
// but always pass through Response/Ack/Event frames.
100+
if shutdown_state.is_draining() && msg_type == MessageType::Message {
101+
debug!(
102+
"Rejecting Message frame during drain (sending ShuttingDown)"
103+
);
104+
// Echo original header back for correlation, empty payload
105+
if let Err(e) = TcpFrameCodec::encode_frame(
106+
framed.get_mut(),
107+
MessageType::ShuttingDown,
108+
&header,
109+
&[],
110+
)
111+
.await
112+
{
113+
warn!(
114+
"Failed to send ShuttingDown frame: {}",
115+
e
116+
);
117+
}
118+
continue;
119+
}
120+
121+
if let Err(e) = Self::route_frame(
122+
msg_type,
123+
header,
124+
payload,
125+
&adapter,
126+
&error_handler,
127+
)
128+
.await
129+
{
130+
warn!(
131+
"Failed to route {:?} frame from UDS: {}",
132+
msg_type, e
133+
);
134+
}
135+
}
136+
Some(Err(e)) => {
137+
error!("Frame decode error from UDS: {}", e);
138+
break;
139+
}
140+
None => {
141+
debug!("UDS connection closed gracefully");
142+
break;
143+
}
144+
}
145+
}
146+
_ = teardown_token.cancelled() => {
147+
debug!("UDS connection handler torn down");
148+
break;
149+
}
150+
}
151+
}
152+
153+
Ok(())
154+
}
155+
156+
/// Route a decoded frame to the appropriate stream
157+
async fn route_frame(
158+
msg_type: MessageType,
159+
header: Bytes,
160+
payload: Bytes,
161+
adapter: &TransportAdapter,
162+
error_handler: &Arc<dyn TransportErrorHandler>,
163+
) -> Result<()> {
164+
let sender = match msg_type {
165+
MessageType::Message => &adapter.message_stream,
166+
MessageType::Response => &adapter.response_stream,
167+
MessageType::Ack | MessageType::Event => &adapter.event_stream,
168+
MessageType::ShuttingDown => {
169+
// ShuttingDown is an outbound-only frame type; receiving it here
170+
// means a remote peer rejected our request. Route to the response
171+
// stream so higher layers can handle the rejection via correlation.
172+
&adapter.response_stream
173+
}
174+
};
175+
176+
match sender.send_async((header, payload)).await {
177+
Ok(_) => Ok(()),
178+
Err(e) => {
179+
error_handler.on_error(e.0.0, e.0.1, format!("Failed to route {:?}", msg_type));
180+
Err(anyhow::anyhow!("Failed to send to stream"))
181+
}
182+
}
183+
}
184+
}
185+
186+
impl BoundUdsListener {
187+
/// Accept connections until the teardown token is cancelled.
188+
pub async fn serve(self) -> Result<()> {
189+
let teardown_token = self.shutdown_state.teardown_token().clone();
190+
191+
loop {
192+
tokio::select! {
193+
accept_result = self.listener.accept() => {
194+
match accept_result {
195+
Ok((stream, _addr)) => {
196+
debug!("Accepted UDS connection");
197+
198+
let adapter = self.adapter.clone();
199+
let error_handler = self.error_handler.clone();
200+
let shutdown_state = self.shutdown_state.clone();
201+
202+
tokio::spawn(async move {
203+
if let Err(e) = UdsListener::handle_connection(
204+
stream,
205+
adapter,
206+
error_handler,
207+
shutdown_state,
208+
)
209+
.await
210+
{
211+
warn!("Error handling UDS connection: {}", e);
212+
}
213+
});
214+
}
215+
Err(e) => {
216+
error!("Failed to accept UDS connection: {}", e);
217+
}
218+
}
219+
}
220+
_ = teardown_token.cancelled() => {
221+
info!("UDS listener shutting down (teardown)");
222+
break;
223+
}
224+
}
225+
}
226+
227+
// Clean up socket file
228+
std::fs::remove_file(&self.socket_path).ok();
229+
230+
Ok(())
231+
}
232+
}
233+
234+
/// Builder for UdsListener
235+
pub struct UdsListenerBuilder {
236+
socket_path: Option<PathBuf>,
237+
adapter: Option<TransportAdapter>,
238+
error_handler: Option<Arc<dyn TransportErrorHandler>>,
239+
shutdown_state: Option<ShutdownState>,
240+
}
241+
242+
impl UdsListenerBuilder {
243+
/// Create a new builder
244+
pub fn new() -> Self {
245+
Self {
246+
socket_path: None,
247+
adapter: None,
248+
error_handler: None,
249+
shutdown_state: None,
250+
}
251+
}
252+
253+
/// Set the socket path
254+
pub fn socket_path(mut self, path: PathBuf) -> Self {
255+
self.socket_path = Some(path);
256+
self
257+
}
258+
259+
/// Set the transport adapter
260+
pub fn adapter(mut self, adapter: TransportAdapter) -> Self {
261+
self.adapter = Some(adapter);
262+
self
263+
}
264+
265+
/// Set the error handler
266+
pub fn error_handler(mut self, handler: Arc<dyn TransportErrorHandler>) -> Self {
267+
self.error_handler = Some(handler);
268+
self
269+
}
270+
271+
/// Set the shutdown state for graceful drain coordination
272+
pub fn shutdown_state(mut self, state: ShutdownState) -> Self {
273+
self.shutdown_state = Some(state);
274+
self
275+
}
276+
277+
/// Build the UdsListener
278+
pub fn build(self) -> Result<UdsListener> {
279+
let socket_path = self
280+
.socket_path
281+
.ok_or_else(|| anyhow::anyhow!("socket_path is required"))?;
282+
let adapter = self
283+
.adapter
284+
.ok_or_else(|| anyhow::anyhow!("adapter is required"))?;
285+
let error_handler = self
286+
.error_handler
287+
.ok_or_else(|| anyhow::anyhow!("error_handler is required"))?;
288+
let shutdown_state = self.shutdown_state.unwrap_or_default();
289+
290+
Ok(UdsListener {
291+
socket_path,
292+
adapter,
293+
error_handler,
294+
shutdown_state,
295+
})
296+
}
297+
}
298+
299+
impl Default for UdsListenerBuilder {
300+
fn default() -> Self {
301+
Self::new()
302+
}
303+
}
304+
305+
#[cfg(test)]
306+
mod tests {
307+
use super::*;
308+
use crate::transport::make_channels;
309+
310+
struct TestErrorHandler;
311+
312+
impl TransportErrorHandler for TestErrorHandler {
313+
fn on_error(&self, _header: Bytes, _payload: Bytes, error: String) {
314+
eprintln!("Test error handler: {}", error);
315+
}
316+
}
317+
318+
#[test]
319+
fn test_builder_requires_fields() {
320+
let result = UdsListener::builder().build();
321+
assert!(result.is_err());
322+
}
323+
324+
#[tokio::test]
325+
async fn test_builder_with_all_fields() {
326+
let (adapter, _streams) = make_channels();
327+
let error_handler = Arc::new(TestErrorHandler);
328+
329+
let result = UdsListener::builder()
330+
.socket_path(PathBuf::from("/tmp/test-uds-listener.sock"))
331+
.adapter(adapter)
332+
.error_handler(error_handler)
333+
.build();
334+
335+
assert!(result.is_ok());
336+
}
337+
}

0 commit comments

Comments
 (0)