Skip to content

Commit 6f7fce7

Browse files
committed
fix(bitfield): prevent construction of out-of-range bitfields
1. Produce an error when collecting a bitfield with an invalid bit, or attempting to set an invalid bit. 2. Make sensitive functions crate-private.
1 parent 292b7ff commit 6f7fce7

File tree

3 files changed

+170
-44
lines changed

3 files changed

+170
-44
lines changed

ipld/bitfield/src/iter/mod.rs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ where
133133
I: Iterator<Item = Range<u64>>,
134134
{
135135
/// Creates a new `Ranges` instance.
136+
///
137+
/// WARNING: This is asserting that the underlying iterator obeys the `RangeIterator`
138+
/// constraints. Using this incorrectly could lead to panics, etc.
136139
pub fn new<II>(iter: II) -> Self
137140
where
138141
II: IntoIterator<IntoIter = I, Item = Range<u64>>,
@@ -155,16 +158,21 @@ where
155158
impl<I> RangeIterator for Ranges<I> where I: Iterator<Item = Range<u64>> {}
156159

157160
/// Returns a `RangeIterator` which ranges contain the values from the provided iterator.
158-
/// The values need to be in ascending order — if not, the returned iterator may not satisfy
159-
/// all `RangeIterator` requirements.
160-
pub fn ranges_from_bits(bits: impl IntoIterator<Item = u64>) -> impl RangeIterator {
161+
/// The values need to be in ascending order and may not include u64::MAX. Otherwise, the iterator
162+
/// will panic.
163+
pub(crate) fn ranges_from_bits(bits: impl IntoIterator<Item = u64>) -> impl RangeIterator {
161164
let mut iter = bits.into_iter().peekable();
162165

163166
Ranges::new(iter::from_fn(move || {
164167
let start = iter.next()?;
165-
let mut end = start + 1;
166-
while iter.peek() == Some(&end) {
167-
end += 1;
168+
let mut end = start.checked_add(1).expect("bitfield overflow");
169+
while let Some(&next) = iter.peek() {
170+
if next < end {
171+
panic!("out of order bitfield")
172+
} else if next > end {
173+
break;
174+
}
175+
end = end.checked_add(1).expect("bitfield overflow");
168176
iter.next();
169177
}
170178
Some(start..end)

ipld/bitfield/src/lib.rs

Lines changed: 108 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,36 @@
11
// Copyright 2019-2022 ChainSafe Systems
22
// SPDX-License-Identifier: Apache-2.0, MIT
33

4+
// disable this lint because it can actually cause performance regressions, and usually leads to
5+
// hard to read code.
6+
#![allow(clippy::comparison_chain)]
7+
48
pub mod iter;
59
mod range;
610
mod rleplus;
711
mod unvalidated;
812

913
use std::collections::BTreeSet;
10-
use std::iter::FromIterator;
1114
use std::ops::{
1215
BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Range, Sub, SubAssign,
1316
};
1417

1518
use iter::{ranges_from_bits, RangeIterator};
1619
pub(crate) use range::RangeSize;
1720
pub use rleplus::Error;
21+
use thiserror::Error;
1822
pub use unvalidated::{UnvalidatedBitField, Validate};
1923

24+
#[derive(Clone, Error, Debug)]
25+
#[error("bitfields may not include u64::MAX")]
26+
pub struct OutOfRangeError;
27+
28+
impl From<OutOfRangeError> for Error {
29+
fn from(_: OutOfRangeError) -> Self {
30+
Error::RLEOverflow
31+
}
32+
}
33+
2034
/// A bit field with buffered insertion/removal that serializes to/from RLE+. Similar to
2135
/// `HashSet<u64>`, but more memory-efficient when long runs of 1s and 0s are present.
2236
#[derive(Debug, Default, Clone)]
@@ -35,22 +49,79 @@ impl PartialEq for BitField {
3549
}
3650
}
3751

38-
impl FromIterator<u64> for BitField {
39-
fn from_iter<I: IntoIterator<Item = u64>>(iter: I) -> Self {
40-
let mut vec: Vec<_> = iter.into_iter().collect();
41-
if vec.is_empty() {
42-
Self::new()
52+
/// Possibly a valid bitfield, or an out of bounds error. Ideally we'd just use a result, but we
53+
/// can't implement [`FromIterator`] on a [`Result`] due to coherence.
54+
///
55+
/// You probably want to call [`BitField::try_from_bits`] instead of using this directly.
56+
#[doc(hidden)]
57+
pub enum MaybeBitField {
58+
/// A valid bitfield.
59+
Ok(BitField),
60+
/// Out of bounds.
61+
OutOfBounds,
62+
}
63+
64+
impl MaybeBitField {
65+
pub fn unwrap(self) -> BitField {
66+
use MaybeBitField::*;
67+
match self {
68+
Ok(bf) => bf,
69+
OutOfBounds => panic!("bitfield bit out of bounds"),
70+
}
71+
}
72+
73+
pub fn expect(self, message: &str) -> BitField {
74+
use MaybeBitField::*;
75+
match self {
76+
Ok(bf) => bf,
77+
OutOfBounds => panic!("{}", message),
78+
}
79+
}
80+
}
81+
82+
impl TryFrom<MaybeBitField> for BitField {
83+
type Error = OutOfRangeError;
84+
85+
fn try_from(value: MaybeBitField) -> Result<Self, Self::Error> {
86+
match value {
87+
MaybeBitField::Ok(bf) => Ok(bf),
88+
MaybeBitField::OutOfBounds => Err(OutOfRangeError),
89+
}
90+
}
91+
}
92+
93+
impl FromIterator<bool> for MaybeBitField {
94+
fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> MaybeBitField {
95+
let mut iter = iter.into_iter().fuse();
96+
let bits = (0u64..u64::MAX)
97+
.zip(&mut iter)
98+
.filter(|&(_, b)| b)
99+
.map(|(i, _)| i);
100+
let bf = BitField::from_ranges(ranges_from_bits(bits));
101+
102+
// Now, if we have remaining bits, raise an error. Otherwise, we're good.
103+
if iter.next().is_some() {
104+
MaybeBitField::OutOfBounds
43105
} else {
44-
vec.sort_unstable();
45-
Self::from_ranges(ranges_from_bits(vec))
106+
MaybeBitField::Ok(bf)
46107
}
47108
}
48109
}
49110

50-
impl FromIterator<bool> for BitField {
51-
fn from_iter<I: IntoIterator<Item = bool>>(iter: I) -> Self {
52-
let bits = (0u64..).zip(iter).filter(|&(_, b)| b).map(|(i, _)| i);
53-
Self::from_ranges(ranges_from_bits(bits))
111+
impl FromIterator<u64> for MaybeBitField {
112+
fn from_iter<T: IntoIterator<Item = u64>>(iter: T) -> MaybeBitField {
113+
let mut vec: Vec<_> = iter.into_iter().collect();
114+
if vec.is_empty() {
115+
MaybeBitField::Ok(BitField::new())
116+
} else {
117+
vec.sort_unstable();
118+
vec.dedup();
119+
if vec.last() == Some(&u64::MAX) {
120+
MaybeBitField::OutOfBounds
121+
} else {
122+
MaybeBitField::Ok(BitField::from_ranges(ranges_from_bits(vec)))
123+
}
124+
}
54125
}
55126
}
56127

@@ -68,10 +139,33 @@ impl BitField {
68139
}
69140
}
70141

71-
/// Adds the bit at a given index to the bit field.
142+
/// Tries to create a new bitfield from a bit iterator. It fails if the resulting bitfield would
143+
/// contain values not in the range `0..u64::MAX` (non-inclusive).
144+
pub fn try_from_bits<I>(iter: I) -> Result<Self, OutOfRangeError>
145+
where
146+
I: IntoIterator,
147+
MaybeBitField: FromIterator<I::Item>,
148+
{
149+
iter.into_iter().collect::<MaybeBitField>().try_into()
150+
}
151+
152+
/// Adds the bit at a given index to the bit field, panicing if it's out of range.
153+
///
154+
/// # Panics
155+
///
156+
/// Panics if `bit` is `u64::MAX`.
72157
pub fn set(&mut self, bit: u64) {
158+
self.try_set(bit).unwrap()
159+
}
160+
161+
/// Adds the bit at a given index to the bit field, returning an error if it's out of range.
162+
pub fn try_set(&mut self, bit: u64) -> Result<(), OutOfRangeError> {
163+
if bit == u64::MAX {
164+
return Err(OutOfRangeError);
165+
}
73166
self.unset.remove(&bit);
74167
self.set.insert(bit);
168+
Ok(())
75169
}
76170

77171
/// Removes the bit at a given index from the bit field.
@@ -328,7 +422,7 @@ macro_rules! bitfield {
328422
std::iter::once($head != 0_u32).chain(bitfield!(@iter $($tail),*))
329423
};
330424
($($val:literal),* $(,)?) => {
331-
bitfield!(@iter $($val),*).collect::<$crate::BitField>()
425+
bitfield!(@iter $($val),*).collect::<$crate::MaybeBitField>().unwrap()
332426
};
333427
}
334428

ipld/bitfield/tests/bitfield_tests.rs

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// SPDX-License-Identifier: Apache-2.0, MIT
33

44
use std::collections::HashSet;
5-
use std::iter::FromIterator;
65

76
use fvm_ipld_bitfield::{bitfield, BitField};
87
use fvm_shared::encoding;
@@ -17,7 +16,7 @@ fn random_indices(range: u64, seed: u64) -> Vec<u64> {
1716
#[test]
1817
fn bitfield_slice() {
1918
let vals = random_indices(10000, 2);
20-
let bf: BitField = vals.iter().copied().collect();
19+
let bf = BitField::try_from_bits(vals.iter().copied()).unwrap();
2120

2221
let slice = bf.slice(600, 500).unwrap();
2322
let out_vals: Vec<_> = slice.iter().collect();
@@ -38,7 +37,7 @@ fn bitfield_slice_small() {
3837
let vals = [1, 5, 6, 7, 10, 11, 12, 15];
3938

4039
let test_permutations = |start: usize, count: usize| {
41-
let bf: BitField = vals.iter().copied().collect();
40+
let bf = BitField::try_from_bits(vals.iter().copied()).unwrap();
4241
let sl = bf.slice(start as u64, count as u64).unwrap();
4342
let exp = &vals[start..start + count];
4443
let out: Vec<_> = sl.iter().collect();
@@ -56,8 +55,8 @@ fn set_up_test_bitfields() -> (Vec<u64>, Vec<u64>, BitField, BitField) {
5655
let a = random_indices(100, 1);
5756
let b = random_indices(100, 2);
5857

59-
let bf_a: BitField = a.iter().copied().collect();
60-
let bf_b: BitField = b.iter().copied().collect();
58+
let bf_a = BitField::try_from_bits(a.iter().copied()).unwrap();
59+
let bf_b = BitField::try_from_bits(b.iter().copied()).unwrap();
6160

6261
(a, b, bf_a, bf_b)
6362
}
@@ -101,50 +100,62 @@ fn bitfield_difference() {
101100
// Ported test from go impl (specs-actors)
102101
#[test]
103102
fn subtract_more() {
104-
let have = BitField::from_iter(vec![5, 6, 8, 10, 11, 13, 14, 17]);
105-
let s1 = &BitField::from_iter(vec![5, 6]) - &have;
106-
let s2 = &BitField::from_iter(vec![8, 10]) - &have;
107-
let s3 = &BitField::from_iter(vec![11, 13]) - &have;
108-
let s4 = &BitField::from_iter(vec![14, 17]) - &have;
103+
let have = BitField::try_from_bits(vec![5, 6, 8, 10, 11, 13, 14, 17]).unwrap();
104+
let s1 = &BitField::try_from_bits(vec![5, 6]).unwrap() - &have;
105+
let s2 = &BitField::try_from_bits(vec![8, 10]).unwrap() - &have;
106+
let s3 = &BitField::try_from_bits(vec![11, 13]).unwrap() - &have;
107+
let s4 = &BitField::try_from_bits(vec![14, 17]).unwrap() - &have;
109108

110109
let u = BitField::union(&[s1, s2, s3, s4]);
111110
assert_eq!(u.len(), 0);
112111
}
113112

114113
#[test]
115114
fn contains_any() {
116-
assert!(!BitField::from_iter(vec![0, 4]).contains_any(&BitField::from_iter(vec![1, 3, 5])));
115+
assert!(!BitField::try_from_bits(vec![0, 4])
116+
.unwrap()
117+
.contains_any(&BitField::try_from_bits(vec![1, 3, 5]).unwrap()));
117118

118-
assert!(BitField::from_iter(vec![0, 2, 5, 6]).contains_any(&BitField::from_iter(vec![1, 3, 5])));
119+
assert!(BitField::try_from_bits(vec![0, 2, 5, 6])
120+
.unwrap()
121+
.contains_any(&BitField::try_from_bits(vec![1, 3, 5]).unwrap()));
119122

120-
assert!(BitField::from_iter(vec![1, 2, 3]).contains_any(&BitField::from_iter(vec![1, 2, 3])));
123+
assert!(BitField::try_from_bits(vec![1, 2, 3])
124+
.unwrap()
125+
.contains_any(&BitField::try_from_bits(vec![1, 2, 3]).unwrap()));
121126
}
122127

123128
#[test]
124129
fn contains_all() {
125-
assert!(
126-
!BitField::from_iter(vec![0, 2, 4]).contains_all(&BitField::from_iter(vec![0, 2, 4, 5]))
127-
);
130+
assert!(!BitField::try_from_bits(vec![0, 2, 4])
131+
.unwrap()
132+
.contains_all(&BitField::try_from_bits(vec![0, 2, 4, 5]).unwrap()));
128133

129-
assert!(BitField::from_iter(vec![0, 2, 4, 5]).contains_all(&BitField::from_iter(vec![0, 2, 4])));
134+
assert!(BitField::try_from_bits(vec![0, 2, 4, 5])
135+
.unwrap()
136+
.contains_all(&BitField::try_from_bits(vec![0, 2, 4]).unwrap()));
130137

131-
assert!(BitField::from_iter(vec![1, 2, 3]).contains_all(&BitField::from_iter(vec![1, 2, 3])));
138+
assert!(BitField::try_from_bits(vec![1, 2, 3])
139+
.unwrap()
140+
.contains_all(&BitField::try_from_bits(vec![1, 2, 3]).unwrap()));
132141
}
133142

134143
#[test]
135144
fn bit_ops() {
136-
let a = &BitField::from_iter(vec![1, 2, 3]) & &BitField::from_iter(vec![1, 3, 4]);
145+
let a = &BitField::try_from_bits(vec![1, 2, 3]).unwrap()
146+
& &BitField::try_from_bits(vec![1, 3, 4]).unwrap();
137147
assert_eq!(a.iter().collect::<Vec<_>>(), &[1, 3]);
138148

139-
let mut a = BitField::from_iter(vec![1, 2, 3]);
140-
a &= &BitField::from_iter(vec![1, 3, 4]);
149+
let mut a = BitField::try_from_bits(vec![1, 2, 3]).unwrap();
150+
a &= &BitField::try_from_bits(vec![1, 3, 4]).unwrap();
141151
assert_eq!(a.iter().collect::<Vec<_>>(), &[1, 3]);
142152

143-
let a = &BitField::from_iter(vec![1, 2, 3]) | &BitField::from_iter(vec![1, 3, 4]);
153+
let a = &BitField::try_from_bits(vec![1, 2, 3]).unwrap()
154+
| &BitField::try_from_bits(vec![1, 3, 4]).unwrap();
144155
assert_eq!(a.iter().collect::<Vec<_>>(), &[1, 2, 3, 4]);
145156

146-
let mut a = BitField::from_iter(vec![1, 2, 3]);
147-
a |= &BitField::from_iter(vec![1, 3, 4]);
157+
let mut a = BitField::try_from_bits(vec![1, 2, 3]).unwrap();
158+
a |= &BitField::try_from_bits(vec![1, 3, 4]).unwrap();
148159
assert_eq!(a.iter().collect::<Vec<_>>(), &[1, 2, 3, 4]);
149160
}
150161

@@ -209,3 +220,16 @@ fn padding() {
209220
let deserialized: BitField = encoding::from_slice(&cbor).unwrap();
210221
assert_eq!(deserialized, bf);
211222
}
223+
224+
#[test]
225+
fn exceeds_bitfield_range() {
226+
let mut bf = BitField::new();
227+
bf.try_set(u64::MAX)
228+
.expect_err("expected setting u64::MAX to fail");
229+
bf.try_set(u64::MAX - 1)
230+
.expect("expected setting u64::MAX-1 to succeed");
231+
BitField::try_from_bits([0, 1, 4, 99, u64::MAX])
232+
.expect_err("expected setting u64::MAX to fail");
233+
BitField::try_from_bits([0, 1, 4, 99, u64::MAX - 1])
234+
.expect("expected setting u64::MAX-1 to succeed");
235+
}

0 commit comments

Comments
 (0)