diff --git a/auraed/src/init/mod.rs b/auraed/src/init/mod.rs index cbab76f9c..f8300cbb3 100644 --- a/auraed/src/init/mod.rs +++ b/auraed/src/init/mod.rs @@ -81,14 +81,15 @@ pub async fn init( socket_address: Option, ) -> (Context, SocketStream) { let context = Context::get(nested); - let init_result = match context { - Context::Pid1 => Pid1SystemRuntime {}.init(verbose, socket_address), - Context::Cell => CellSystemRuntime {}.init(verbose, socket_address), - Context::Container => { - ContainerSystemRuntime {}.init(verbose, socket_address) - } - Context::Daemon => DaemonSystemRuntime {}.init(verbose, socket_address), - } + let init_result = init_with_runtimes( + context, + verbose, + socket_address, + Pid1SystemRuntime {}, + CellSystemRuntime {}, + ContainerSystemRuntime {}, + DaemonSystemRuntime {}, + ) .await; match init_result { @@ -97,6 +98,31 @@ pub async fn init( } } +async fn init_with_runtimes( + context: Context, + verbose: bool, + socket_address: Option, + pid1_runtime: RPid1, + cell_runtime: RCell, + container_runtime: RContainer, + daemon_runtime: RDaemon, +) -> Result +where + RPid1: SystemRuntime, + RCell: SystemRuntime, + RContainer: SystemRuntime, + RDaemon: SystemRuntime, +{ + match context { + Context::Pid1 => pid1_runtime.init(verbose, socket_address).await, + Context::Cell => cell_runtime.init(verbose, socket_address).await, + Context::Container => { + container_runtime.init(verbose, socket_address).await + } + Context::Daemon => daemon_runtime.init(verbose, socket_address).await, + } +} + #[derive(Debug, PartialEq, Eq, Copy, Clone)] pub enum Context { /// auraed is running as true PID 1 @@ -215,6 +241,16 @@ fn in_new_cgroup_namespace() -> bool { #[cfg(test)] mod tests { use super::*; + use crate::init::system_runtimes::{ + SocketStream, SystemRuntime, SystemRuntimeError, + }; + use anyhow::anyhow; + use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }; + use tokio::runtime::Runtime; + use tonic::async_trait; fn pid_one() -> u32 { 1 @@ -282,4 +318,83 @@ mod tests { Context::Cell ); } + + #[derive(Clone)] + struct MockRuntime { + calls: Arc, + label: &'static str, + } + + impl MockRuntime { + fn new(label: &'static str) -> Self { + Self { calls: Arc::new(AtomicUsize::new(0)), label } + } + } + + #[async_trait] + impl SystemRuntime for Arc { + async fn init( + self, + _verbose: bool, + _socket_address: Option, + ) -> Result { + let _ = self.calls.fetch_add(1, Ordering::SeqCst); + Err(SystemRuntimeError::Other(anyhow!(self.label))) + } + } + + fn assert_called_once(mock: &Arc) { + assert_eq!( + mock.calls.load(Ordering::SeqCst), + 1, + "expected {} to be called once", + mock.label + ); + } + + #[test] + fn init_should_call_matching_system_runtime() { + // This test ensures the `init` dispatcher chooses the correct runtime + // implementation for each Context. We avoid spinning up real runtimes + // by injecting cheap mocks that count how many times they're called. + let rt = Runtime::new().expect("tokio runtime"); + + let pid1 = Arc::new(MockRuntime::new("pid1")); + let cell = Arc::new(MockRuntime::new("cell")); + let container = Arc::new(MockRuntime::new("container")); + let daemon = Arc::new(MockRuntime::new("daemon")); + + rt.block_on(async { + // Each tuple represents (nested flag, pid, in_cgroup_namespace). + // We exercise the four Context variants the same way Context::get does. + let runtimes = [ + (false, 1, false), + (true, 1, false), + (false, 42, true), + (false, 42, false), + ]; + + for (nested, pid, in_cgroup) in runtimes { + let ctx = derive_context(nested, pid, in_cgroup); + + // Call the same routing code init() uses, but with our mocks. + let _ = init_with_runtimes( + ctx, + false, + None, + pid1.clone(), + cell.clone(), + container.clone(), + daemon.clone(), + ) + .await; + } + }); + + // Each mock should have been called exactly once by its matching Context. + assert_called_once(&pid1); + assert_called_once(&cell); + assert_called_once(&container); + assert_called_once(&daemon); + } }