From 846a97d5625902c616cef0de859d104861a4cdc7 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Tue, 8 Jul 2025 11:18:21 +0100 Subject: [PATCH 1/2] fix(pci): do not panic on invalid BDF during deserialization Correctly handle invalid Bdf by returning an error to the deserializer. This bug was caught by the fuzzer. Signed-off-by: Riccardo Mancini --- src/pci/src/lib.rs | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/pci/src/lib.rs b/src/pci/src/lib.rs index f1dec5b126a..f2cf82f81fa 100644 --- a/src/pci/src/lib.rs +++ b/src/pci/src/lib.rs @@ -71,7 +71,7 @@ impl Visitor<'_> for PciBdfVisitor { where E: serde::de::Error, { - Ok(v.into()) + PciBdf::from_str(v).map_err(serde::de::Error::custom) } } @@ -176,24 +176,31 @@ impl Display for PciBdf { } } +/// Errors associated with parsing a BDF string. +#[derive(Debug, thiserror::Error, displaydoc::Display)] +pub enum PciBdfParseError { + /// Unable to parse bus/device/function number hex: {0} + InvalidHex(#[from] ParseIntError), + /// Invalid format: {0} (expected format: 0000:00:00.0) + InvalidFormat(String), +} + impl FromStr for PciBdf { - type Err = ParseIntError; + type Err = PciBdfParseError; fn from_str(s: &str) -> Result { let items: Vec<&str> = s.split('.').collect(); - assert_eq!(items.len(), 2); + if items.len() != 2 { + return Err(PciBdfParseError::InvalidFormat(s.to_string())); + } let function = u8::from_str_radix(items[1], 16)?; let items: Vec<&str> = items[0].split(':').collect(); - assert_eq!(items.len(), 3); + if items.len() != 3 { + return Err(PciBdfParseError::InvalidFormat(s.to_string())); + } let segment = u16::from_str_radix(items[0], 16)?; let bus = u8::from_str_radix(items[1], 16)?; let device = u8::from_str_radix(items[2], 16)?; Ok(PciBdf::new(segment, bus, device, function)) } } - -impl From<&str> for PciBdf { - fn from(bdf: &str) -> Self { - Self::from_str(bdf).unwrap() - } -} From 2187037dcd7306d42ff2ac5ec395ae5f5a26a0b7 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Tue, 8 Jul 2025 12:15:22 +0100 Subject: [PATCH 2/2] test(pci): add unit tests for Bdf Add some unit tests to cover PciBdf parsing, conversion, and (de)serialization. Signed-off-by: Riccardo Mancini --- Cargo.lock | 10 ++++ src/pci/Cargo.toml | 3 + src/pci/src/lib.rs | 145 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 158 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index e1b6f10897c..59b2c726a69 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1057,6 +1057,7 @@ dependencies = [ "libc", "log", "serde", + "serde_test", "thiserror 2.0.12", "vm-allocator", "vm-device", @@ -1338,6 +1339,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_test" +version = "1.0.177" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f901ee573cab6b3060453d2d5f0bae4e6d628c23c0a962ff9b5f1d7c8d4f1ed" +dependencies = [ + "serde", +] + [[package]] name = "shlex" version = "1.3.0" diff --git a/src/pci/Cargo.toml b/src/pci/Cargo.toml index d179854f391..a7ef102acfb 100644 --- a/src/pci/Cargo.toml +++ b/src/pci/Cargo.toml @@ -24,3 +24,6 @@ vm-memory = { version = "0.16.1", features = [ "backend-mmap", "backend-bitmap", ] } + +[dev-dependencies] +serde_test = "1.0.177" diff --git a/src/pci/src/lib.rs b/src/pci/src/lib.rs index f2cf82f81fa..1b9a3a99f76 100644 --- a/src/pci/src/lib.rs +++ b/src/pci/src/lib.rs @@ -204,3 +204,148 @@ impl FromStr for PciBdf { Ok(PciBdf::new(segment, bus, device, function)) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pci_bdf_new() { + let bdf = PciBdf::new(0x1234, 0x56, 0x1f, 0x7); + assert_eq!(bdf.segment(), 0x1234); + assert_eq!(bdf.bus(), 0x56); + assert_eq!(bdf.device(), 0x1f); + assert_eq!(bdf.function(), 0x7); + } + + #[test] + fn test_pci_bdf_from_u32() { + let bdf = PciBdf::from(0x12345678); + assert_eq!(bdf.segment(), 0x1234); + assert_eq!(bdf.bus(), 0x56); + assert_eq!(bdf.device(), 0x0f); + assert_eq!(bdf.function(), 0x0); + } + + #[test] + fn test_pci_bdf_to_u32() { + let bdf = PciBdf::new(0x1234, 0x56, 0x1f, 0x7); + let val: u32 = bdf.into(); + assert_eq!(val, 0x123456ff); + } + + #[test] + fn test_pci_bdf_to_u16() { + let bdf = PciBdf::new(0x1234, 0x56, 0x1f, 0x7); + let val: u16 = bdf.into(); + assert_eq!(val, 0x56ff); + } + + #[test] + fn test_pci_bdf_from_str_valid() { + let bdf = PciBdf::from_str("1234:56:1f.7").unwrap(); + assert_eq!(bdf.segment(), 0x1234); + assert_eq!(bdf.bus(), 0x56); + assert_eq!(bdf.device(), 0x1f); + assert_eq!(bdf.function(), 0x7); + } + + #[test] + fn test_pci_bdf_from_str_zero() { + let bdf = PciBdf::from_str("0000:00:00.0").unwrap(); + assert_eq!(bdf.segment(), 0); + assert_eq!(bdf.bus(), 0); + assert_eq!(bdf.device(), 0); + assert_eq!(bdf.function(), 0); + } + + #[test] + fn test_pci_bdf_from_str_invalid_format() { + assert!(matches!( + PciBdf::from_str("invalid"), + Err(PciBdfParseError::InvalidFormat(_)) + )); + assert!(matches!( + PciBdf::from_str("1234:56"), + Err(PciBdfParseError::InvalidFormat(_)) + )); + assert!(matches!( + PciBdf::from_str("1234:56:78:9a.b"), + Err(PciBdfParseError::InvalidFormat(_)) + )); + } + + #[test] + fn test_pci_bdf_from_str_invalid_hex() { + assert!(matches!( + PciBdf::from_str("xxxx:00:00.0"), + Err(PciBdfParseError::InvalidHex(_)) + )); + assert!(matches!( + PciBdf::from_str("0000:xx:00.0"), + Err(PciBdfParseError::InvalidHex(_)) + )); + assert!(matches!( + PciBdf::from_str("0000:00:xx.0"), + Err(PciBdfParseError::InvalidHex(_)) + )); + assert!(matches!( + PciBdf::from_str("0000:00:00.x"), + Err(PciBdfParseError::InvalidHex(_)) + )); + } + + #[test] + fn test_pci_bdf_display() { + let bdf = PciBdf::new(0x1234, 0x56, 0x1f, 0x7); + assert_eq!(format!("{}", bdf), "1234:56:1f.7"); + } + + #[test] + fn test_pci_bdf_debug() { + let bdf = PciBdf::new(0x1234, 0x56, 0x1f, 0x7); + assert_eq!(format!("{:?}", bdf), "1234:56:1f.7"); + } + + #[test] + fn test_pci_bdf_partial_eq() { + let bdf1 = PciBdf::new(0x1234, 0x56, 0x1f, 0x7); + let bdf2 = PciBdf::new(0x1234, 0x56, 0x1f, 0x7); + let bdf3 = PciBdf::new(0x1234, 0x56, 0x1f, 0x6); + assert_eq!(bdf1, bdf2); + assert_ne!(bdf1, bdf3); + } + + #[test] + fn test_pci_bdf_partial_ord() { + let bdf1 = PciBdf::new(0x1234, 0x56, 0x1f, 0x6); + let bdf2 = PciBdf::new(0x1234, 0x56, 0x1f, 0x7); + assert!(bdf1 < bdf2); + } + + #[test] + fn test_pci_bdf_deserialize_ok() { + // Test deserializer + let visitor = PciBdfVisitor; + let result = visitor + .visit_str::("1234:56:1f.7") + .unwrap(); + assert_eq!(result, PciBdf::new(0x1234, 0x56, 0x1f, 0x7)); + } + + #[test] + fn test_pci_bdf_deserialize_invalid() { + // Test deserializer with invalid input returns error + let visitor = PciBdfVisitor; + assert!(visitor + .visit_str::("invalid") + .is_err()); + } + + #[test] + fn test_pci_bdf_serialize() { + // Test serializer using serde_test + let bdf = PciBdf::new(0x1234, 0x56, 0x1f, 0x7); + serde_test::assert_tokens(&bdf, &[serde_test::Token::Str("1234:56:1f.7")]); + } +}