@@ -81,14 +81,15 @@ pub async fn init(
8181 socket_address : Option < String > ,
8282) -> ( Context , SocketStream ) {
8383 let context = Context :: get ( nested) ;
84- let init_result = match context {
85- Context :: Pid1 => Pid1SystemRuntime { } . init ( verbose, socket_address) ,
86- Context :: Cell => CellSystemRuntime { } . init ( verbose, socket_address) ,
87- Context :: Container => {
88- ContainerSystemRuntime { } . init ( verbose, socket_address)
89- }
90- Context :: Daemon => DaemonSystemRuntime { } . init ( verbose, socket_address) ,
91- }
84+ let init_result = init_with_runtimes (
85+ context,
86+ verbose,
87+ socket_address,
88+ Pid1SystemRuntime { } ,
89+ CellSystemRuntime { } ,
90+ ContainerSystemRuntime { } ,
91+ DaemonSystemRuntime { } ,
92+ )
9293 . await ;
9394
9495 match init_result {
@@ -97,6 +98,31 @@ pub async fn init(
9798 }
9899}
99100
101+ async fn init_with_runtimes < RPid1 , RCell , RContainer , RDaemon > (
102+ context : Context ,
103+ verbose : bool ,
104+ socket_address : Option < String > ,
105+ pid1_runtime : RPid1 ,
106+ cell_runtime : RCell ,
107+ container_runtime : RContainer ,
108+ daemon_runtime : RDaemon ,
109+ ) -> Result < SocketStream , SystemRuntimeError >
110+ where
111+ RPid1 : SystemRuntime ,
112+ RCell : SystemRuntime ,
113+ RContainer : SystemRuntime ,
114+ RDaemon : SystemRuntime ,
115+ {
116+ match context {
117+ Context :: Pid1 => pid1_runtime. init ( verbose, socket_address) . await ,
118+ Context :: Cell => cell_runtime. init ( verbose, socket_address) . await ,
119+ Context :: Container => {
120+ container_runtime. init ( verbose, socket_address) . await
121+ }
122+ Context :: Daemon => daemon_runtime. init ( verbose, socket_address) . await ,
123+ }
124+ }
125+
100126#[ derive( Debug , PartialEq , Eq , Copy , Clone ) ]
101127pub enum Context {
102128 /// auraed is running as true PID 1
@@ -215,6 +241,16 @@ fn in_new_cgroup_namespace() -> bool {
215241#[ cfg( test) ]
216242mod tests {
217243 use super :: * ;
244+ use crate :: init:: system_runtimes:: {
245+ SocketStream , SystemRuntime , SystemRuntimeError ,
246+ } ;
247+ use anyhow:: anyhow;
248+ use std:: sync:: {
249+ Arc ,
250+ atomic:: { AtomicUsize , Ordering } ,
251+ } ;
252+ use tokio:: runtime:: Runtime ;
253+ use tonic:: async_trait;
218254
219255 fn pid_one ( ) -> u32 {
220256 1
@@ -282,4 +318,83 @@ mod tests {
282318 Context :: Cell
283319 ) ;
284320 }
321+
322+ #[ derive( Clone ) ]
323+ struct MockRuntime {
324+ calls : Arc < AtomicUsize > ,
325+ label : & ' static str ,
326+ }
327+
328+ impl MockRuntime {
329+ fn new ( label : & ' static str ) -> Self {
330+ Self { calls : Arc :: new ( AtomicUsize :: new ( 0 ) ) , label }
331+ }
332+ }
333+
334+ #[ async_trait]
335+ impl SystemRuntime for Arc < MockRuntime > {
336+ async fn init (
337+ self ,
338+ _verbose : bool ,
339+ _socket_address : Option < String > ,
340+ ) -> Result < SocketStream , SystemRuntimeError > {
341+ let _ = self . calls . fetch_add ( 1 , Ordering :: SeqCst ) ;
342+ Err ( SystemRuntimeError :: Other ( anyhow ! ( self . label) ) )
343+ }
344+ }
345+
346+ fn assert_called_once ( mock : & Arc < MockRuntime > ) {
347+ assert_eq ! (
348+ mock. calls. load( Ordering :: SeqCst ) ,
349+ 1 ,
350+ "expected {} to be called once" ,
351+ mock. label
352+ ) ;
353+ }
354+
355+ #[ test]
356+ fn init_should_call_matching_system_runtime ( ) {
357+ // This test ensures the `init` dispatcher chooses the correct runtime
358+ // implementation for each Context. We avoid spinning up real runtimes
359+ // by injecting cheap mocks that count how many times they're called.
360+ let rt = Runtime :: new ( ) . expect ( "tokio runtime" ) ;
361+
362+ let pid1 = Arc :: new ( MockRuntime :: new ( "pid1" ) ) ;
363+ let cell = Arc :: new ( MockRuntime :: new ( "cell" ) ) ;
364+ let container = Arc :: new ( MockRuntime :: new ( "container" ) ) ;
365+ let daemon = Arc :: new ( MockRuntime :: new ( "daemon" ) ) ;
366+
367+ rt. block_on ( async {
368+ // Each tuple represents (nested flag, pid, in_cgroup_namespace).
369+ // We exercise the four Context variants the same way Context::get does.
370+ let runtimes = [
371+ ( false , 1 , false ) ,
372+ ( true , 1 , false ) ,
373+ ( false , 42 , true ) ,
374+ ( false , 42 , false ) ,
375+ ] ;
376+
377+ for ( nested, pid, in_cgroup) in runtimes {
378+ let ctx = derive_context ( nested, pid, in_cgroup) ;
379+
380+ // Call the same routing code init() uses, but with our mocks.
381+ let _ = init_with_runtimes (
382+ ctx,
383+ false ,
384+ None ,
385+ pid1. clone ( ) ,
386+ cell. clone ( ) ,
387+ container. clone ( ) ,
388+ daemon. clone ( ) ,
389+ )
390+ . await ;
391+ }
392+ } ) ;
393+
394+ // Each mock should have been called exactly once by its matching Context.
395+ assert_called_once ( & pid1) ;
396+ assert_called_once ( & cell) ;
397+ assert_called_once ( & container) ;
398+ assert_called_once ( & daemon) ;
399+ }
285400}
0 commit comments