@@ -463,6 +463,126 @@ fn get_max_log_level() -> u32 {
463463 LevelFilter :: from_str ( level) . unwrap_or ( LevelFilter :: Error ) as u32
464464}
465465
466+ #[ cfg( target_os = "windows" ) ]
467+ #[ derive( Debug ) ]
468+ pub ( super ) struct WindowsInterruptHandle {
469+ // `WHvCancelRunVirtualProcessor()` will return Ok even if the vcpu is not running, which is the reason we need this flag.
470+ running : AtomicBool ,
471+ cancel_requested : AtomicBool ,
472+ // This is used to signal the GDB thread to stop the vCPU
473+ #[ cfg( gdb) ]
474+ debug_interrupt : AtomicBool ,
475+ partition_handle : windows:: Win32 :: System :: Hypervisor :: WHV_PARTITION_HANDLE ,
476+ dropped : AtomicBool ,
477+ }
478+
479+ #[ cfg( target_os = "windows" ) ]
480+ impl InterruptHandleImpl for WindowsInterruptHandle {
481+ fn set_tid ( & self ) {
482+ // No-op on Windows - we don't need to track thread ID
483+ }
484+
485+ fn set_running ( & self ) {
486+ self . running . store ( true , Ordering :: Relaxed ) ;
487+ }
488+
489+ fn is_cancelled ( & self ) -> bool {
490+ self . cancel_requested . load ( Ordering :: Relaxed )
491+ }
492+
493+ fn clear_cancel ( & self ) {
494+ self . cancel_requested . store ( false , Ordering :: Relaxed ) ;
495+ }
496+
497+ fn clear_running ( & self ) {
498+ // On Windows, clear running, cancel_requested, and debug_interrupt together
499+ self . running . store ( false , Ordering :: Relaxed ) ;
500+ #[ cfg( gdb) ]
501+ self . debug_interrupt . store ( false , Ordering :: Relaxed ) ;
502+ }
503+
504+ fn is_debug_interrupted ( & self ) -> bool {
505+ #[ cfg( gdb) ]
506+ {
507+ self . debug_interrupt . load ( Ordering :: Relaxed )
508+ }
509+ #[ cfg( not( gdb) ) ]
510+ {
511+ false
512+ }
513+ }
514+
515+ #[ cfg( gdb) ]
516+ fn clear_debug_interrupt ( & self ) {
517+ #[ cfg( gdb) ]
518+ self . debug_interrupt . store ( false , Ordering :: Relaxed ) ;
519+ }
520+
521+ fn set_dropped ( & self ) {
522+ self . dropped . store ( true , Ordering :: Relaxed ) ;
523+ }
524+ }
525+
526+ #[ cfg( target_os = "windows" ) ]
527+ impl InterruptHandle for WindowsInterruptHandle {
528+ fn kill ( & self ) -> bool {
529+ use windows:: Win32 :: System :: Hypervisor :: WHvCancelRunVirtualProcessor ;
530+
531+ self . cancel_requested . store ( true , Ordering :: Relaxed ) ;
532+ self . running . load ( Ordering :: Relaxed )
533+ && unsafe { WHvCancelRunVirtualProcessor ( self . partition_handle , 0 , 0 ) . is_ok ( ) }
534+ }
535+ #[ cfg( gdb) ]
536+ fn kill_from_debugger ( & self ) -> bool {
537+ use windows:: Win32 :: System :: Hypervisor :: WHvCancelRunVirtualProcessor ;
538+
539+ self . debug_interrupt . store ( true , Ordering :: Relaxed ) ;
540+ self . running . load ( Ordering :: Relaxed )
541+ && unsafe { WHvCancelRunVirtualProcessor ( self . partition_handle , 0 , 0 ) . is_ok ( ) }
542+ }
543+
544+ fn dropped ( & self ) -> bool {
545+ self . dropped . load ( Ordering :: Relaxed )
546+ }
547+ }
548+
549+ /// Get the logging level to pass to the guest entrypoint
550+ fn get_max_log_level ( ) -> u32 {
551+ // Check to see if the RUST_LOG environment variable is set
552+ // and if so, parse it to get the log_level for hyperlight_guest
553+ // if that is not set get the log level for the hyperlight_host
554+
555+ // This is done as the guest will produce logs based on the log level returned here
556+ // producing those logs is expensive and we don't want to do it if the host is not
557+ // going to process them
558+
559+ let val = std:: env:: var ( "RUST_LOG" ) . unwrap_or_default ( ) ;
560+
561+ let level = if val. contains ( "hyperlight_guest" ) {
562+ val. split ( ',' )
563+ . find ( |s| s. contains ( "hyperlight_guest" ) )
564+ . unwrap_or ( "" )
565+ . split ( '=' )
566+ . nth ( 1 )
567+ . unwrap_or ( "" )
568+ } else if val. contains ( "hyperlight_host" ) {
569+ val. split ( ',' )
570+ . find ( |s| s. contains ( "hyperlight_host" ) )
571+ . unwrap_or ( "" )
572+ . split ( '=' )
573+ . nth ( 1 )
574+ . unwrap_or ( "" )
575+ } else {
576+ // look for a value string that does not contain "="
577+ val. split ( ',' ) . find ( |s| !s. contains ( "=" ) ) . unwrap_or ( "" )
578+ } ;
579+
580+ log:: info!( "Determined guest log level: {}" , level) ;
581+ // Convert the log level string to a LevelFilter
582+ // If no value is found, default to Error
583+ LevelFilter :: from_str ( level) . unwrap_or ( LevelFilter :: Error ) as u32
584+ }
585+
466586#[ cfg( all( test, any( target_os = "windows" , kvm) ) ) ]
467587pub ( crate ) mod tests {
468588 use std:: sync:: { Arc , Mutex } ;
0 commit comments