Skip to content

Commit ea371ad

Browse files
authored
Merge branch 'main' into fmt-check-workflow
2 parents c281682 + 27d24bf commit ea371ad

File tree

4 files changed

+338
-8
lines changed

4 files changed

+338
-8
lines changed

auraed/src/init/mod.rs

Lines changed: 123 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,15 @@ pub async fn init(
8181
socket_address: Option<String>,
8282
) -> (Context, SocketStream) {
8383
let context = Context::get(nested);
84-
let init_result = match context {
85-
Context::Pid1 => Pid1SystemRuntime {}.init(verbose, socket_address),
86-
Context::Cell => CellSystemRuntime {}.init(verbose, socket_address),
87-
Context::Container => {
88-
ContainerSystemRuntime {}.init(verbose, socket_address)
89-
}
90-
Context::Daemon => DaemonSystemRuntime {}.init(verbose, socket_address),
91-
}
84+
let init_result = init_with_runtimes(
85+
context,
86+
verbose,
87+
socket_address,
88+
Pid1SystemRuntime {},
89+
CellSystemRuntime {},
90+
ContainerSystemRuntime {},
91+
DaemonSystemRuntime {},
92+
)
9293
.await;
9394

9495
match init_result {
@@ -97,6 +98,31 @@ pub async fn init(
9798
}
9899
}
99100

101+
async fn init_with_runtimes<RPid1, RCell, RContainer, RDaemon>(
102+
context: Context,
103+
verbose: bool,
104+
socket_address: Option<String>,
105+
pid1_runtime: RPid1,
106+
cell_runtime: RCell,
107+
container_runtime: RContainer,
108+
daemon_runtime: RDaemon,
109+
) -> Result<SocketStream, SystemRuntimeError>
110+
where
111+
RPid1: SystemRuntime,
112+
RCell: SystemRuntime,
113+
RContainer: SystemRuntime,
114+
RDaemon: SystemRuntime,
115+
{
116+
match context {
117+
Context::Pid1 => pid1_runtime.init(verbose, socket_address).await,
118+
Context::Cell => cell_runtime.init(verbose, socket_address).await,
119+
Context::Container => {
120+
container_runtime.init(verbose, socket_address).await
121+
}
122+
Context::Daemon => daemon_runtime.init(verbose, socket_address).await,
123+
}
124+
}
125+
100126
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
101127
pub enum Context {
102128
/// auraed is running as true PID 1
@@ -215,6 +241,16 @@ fn in_new_cgroup_namespace() -> bool {
215241
#[cfg(test)]
216242
mod tests {
217243
use super::*;
244+
use crate::init::system_runtimes::{
245+
SocketStream, SystemRuntime, SystemRuntimeError,
246+
};
247+
use anyhow::anyhow;
248+
use std::sync::{
249+
Arc,
250+
atomic::{AtomicUsize, Ordering},
251+
};
252+
use tokio::runtime::Runtime;
253+
use tonic::async_trait;
218254

219255
fn pid_one() -> u32 {
220256
1
@@ -282,4 +318,83 @@ mod tests {
282318
Context::Cell
283319
);
284320
}
321+
322+
#[derive(Clone)]
323+
struct MockRuntime {
324+
calls: Arc<AtomicUsize>,
325+
label: &'static str,
326+
}
327+
328+
impl MockRuntime {
329+
fn new(label: &'static str) -> Self {
330+
Self { calls: Arc::new(AtomicUsize::new(0)), label }
331+
}
332+
}
333+
334+
#[async_trait]
335+
impl SystemRuntime for Arc<MockRuntime> {
336+
async fn init(
337+
self,
338+
_verbose: bool,
339+
_socket_address: Option<String>,
340+
) -> Result<SocketStream, SystemRuntimeError> {
341+
let _ = self.calls.fetch_add(1, Ordering::SeqCst);
342+
Err(SystemRuntimeError::Other(anyhow!(self.label)))
343+
}
344+
}
345+
346+
fn assert_called_once(mock: &Arc<MockRuntime>) {
347+
assert_eq!(
348+
mock.calls.load(Ordering::SeqCst),
349+
1,
350+
"expected {} to be called once",
351+
mock.label
352+
);
353+
}
354+
355+
#[test]
356+
fn init_should_call_matching_system_runtime() {
357+
// This test ensures the `init` dispatcher chooses the correct runtime
358+
// implementation for each Context. We avoid spinning up real runtimes
359+
// by injecting cheap mocks that count how many times they're called.
360+
let rt = Runtime::new().expect("tokio runtime");
361+
362+
let pid1 = Arc::new(MockRuntime::new("pid1"));
363+
let cell = Arc::new(MockRuntime::new("cell"));
364+
let container = Arc::new(MockRuntime::new("container"));
365+
let daemon = Arc::new(MockRuntime::new("daemon"));
366+
367+
rt.block_on(async {
368+
// Each tuple represents (nested flag, pid, in_cgroup_namespace).
369+
// We exercise the four Context variants the same way Context::get does.
370+
let runtimes = [
371+
(false, 1, false),
372+
(true, 1, false),
373+
(false, 42, true),
374+
(false, 42, false),
375+
];
376+
377+
for (nested, pid, in_cgroup) in runtimes {
378+
let ctx = derive_context(nested, pid, in_cgroup);
379+
380+
// Call the same routing code init() uses, but with our mocks.
381+
let _ = init_with_runtimes(
382+
ctx,
383+
false,
384+
None,
385+
pid1.clone(),
386+
cell.clone(),
387+
container.clone(),
388+
daemon.clone(),
389+
)
390+
.await;
391+
}
392+
});
393+
394+
// Each mock should have been called exactly once by its matching Context.
395+
assert_called_once(&pid1);
396+
assert_called_once(&cell);
397+
assert_called_once(&container);
398+
assert_called_once(&daemon);
399+
}
285400
}

auraed/src/lib.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,3 +355,25 @@ pub fn prep_oci_spec_for_spawn(output: &str) -> Result<(), anyhow::Error> {
355355
.map_err(|e| anyhow!("building default oci spec: {e}"))?;
356356
spawn_auraed_oci_to(PathBuf::from(output), spec)
357357
}
358+
359+
#[cfg(test)]
360+
mod tests {
361+
use super::*;
362+
363+
#[test]
364+
fn auraed_runtime_default_socket_address_should_use_runtime_dir() {
365+
let default_runtime = AuraedRuntime::default();
366+
assert_eq!(
367+
default_runtime.default_socket_address(),
368+
PathBuf::from("/var/run/aurae/aurae.sock")
369+
);
370+
371+
let custom_runtime_dir = PathBuf::from("/tmp/aurae-test-runtime");
372+
let mut runtime = AuraedRuntime::default();
373+
runtime.runtime_dir = custom_runtime_dir.clone();
374+
assert_eq!(
375+
runtime.default_socket_address(),
376+
custom_runtime_dir.join("aurae.sock")
377+
);
378+
}
379+
}

auraed/tests/common/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,11 @@ macro_rules! retry {
7575
}};
7676
}
7777

78+
#[allow(dead_code)]
7879
static RT: Lazy<tokio::runtime::Runtime> =
7980
Lazy::new(|| tokio::runtime::Runtime::new().unwrap());
8081

82+
#[allow(dead_code)]
8183
pub fn test<F>(f: F) -> F::Output
8284
where
8385
F: Future + Send + 'static,
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
/* -------------------------------------------------------------------------- *\
2+
* | █████╗ ██╗ ██╗██████╗ █████╗ ███████╗ | *
3+
* | ██╔══██╗██║ ██║██╔══██╗██╔══██╗██╔════╝ | *
4+
* | ███████║██║ ██║██████╔╝███████║█████╗ | *
5+
* | ██╔══██║██║ ██║██╔══██╗██╔══██║██╔══╝ | *
6+
* | ██║ ██║╚██████╔╝██║ ██║██║ ██║███████╗ | *
7+
* | ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═╝╚══════╝ | *
8+
* +--------------------------------------------+ *
9+
* *
10+
* Distributed Systems Runtime *
11+
* -------------------------------------------------------------------------- *
12+
* Copyright 2022 - 2024, the aurae contributors *
13+
* SPDX-License-Identifier: Apache-2.0 *
14+
\* -------------------------------------------------------------------------- */
15+
16+
//! Ensure auraed defaults to a Unix socket in daemon mode and does not listen on TCP.
17+
18+
mod common;
19+
20+
use std::{
21+
io,
22+
net::{SocketAddr, TcpListener},
23+
os::unix::fs::FileTypeExt,
24+
path::{Path, PathBuf},
25+
process::{Command, Stdio},
26+
thread,
27+
time::{Duration, Instant},
28+
};
29+
use test_helpers::*;
30+
31+
fn tcp_addrs_available_before_spawn() -> Vec<SocketAddr> {
32+
["127.0.0.1:8080", "[::1]:8080"]
33+
.into_iter()
34+
.filter_map(|addr| addr.parse::<SocketAddr>().ok())
35+
.filter_map(|addr| match TcpListener::bind(addr) {
36+
Ok(listener) => {
37+
drop(listener);
38+
Some(addr)
39+
}
40+
Err(e) if e.kind() == io::ErrorKind::AddrInUse => None,
41+
Err(e) if e.kind() == io::ErrorKind::AddrNotAvailable => None,
42+
Err(e) => panic!("unexpected error probing {addr}: {e}"),
43+
})
44+
.collect()
45+
}
46+
47+
#[test]
48+
fn auraed_daemon_mode_should_bind_only_unix_socket() {
49+
skip_if_not_root!("auraed_daemon_mode_should_bind_only_unix_socket");
50+
skip_if_seccomp!("auraed_daemon_mode_should_bind_only_unix_socket");
51+
52+
let tempdir = tempfile::tempdir().expect("tempdir");
53+
let runtime_dir = tempdir.path().join("runtime");
54+
let library_dir = tempdir.path().join("library");
55+
std::fs::create_dir_all(&runtime_dir).expect("runtime dir");
56+
std::fs::create_dir_all(&library_dir).expect("library dir");
57+
58+
let tls = generate_tls_material(tempdir.path());
59+
60+
let tcp_addrs = tcp_addrs_available_before_spawn();
61+
62+
let child = Command::new(env!("CARGO_BIN_EXE_auraed"))
63+
.arg("--runtime-dir")
64+
.arg(runtime_dir.to_str().expect("runtime dir"))
65+
.arg("--library-dir")
66+
.arg(library_dir.to_str().expect("library dir"))
67+
.arg("--ca-crt")
68+
.arg(tls.ca_crt.to_str().expect("ca crt"))
69+
.arg("--server-crt")
70+
.arg(tls.server_crt.to_str().expect("server crt"))
71+
.arg("--server-key")
72+
.arg(tls.server_key.to_str().expect("server key"))
73+
.stdout(Stdio::null())
74+
.stderr(Stdio::null())
75+
.spawn()
76+
.expect("spawn auraed");
77+
let _guard = common::ChildGuard::new(child);
78+
79+
let socket_path = runtime_dir.join("aurae.sock");
80+
wait_for_socket(&socket_path, Duration::from_secs(5));
81+
82+
let meta =
83+
std::fs::symlink_metadata(&socket_path).expect("metadata for socket");
84+
assert!(
85+
meta.file_type().is_socket(),
86+
"expected {:?} to be a Unix socket",
87+
socket_path
88+
);
89+
90+
// Default daemon mode should not open the documented TCP endpoint ([::1]:8080 or 127.0.0.1:8080).
91+
// Only check addresses that were free before spawning auraed to avoid false positives from other services.
92+
for addr in tcp_addrs {
93+
let tcp_result = TcpListener::bind(addr).map(|listener| drop(listener));
94+
assert!(
95+
tcp_result.is_ok(),
96+
"expected no TCP listener at {addr}, but binding failed after starting auraed"
97+
);
98+
}
99+
}
100+
101+
fn wait_for_socket(path: &Path, timeout: Duration) {
102+
let start = Instant::now();
103+
while start.elapsed() < timeout {
104+
if path.exists() {
105+
return;
106+
}
107+
thread::sleep(Duration::from_millis(50));
108+
}
109+
panic!("socket {path:?} not created within {:?}", timeout);
110+
}
111+
112+
struct TlsMaterial {
113+
ca_crt: PathBuf,
114+
server_crt: PathBuf,
115+
server_key: PathBuf,
116+
}
117+
118+
fn generate_tls_material(dir: &Path) -> TlsMaterial {
119+
let ca_crt = dir.join("ca.crt");
120+
let ca_key = dir.join("ca.key");
121+
let server_csr = dir.join("server.csr");
122+
let server_crt = dir.join("server.crt");
123+
let server_key = dir.join("server.key");
124+
125+
Command::new("openssl")
126+
.args([
127+
"req",
128+
"-x509",
129+
"-nodes",
130+
"-newkey",
131+
"rsa:2048",
132+
"-sha256",
133+
"-days",
134+
"365",
135+
"-keyout",
136+
ca_key.to_str().unwrap(),
137+
"-out",
138+
ca_crt.to_str().unwrap(),
139+
"-subj",
140+
"/CN=AuraeTestCA",
141+
])
142+
.status()
143+
.expect("run openssl for CA")
144+
.success()
145+
.then_some(())
146+
.expect("openssl CA generation failed");
147+
148+
Command::new("openssl")
149+
.args([
150+
"req",
151+
"-nodes",
152+
"-newkey",
153+
"rsa:2048",
154+
"-keyout",
155+
server_key.to_str().unwrap(),
156+
"-out",
157+
server_csr.to_str().unwrap(),
158+
"-subj",
159+
"/CN=server.unsafe.aurae.io",
160+
])
161+
.status()
162+
.expect("run openssl for server csr")
163+
.success()
164+
.then_some(())
165+
.expect("openssl server csr failed");
166+
167+
Command::new("openssl")
168+
.args([
169+
"x509",
170+
"-req",
171+
"-in",
172+
server_csr.to_str().unwrap(),
173+
"-CA",
174+
ca_crt.to_str().unwrap(),
175+
"-CAkey",
176+
ca_key.to_str().unwrap(),
177+
"-CAcreateserial",
178+
"-out",
179+
server_crt.to_str().unwrap(),
180+
"-days",
181+
"365",
182+
"-sha256",
183+
])
184+
.status()
185+
.expect("run openssl to sign server cert")
186+
.success()
187+
.then_some(())
188+
.expect("openssl sign server cert failed");
189+
190+
TlsMaterial { ca_crt, server_crt, server_key }
191+
}

0 commit comments

Comments
 (0)