Skip to content

Commit 33e8556

Browse files
committed
Separate ipma validation into separate function
Also, add compile-time checking against potential overflow.
1 parent aaf14d5 commit 33e8556

File tree

1 file changed

+157
-27
lines changed

1 file changed

+157
-27
lines changed

mp4parse/src/lib.rs

Lines changed: 157 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,21 +1629,136 @@ pub struct AssociatedProperty {
16291629
pub property: ItemProperty,
16301630
}
16311631

1632-
const MAX_IPMA_ASSOCIATION_COUNT: u64 = u8::MAX as u64;
1632+
/// An upper bound which can be used to check overflow at compile time
1633+
trait UpperBounded {
1634+
const MAX: u64;
1635+
}
16331636

1634-
/// Parse an ItemPropertyAssociation box
1635-
/// See HEIF (ISO 23008-12:2017) § 9.3.1
1636-
fn read_ipma<T: Read>(
1637-
src: &mut BMFFBox<T>,
1638-
(version, flags): (u8, u32),
1639-
) -> Result<TryVec<Association>> {
1640-
// Static analysis shows that none of this unchecked arithmetic can fail
1641-
let entry_count: u64 = be_u32(src)?.into();
1642-
let min_entry_bytes: u64 = 1 /* association_count */ + if version == 0 { 2 } else { 4 };
1643-
let total_non_association_bytes = entry_count * min_entry_bytes;
1644-
let total_association_bytes;
1637+
/// Implement type $name as a newtype wrapper around an unsigned int which
1638+
/// implements the UpperBounded trait.
1639+
macro_rules! impl_bounded {
1640+
( $name:ident, $inner:ty ) => {
1641+
#[derive(Clone, Copy)]
1642+
pub struct $name($inner);
1643+
1644+
impl $name {
1645+
pub const fn new(n: $inner) -> Self {
1646+
Self(n)
1647+
}
1648+
1649+
#[allow(dead_code)]
1650+
pub fn get(self) -> $inner {
1651+
self.0
1652+
}
1653+
}
1654+
1655+
impl UpperBounded for $name {
1656+
const MAX: u64 = <$inner>::MAX as u64;
1657+
}
1658+
};
1659+
}
1660+
1661+
/// Implement type $name as a type representing the product of two unsigned ints
1662+
/// which implements the UpperBounded trait.
1663+
macro_rules! impl_bounded_product {
1664+
( $name:ident, $multiplier:ty, $multiplicand:ty, $inner:ty) => {
1665+
#[derive(Clone, Copy)]
1666+
pub struct $name($inner);
1667+
1668+
impl $name {
1669+
pub fn new(value: $inner) -> Self {
1670+
assert!(<$inner>::from(value) <= Self::MAX);
1671+
Self(value)
1672+
}
1673+
1674+
pub fn get(self) -> $inner {
1675+
self.0
1676+
}
1677+
}
1678+
1679+
impl UpperBounded for $name {
1680+
const MAX: u64 = <$multiplier>::MAX * <$multiplicand>::MAX;
1681+
}
1682+
};
1683+
}
1684+
1685+
mod bounded_uints {
1686+
use UpperBounded;
1687+
1688+
impl_bounded!(U8, u8);
1689+
impl_bounded!(U16, u16);
1690+
impl_bounded!(U32, u32);
1691+
impl_bounded!(U64, u64);
1692+
1693+
impl_bounded_product!(U32MulU8, U32, U8, u64);
1694+
impl_bounded_product!(U32MulU16, U32, U16, u64);
1695+
1696+
impl UpperBounded for std::num::NonZeroU8 {
1697+
const MAX: u64 = u8::MAX as u64;
1698+
}
1699+
}
16451700

1646-
if let Some(difference) = src.bytes_left().checked_sub(total_non_association_bytes) {
1701+
use bounded_uints::*;
1702+
1703+
/// Implement the multiplication operator for $lhs * $rhs giving $output, which
1704+
/// is internally represented as $inner. The operation is statically checked
1705+
/// to ensure the product won't overflow $inner, nor exceed <$output>::MAX.
1706+
macro_rules! impl_mul {
1707+
( ($lhs:ty , $rhs:ty) => ($output:ty, $inner:ty) ) => {
1708+
impl std::ops::Mul<$rhs> for $lhs {
1709+
type Output = $output;
1710+
1711+
fn mul(self, rhs: $rhs) -> Self::Output {
1712+
static_assertions::const_assert!(<$output>::MAX <= <$inner>::MAX as u64);
1713+
static_assertions::const_assert!(<$lhs>::MAX * <$rhs>::MAX <= <$output>::MAX);
1714+
1715+
let lhs: $inner = self.get().into();
1716+
let rhs: $inner = rhs.get().into();
1717+
Self::Output::new(lhs.checked_mul(rhs).expect("infallible"))
1718+
}
1719+
}
1720+
};
1721+
}
1722+
1723+
impl_mul!((U8, std::num::NonZeroU8) => (U16, u16));
1724+
impl_mul!((U32, std::num::NonZeroU8) => (U32MulU8, u64));
1725+
impl_mul!((U32, U16) => (U32MulU16, u64));
1726+
1727+
impl std::ops::Add<U32MulU16> for U32MulU8 {
1728+
type Output = U64;
1729+
1730+
fn add(self, rhs: U32MulU16) -> Self::Output {
1731+
static_assertions::const_assert!(U32MulU8::MAX + U32MulU16::MAX < U64::MAX);
1732+
let lhs: u64 = self.get().into();
1733+
let rhs: u64 = rhs.get().into();
1734+
Self::Output::new(lhs.checked_add(rhs).expect("infallible"))
1735+
}
1736+
}
1737+
1738+
const MAX_IPMA_ASSOCIATION_COUNT: U8 = U8::new(u8::MAX);
1739+
1740+
/// After reading only the `entry_count` field of an ipma box, we can check its
1741+
/// basic validity and calculate (assuming validity) the number of associations
1742+
/// which will be contained (allowing preallocation of the storage).
1743+
/// All the arithmetic is compile-time verified to not overflow via supporting
1744+
/// types implementing the UpperBounded trait. Types are declared explicitly to
1745+
/// show there isn't any accidental inference to primitive types.
1746+
///
1747+
/// See HEIF (ISO 23008-12:2017) § 9.3.1
1748+
fn calculate_ipma_total_associations(
1749+
version: u8,
1750+
bytes_left: u64,
1751+
entry_count: U32,
1752+
num_association_bytes: std::num::NonZeroU8,
1753+
) -> Result<usize> {
1754+
let min_entry_bytes =
1755+
std::num::NonZeroU8::new(1 /* association_count */ + if version == 0 { 2 } else { 4 })
1756+
.unwrap();
1757+
1758+
let total_non_association_bytes: U32MulU8 = entry_count * min_entry_bytes;
1759+
let total_association_bytes: u64;
1760+
1761+
if let Some(difference) = bytes_left.checked_sub(total_non_association_bytes.get()) {
16471762
// All the storage for the `essential` and `property_index` parts (assuming a valid ipma box size)
16481763
total_association_bytes = difference;
16491764
} else {
@@ -1652,25 +1767,38 @@ fn read_ipma<T: Read>(
16521767
));
16531768
}
16541769

1655-
let num_association_bytes = if flags & 1 == 1 { 2 } else { 1 };
1656-
1657-
// total_association_bytes must be a multiple of num_association_bytes
1658-
if total_association_bytes % num_association_bytes != 0 {
1659-
return Err(Error::InvalidData("ipma box has invalid size"));
1660-
}
1661-
1662-
let max_association_bytes_per_entry = MAX_IPMA_ASSOCIATION_COUNT * num_association_bytes;
1663-
let max_total_association_bytes = entry_count * max_association_bytes_per_entry;
1664-
let max_bytes_left = total_non_association_bytes + max_total_association_bytes;
1770+
let max_association_bytes_per_entry: U16 = MAX_IPMA_ASSOCIATION_COUNT * num_association_bytes;
1771+
let max_total_association_bytes: U32MulU16 = entry_count * max_association_bytes_per_entry;
1772+
let max_bytes_left: U64 = total_non_association_bytes + max_total_association_bytes;
16651773

1666-
if src.bytes_left() > max_bytes_left {
1774+
if bytes_left > max_bytes_left.get() {
16671775
return Err(Error::InvalidData(
16681776
"ipma box exceeds maximum size for entry_count",
16691777
));
16701778
}
16711779

1672-
let total_associations = total_association_bytes / num_association_bytes;
1673-
let mut associations = TryVec::with_capacity(total_associations.try_into()?)?;
1780+
let total_associations: u64 = total_association_bytes / u64::from(num_association_bytes.get());
1781+
1782+
Ok(total_associations.try_into()?)
1783+
}
1784+
1785+
/// Parse an ItemPropertyAssociation box
1786+
/// See HEIF (ISO 23008-12:2017) § 9.3.1
1787+
fn read_ipma<T: Read>(
1788+
src: &mut BMFFBox<T>,
1789+
(version, flags): (u8, u32),
1790+
) -> Result<TryVec<Association>> {
1791+
let entry_count = be_u32(src)?;
1792+
let num_association_bytes =
1793+
std::num::NonZeroU8::new(if flags & 1 == 1 { 2 } else { 1 }).unwrap();
1794+
1795+
let total_associations = calculate_ipma_total_associations(
1796+
version,
1797+
src.bytes_left(),
1798+
U32::new(entry_count),
1799+
num_association_bytes,
1800+
)?;
1801+
let mut associations = TryVec::with_capacity(total_associations)?;
16741802

16751803
for _ in 0..entry_count {
16761804
let item_id = if version == 0 {
@@ -1680,7 +1808,9 @@ fn read_ipma<T: Read>(
16801808
};
16811809
let association_count = src.read_u8()?;
16821810
for _ in 0..association_count {
1683-
let association = src.take(num_association_bytes).read_into_try_vec()?;
1811+
let association = src
1812+
.take(num_association_bytes.get().into())
1813+
.read_into_try_vec()?;
16841814
let mut association = BitReader::new(association.as_slice());
16851815
let essential = association.read_bool()?;
16861816
let property_index = association.read_u16(association.remaining().try_into()?)?;

0 commit comments

Comments
 (0)