Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 123 additions & 8 deletions auraed/src/init/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,15 @@ pub async fn init(
socket_address: Option<String>,
) -> (Context, SocketStream) {
let context = Context::get(nested);
let init_result = match context {
Context::Pid1 => Pid1SystemRuntime {}.init(verbose, socket_address),
Context::Cell => CellSystemRuntime {}.init(verbose, socket_address),
Context::Container => {
ContainerSystemRuntime {}.init(verbose, socket_address)
}
Context::Daemon => DaemonSystemRuntime {}.init(verbose, socket_address),
}
let init_result = init_with_runtimes(
context,
verbose,
socket_address,
Pid1SystemRuntime {},
CellSystemRuntime {},
ContainerSystemRuntime {},
DaemonSystemRuntime {},
)
.await;

match init_result {
Expand All @@ -97,6 +98,31 @@ pub async fn init(
}
}

async fn init_with_runtimes<RPid1, RCell, RContainer, RDaemon>(
context: Context,
verbose: bool,
socket_address: Option<String>,
pid1_runtime: RPid1,
cell_runtime: RCell,
container_runtime: RContainer,
daemon_runtime: RDaemon,
) -> Result<SocketStream, SystemRuntimeError>
where
RPid1: SystemRuntime,
RCell: SystemRuntime,
RContainer: SystemRuntime,
RDaemon: SystemRuntime,
{
match context {
Context::Pid1 => pid1_runtime.init(verbose, socket_address).await,
Context::Cell => cell_runtime.init(verbose, socket_address).await,
Context::Container => {
container_runtime.init(verbose, socket_address).await
}
Context::Daemon => daemon_runtime.init(verbose, socket_address).await,
}
}

#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub enum Context {
/// auraed is running as true PID 1
Expand Down Expand Up @@ -215,6 +241,16 @@ fn in_new_cgroup_namespace() -> bool {
#[cfg(test)]
mod tests {
use super::*;
use crate::init::system_runtimes::{
SocketStream, SystemRuntime, SystemRuntimeError,
};
use anyhow::anyhow;
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use tokio::runtime::Runtime;
use tonic::async_trait;

fn pid_one() -> u32 {
1
Expand Down Expand Up @@ -282,4 +318,83 @@ mod tests {
Context::Cell
);
}

#[derive(Clone)]
struct MockRuntime {
calls: Arc<AtomicUsize>,
label: &'static str,
}

impl MockRuntime {
fn new(label: &'static str) -> Self {
Self { calls: Arc::new(AtomicUsize::new(0)), label }
}
}

#[async_trait]
impl SystemRuntime for Arc<MockRuntime> {
async fn init(
self,
_verbose: bool,
_socket_address: Option<String>,
) -> Result<SocketStream, SystemRuntimeError> {
let _ = self.calls.fetch_add(1, Ordering::SeqCst);
Err(SystemRuntimeError::Other(anyhow!(self.label)))
}
}

fn assert_called_once(mock: &Arc<MockRuntime>) {
assert_eq!(
mock.calls.load(Ordering::SeqCst),
1,
"expected {} to be called once",
mock.label
);
}

#[test]
fn init_should_call_matching_system_runtime() {
// This test ensures the `init` dispatcher chooses the correct runtime
// implementation for each Context. We avoid spinning up real runtimes
// by injecting cheap mocks that count how many times they're called.
let rt = Runtime::new().expect("tokio runtime");

let pid1 = Arc::new(MockRuntime::new("pid1"));
let cell = Arc::new(MockRuntime::new("cell"));
let container = Arc::new(MockRuntime::new("container"));
let daemon = Arc::new(MockRuntime::new("daemon"));

rt.block_on(async {
// Each tuple represents (nested flag, pid, in_cgroup_namespace).
// We exercise the four Context variants the same way Context::get does.
let runtimes = [
(false, 1, false),
(true, 1, false),
(false, 42, true),
(false, 42, false),
];

for (nested, pid, in_cgroup) in runtimes {
let ctx = derive_context(nested, pid, in_cgroup);

// Call the same routing code init() uses, but with our mocks.
let _ = init_with_runtimes(
ctx,
false,
None,
pid1.clone(),
cell.clone(),
container.clone(),
daemon.clone(),
)
.await;
}
});

// Each mock should have been called exactly once by its matching Context.
assert_called_once(&pid1);
assert_called_once(&cell);
assert_called_once(&container);
assert_called_once(&daemon);
}
}