Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 2 additions & 71 deletions src/vmm/src/snapshot/crc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,60 +4,10 @@
//! Implements readers and writers that compute the CRC64 checksum of the bytes
//! read/written.

use std::io::{Read, Write};
use std::io::Write;

use crc64::crc64;

/// Computes the CRC64 checksum of the read bytes.
///
/// ```
/// use std::io::Read;
///
/// use vmm::snapshot::crc::CRC64Reader;
///
/// let buf = vec![1, 2, 3, 4, 5];
/// let mut read_buf = Vec::new();
/// let mut slice = buf.as_slice();
///
/// // Create a reader from a slice.
/// let mut crc_reader = CRC64Reader::new(&mut slice);
///
/// let count = crc_reader.read_to_end(&mut read_buf).unwrap();
/// assert_eq!(crc_reader.checksum(), 0xFB04_60DE_0638_3654);
/// assert_eq!(read_buf, buf);
/// ```
#[derive(Debug)]
pub struct CRC64Reader<T> {
/// The underlying raw reader. Using this directly will bypass CRC computation!
pub reader: T,
crc64: u64,
}

impl<T> CRC64Reader<T>
where
T: Read,
{
/// Create a new reader.
pub fn new(reader: T) -> Self {
CRC64Reader { crc64: 0, reader }
}
/// Returns the current checksum value.
pub fn checksum(&self) -> u64 {
self.crc64
}
}

impl<T> Read for CRC64Reader<T>
where
T: Read,
{
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let bytes_read = self.reader.read(buf)?;
self.crc64 = crc64(self.crc64, &buf[..bytes_read]);
Ok(bytes_read)
}
}

/// Computes the CRC64 checksum of the written bytes.
///
/// ```
Expand Down Expand Up @@ -115,17 +65,10 @@ where

#[cfg(test)]
mod tests {
use super::{CRC64Reader, CRC64Writer, Read, Write};
use super::{CRC64Writer, Write};

#[test]
fn test_crc_new() {
let buf = vec![1; 5];
let mut slice = buf.as_slice();
let crc_reader = CRC64Reader::new(&mut slice);
assert_eq!(crc_reader.crc64, 0);
assert_eq!(crc_reader.reader, &[1; 5]);
assert_eq!(crc_reader.checksum(), 0);

let mut buf = vec![0; 5];
let mut slice = buf.as_mut_slice();
let crc_writer = CRC64Writer::new(&mut slice);
Expand All @@ -134,18 +77,6 @@ mod tests {
assert_eq!(crc_writer.checksum(), 0);
}

#[test]
fn test_crc_read() {
let buf = vec![1, 2, 3, 4, 5];
let mut read_buf = vec![0; 16];

let mut slice = buf.as_slice();
let mut crc_reader = CRC64Reader::new(&mut slice);
crc_reader.read_to_end(&mut read_buf).unwrap();
assert_eq!(crc_reader.checksum(), 0xFB04_60DE_0638_3654);
assert_eq!(crc_reader.checksum(), crc_reader.crc64);
}

#[test]
fn test_crc_write() {
let mut buf = vec![0; 16];
Expand Down
58 changes: 21 additions & 37 deletions src/vmm/src/snapshot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,20 @@
pub mod crc;
mod persist;
use std::fmt::Debug;
use std::io::{Read, Seek, SeekFrom, Write};
use std::io::{Read, Write};

use bincode::config;
use bincode::config::{Configuration, Fixint, Limit, LittleEndian};
use bincode::error::{DecodeError, EncodeError};
use crc64::crc64;
use semver::Version;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};

use crate::persist::SNAPSHOT_VERSION;
use crate::snapshot::crc::{CRC64Reader, CRC64Writer};
use crate::snapshot::crc::CRC64Writer;
pub use crate::snapshot::persist::Persist;
use crate::utils::{mib_to_bytes, u64_to_usize};
use crate::utils::mib_to_bytes;

#[cfg(target_arch = "x86_64")]
const SNAPSHOT_MAGIC_ID: u64 = 0x0710_1984_8664_0000u64;
Expand All @@ -58,14 +59,12 @@ const SNAPSHOT_MAGIC_ID: u64 = 0x0710_1984_AAAA_0000u64;
/// Error definitions for the Snapshot API.
#[derive(Debug, thiserror::Error, displaydoc::Display)]
pub enum SnapshotError {
/// CRC64 validation failed: {0}
Crc64(u64),
/// CRC64 validation failed
Crc64,
/// Invalid data version: {0}
InvalidFormatVersion(Version),
/// Magic value does not match arch: {0}
InvalidMagic(u64),
/// Snapshot file is not long enough to even contain the CRC
TooShort,
/// An error occured during bincode encoding: {0}
Encode(#[from] EncodeError),
/// An error occured during bincode decoding: {0}
Expand Down Expand Up @@ -152,24 +151,18 @@ impl<Data: DeserializeOwned> Snapshot<Data> {

/// Loads a snapshot from the given [`Read`] instance, performing all validations
/// (CRC, snapshot magic value, snapshot version).
pub fn load<R: Read + Seek>(reader: &mut R) -> Result<Self, SnapshotError> {
let snapshot_size = reader.seek(SeekFrom::End(0))?;
reader.seek(SeekFrom::Start(0))?;
// dont read the CRC yet.
let mut buf = vec![
0;
u64_to_usize(snapshot_size)
.checked_sub(size_of::<u64>())
.ok_or(SnapshotError::TooShort)?
];
let mut crc_reader = CRC64Reader::new(reader);
crc_reader.read_exact(buf.as_mut_slice())?;
pub fn load<R: Read>(reader: &mut R) -> Result<Self, SnapshotError> {
// read_to_end internally right-sizes the buffer, so no reallocations due to growing buffers
// will happen.
let mut buf = Vec::new();
reader.read_to_end(&mut buf)?;
let snapshot = Self::load_without_crc_check(buf.as_slice())?;
let computed_checksum = crc_reader.checksum();
let stored_checksum: u64 =
bincode::serde::decode_from_std_read(&mut crc_reader.reader, BINCODE_CONFIG)?;
if computed_checksum != stored_checksum {
return Err(SnapshotError::Crc64(computed_checksum));
let computed_checksum = crc64(0, buf.as_slice());
// When we read the entire file, we also read the checksum into the buffer. The CRC has the
// property that crc(0, buf.as_slice()) == 0 iff the last 8 bytes of buf are the checksum
// of all the preceeding bytes, and this is the property we are using here.
if computed_checksum != 0 {
return Err(SnapshotError::Crc64);
}
Ok(snapshot)
}
Expand All @@ -187,19 +180,16 @@ impl<Data: Serialize> Snapshot<Data> {

#[cfg(test)]
mod tests {
use vmm_sys_util::tempfile::TempFile;

use super::*;
use crate::persist::MicrovmState;

#[test]
fn test_snapshot_restore() {
let state = MicrovmState::default();
let file = TempFile::new().unwrap();
let mut buf = Vec::new();

Snapshot::new(state).save(&mut file.as_file()).unwrap();
file.as_file().seek(SeekFrom::Start(0)).unwrap();
Snapshot::<MicrovmState>::load(&mut file.as_file()).unwrap();
Snapshot::new(state).save(&mut buf).unwrap();
Snapshot::<MicrovmState>::load(&mut buf.as_slice()).unwrap();
}

#[test]
Expand Down Expand Up @@ -228,12 +218,6 @@ mod tests {
}
}

impl Seek for BadReader {
fn seek(&mut self, _pos: SeekFrom) -> std::io::Result<u64> {
Ok(9) // needs to be long enough to prevent to have a CRC
}
}

let mut reader = BadReader {};

assert!(
Expand Down Expand Up @@ -275,7 +259,7 @@ mod tests {

assert!(matches!(
Snapshot::<()>::load(&mut std::io::Cursor::new(data.as_slice())),
Err(SnapshotError::Crc64(_))
Err(SnapshotError::Crc64)
));
}

Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/functional/test_snapshot_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def test_load_snapshot_failure_handling(uvm_plain):

# Load the snapshot
with pytest.raises(
RuntimeError, match="Snapshot file is not long enough to even contain the CRC"
RuntimeError, match="An error occured during bincode decoding: UnexpectedEnd"
):
vm.api.snapshot_load.put(mem_file_path=jailed_mem, snapshot_path=jailed_vmstate)

Expand Down