diff --git a/src/vmm/src/snapshot/crc.rs b/src/vmm/src/snapshot/crc.rs index df97da26b13..46189bcae50 100644 --- a/src/vmm/src/snapshot/crc.rs +++ b/src/vmm/src/snapshot/crc.rs @@ -4,60 +4,10 @@ //! Implements readers and writers that compute the CRC64 checksum of the bytes //! read/written. -use std::io::{Read, Write}; +use std::io::Write; use crc64::crc64; -/// Computes the CRC64 checksum of the read bytes. -/// -/// ``` -/// use std::io::Read; -/// -/// use vmm::snapshot::crc::CRC64Reader; -/// -/// let buf = vec![1, 2, 3, 4, 5]; -/// let mut read_buf = Vec::new(); -/// let mut slice = buf.as_slice(); -/// -/// // Create a reader from a slice. -/// let mut crc_reader = CRC64Reader::new(&mut slice); -/// -/// let count = crc_reader.read_to_end(&mut read_buf).unwrap(); -/// assert_eq!(crc_reader.checksum(), 0xFB04_60DE_0638_3654); -/// assert_eq!(read_buf, buf); -/// ``` -#[derive(Debug)] -pub struct CRC64Reader { - /// The underlying raw reader. Using this directly will bypass CRC computation! - pub reader: T, - crc64: u64, -} - -impl CRC64Reader -where - T: Read, -{ - /// Create a new reader. - pub fn new(reader: T) -> Self { - CRC64Reader { crc64: 0, reader } - } - /// Returns the current checksum value. - pub fn checksum(&self) -> u64 { - self.crc64 - } -} - -impl Read for CRC64Reader -where - T: Read, -{ - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - let bytes_read = self.reader.read(buf)?; - self.crc64 = crc64(self.crc64, &buf[..bytes_read]); - Ok(bytes_read) - } -} - /// Computes the CRC64 checksum of the written bytes. /// /// ``` @@ -115,17 +65,10 @@ where #[cfg(test)] mod tests { - use super::{CRC64Reader, CRC64Writer, Read, Write}; + use super::{CRC64Writer, Write}; #[test] fn test_crc_new() { - let buf = vec![1; 5]; - let mut slice = buf.as_slice(); - let crc_reader = CRC64Reader::new(&mut slice); - assert_eq!(crc_reader.crc64, 0); - assert_eq!(crc_reader.reader, &[1; 5]); - assert_eq!(crc_reader.checksum(), 0); - let mut buf = vec![0; 5]; let mut slice = buf.as_mut_slice(); let crc_writer = CRC64Writer::new(&mut slice); @@ -134,18 +77,6 @@ mod tests { assert_eq!(crc_writer.checksum(), 0); } - #[test] - fn test_crc_read() { - let buf = vec![1, 2, 3, 4, 5]; - let mut read_buf = vec![0; 16]; - - let mut slice = buf.as_slice(); - let mut crc_reader = CRC64Reader::new(&mut slice); - crc_reader.read_to_end(&mut read_buf).unwrap(); - assert_eq!(crc_reader.checksum(), 0xFB04_60DE_0638_3654); - assert_eq!(crc_reader.checksum(), crc_reader.crc64); - } - #[test] fn test_crc_write() { let mut buf = vec![0; 16]; diff --git a/src/vmm/src/snapshot/mod.rs b/src/vmm/src/snapshot/mod.rs index c1452e978e4..76b5203298d 100644 --- a/src/vmm/src/snapshot/mod.rs +++ b/src/vmm/src/snapshot/mod.rs @@ -26,19 +26,20 @@ pub mod crc; mod persist; use std::fmt::Debug; -use std::io::{Read, Seek, SeekFrom, Write}; +use std::io::{Read, Write}; use bincode::config; use bincode::config::{Configuration, Fixint, Limit, LittleEndian}; use bincode::error::{DecodeError, EncodeError}; +use crc64::crc64; use semver::Version; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use crate::persist::SNAPSHOT_VERSION; -use crate::snapshot::crc::{CRC64Reader, CRC64Writer}; +use crate::snapshot::crc::CRC64Writer; pub use crate::snapshot::persist::Persist; -use crate::utils::{mib_to_bytes, u64_to_usize}; +use crate::utils::mib_to_bytes; #[cfg(target_arch = "x86_64")] const SNAPSHOT_MAGIC_ID: u64 = 0x0710_1984_8664_0000u64; @@ -58,14 +59,12 @@ const SNAPSHOT_MAGIC_ID: u64 = 0x0710_1984_AAAA_0000u64; /// Error definitions for the Snapshot API. #[derive(Debug, thiserror::Error, displaydoc::Display)] pub enum SnapshotError { - /// CRC64 validation failed: {0} - Crc64(u64), + /// CRC64 validation failed + Crc64, /// Invalid data version: {0} InvalidFormatVersion(Version), /// Magic value does not match arch: {0} InvalidMagic(u64), - /// Snapshot file is not long enough to even contain the CRC - TooShort, /// An error occured during bincode encoding: {0} Encode(#[from] EncodeError), /// An error occured during bincode decoding: {0} @@ -152,24 +151,18 @@ impl Snapshot { /// Loads a snapshot from the given [`Read`] instance, performing all validations /// (CRC, snapshot magic value, snapshot version). - pub fn load(reader: &mut R) -> Result { - let snapshot_size = reader.seek(SeekFrom::End(0))?; - reader.seek(SeekFrom::Start(0))?; - // dont read the CRC yet. - let mut buf = vec![ - 0; - u64_to_usize(snapshot_size) - .checked_sub(size_of::()) - .ok_or(SnapshotError::TooShort)? - ]; - let mut crc_reader = CRC64Reader::new(reader); - crc_reader.read_exact(buf.as_mut_slice())?; + pub fn load(reader: &mut R) -> Result { + // read_to_end internally right-sizes the buffer, so no reallocations due to growing buffers + // will happen. + let mut buf = Vec::new(); + reader.read_to_end(&mut buf)?; let snapshot = Self::load_without_crc_check(buf.as_slice())?; - let computed_checksum = crc_reader.checksum(); - let stored_checksum: u64 = - bincode::serde::decode_from_std_read(&mut crc_reader.reader, BINCODE_CONFIG)?; - if computed_checksum != stored_checksum { - return Err(SnapshotError::Crc64(computed_checksum)); + let computed_checksum = crc64(0, buf.as_slice()); + // When we read the entire file, we also read the checksum into the buffer. The CRC has the + // property that crc(0, buf.as_slice()) == 0 iff the last 8 bytes of buf are the checksum + // of all the preceeding bytes, and this is the property we are using here. + if computed_checksum != 0 { + return Err(SnapshotError::Crc64); } Ok(snapshot) } @@ -187,19 +180,16 @@ impl Snapshot { #[cfg(test)] mod tests { - use vmm_sys_util::tempfile::TempFile; - use super::*; use crate::persist::MicrovmState; #[test] fn test_snapshot_restore() { let state = MicrovmState::default(); - let file = TempFile::new().unwrap(); + let mut buf = Vec::new(); - Snapshot::new(state).save(&mut file.as_file()).unwrap(); - file.as_file().seek(SeekFrom::Start(0)).unwrap(); - Snapshot::::load(&mut file.as_file()).unwrap(); + Snapshot::new(state).save(&mut buf).unwrap(); + Snapshot::::load(&mut buf.as_slice()).unwrap(); } #[test] @@ -228,12 +218,6 @@ mod tests { } } - impl Seek for BadReader { - fn seek(&mut self, _pos: SeekFrom) -> std::io::Result { - Ok(9) // needs to be long enough to prevent to have a CRC - } - } - let mut reader = BadReader {}; assert!( @@ -275,7 +259,7 @@ mod tests { assert!(matches!( Snapshot::<()>::load(&mut std::io::Cursor::new(data.as_slice())), - Err(SnapshotError::Crc64(_)) + Err(SnapshotError::Crc64) )); } diff --git a/tests/integration_tests/functional/test_snapshot_basic.py b/tests/integration_tests/functional/test_snapshot_basic.py index 73028eab40b..bd9f1ec0d9b 100644 --- a/tests/integration_tests/functional/test_snapshot_basic.py +++ b/tests/integration_tests/functional/test_snapshot_basic.py @@ -237,7 +237,7 @@ def test_load_snapshot_failure_handling(uvm_plain): # Load the snapshot with pytest.raises( - RuntimeError, match="Snapshot file is not long enough to even contain the CRC" + RuntimeError, match="An error occured during bincode decoding: UnexpectedEnd" ): vm.api.snapshot_load.put(mem_file_path=jailed_mem, snapshot_path=jailed_vmstate)