@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
1414limitations under the License.
1515*/
1616
17+ use std:: collections:: HashSet ;
1718#[ cfg( unix) ]
1819use std:: os:: fd:: AsRawFd ;
1920#[ cfg( unix) ]
@@ -95,18 +96,35 @@ impl MultiUseSandbox {
9596 /// Create a snapshot of the current state of the sandbox's memory.
9697 #[ instrument( err( Debug ) , skip_all, parent = Span :: current( ) ) ]
9798 pub fn snapshot ( & mut self ) -> Result < Snapshot > {
98- let snapshot = self . mem_mgr . unwrap_mgr_mut ( ) . snapshot ( ) ?;
99- Ok ( Snapshot { inner : snapshot } )
99+ let mapped_regions_iter = self . vm . get_mapped_regions ( ) ;
100+ let mapped_regions_vec: Vec < MemoryRegion > = mapped_regions_iter. cloned ( ) . collect ( ) ;
101+ let memory_snapshot = self . mem_mgr . unwrap_mgr_mut ( ) . snapshot ( mapped_regions_vec) ?;
102+ Ok ( Snapshot {
103+ inner : memory_snapshot,
104+ } )
100105 }
101106
102107 /// Restore the sandbox's memory to the state captured in the given snapshot.
103108 #[ instrument( err( Debug ) , skip_all, parent = Span :: current( ) ) ]
104109 pub fn restore ( & mut self , snapshot : & Snapshot ) -> Result < ( ) > {
105- let rgns_to_unmap = self
106- . mem_mgr
110+ self . mem_mgr
107111 . unwrap_mgr_mut ( )
108112 . restore_snapshot ( & snapshot. inner ) ?;
109- unsafe { self . vm . unmap_regions ( rgns_to_unmap) ? } ;
113+
114+ let current_regions: HashSet < _ > = self . vm . get_mapped_regions ( ) . cloned ( ) . collect ( ) ;
115+ let snapshot_regions: HashSet < _ > = snapshot. inner . regions ( ) . iter ( ) . cloned ( ) . collect ( ) ;
116+
117+ let regions_to_unmap = current_regions. difference ( & snapshot_regions) ;
118+ let regions_to_map = snapshot_regions. difference ( & current_regions) ;
119+
120+ for region in regions_to_unmap {
121+ unsafe { self . vm . unmap_region ( region) ? } ;
122+ }
123+
124+ for region in regions_to_map {
125+ unsafe { self . vm . map_region ( region) ? } ;
126+ }
127+
110128 Ok ( ( ) )
111129 }
112130
@@ -645,4 +663,57 @@ mod tests {
645663 region_type : MemoryRegionType :: Heap ,
646664 }
647665 }
666+
667+ #[ cfg( target_os = "linux" ) ]
668+ fn allocate_guest_memory ( ) -> GuestSharedMemory {
669+ page_aligned_memory ( b"test data for snapshot" )
670+ }
671+
672+ #[ test]
673+ #[ cfg( target_os = "linux" ) ]
674+ fn snapshot_restore_handles_remapping_correctly ( ) {
675+ let mut sbox: MultiUseSandbox = {
676+ let path = simple_guest_as_string ( ) . unwrap ( ) ;
677+ let u_sbox = UninitializedSandbox :: new ( GuestBinary :: FilePath ( path) , None ) . unwrap ( ) ;
678+ u_sbox. evolve ( ) . unwrap ( )
679+ } ;
680+
681+ // 1. Take snapshot 1 with no additional regions mapped
682+ let snapshot1 = sbox. snapshot ( ) . unwrap ( ) ;
683+ assert_eq ! ( sbox. vm. get_mapped_regions( ) . len( ) , 0 ) ;
684+
685+ // 2. Map a memory region
686+ let map_mem = allocate_guest_memory ( ) ;
687+ let guest_base = 0x200000000_usize ;
688+ let region = region_for_memory ( & map_mem, guest_base) ;
689+
690+ unsafe { sbox. map_region ( & region) . unwrap ( ) } ;
691+ assert_eq ! ( sbox. vm. get_mapped_regions( ) . len( ) , 1 ) ;
692+
693+ // 3. Take snapshot 2 with 1 region mapped
694+ let snapshot2 = sbox. snapshot ( ) . unwrap ( ) ;
695+ assert_eq ! ( sbox. vm. get_mapped_regions( ) . len( ) , 1 ) ;
696+
697+ // 4. Restore to snapshot 1 (should unmap the region)
698+ sbox. restore ( & snapshot1) . unwrap ( ) ;
699+ assert_eq ! ( sbox. vm. get_mapped_regions( ) . len( ) , 0 ) ;
700+
701+ // 5. Restore forward to snapshot 2 (should remap the region)
702+ sbox. restore ( & snapshot2) . unwrap ( ) ;
703+ assert_eq ! ( sbox. vm. get_mapped_regions( ) . len( ) , 1 ) ;
704+
705+ // Verify the region is the same
706+ let mut restored_regions = sbox. vm . get_mapped_regions ( ) ;
707+ assert_eq ! ( * restored_regions. next( ) . unwrap( ) , region) ;
708+ assert ! ( restored_regions. next( ) . is_none( ) ) ;
709+ drop ( restored_regions) ;
710+
711+ // 6. Try map the region again (should fail since already mapped)
712+ let err = unsafe { sbox. map_region ( & region) } ;
713+ assert ! (
714+ err. is_err( ) ,
715+ "Expected error when remapping existing region: {:?}" ,
716+ err
717+ ) ;
718+ }
648719}
0 commit comments