Skip to content

Commit cc1ffe8

Browse files
authored
const-oid: fix (and simplify) base 128 encoder (#1600)
This changes the base 128 decoder to calculate the length of a base 128-encoded arc and then iterates over each byte, computing the value for that byte, without any mutable state other than the position. It also refactors the unit tests and adds an example extracted from proptest failures. The new implementation passes that test.
1 parent 72b4894 commit cc1ffe8

File tree

6 files changed

+123
-109
lines changed

6 files changed

+123
-109
lines changed

const-oid/src/arcs.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ pub(crate) const ARC_MAX_SECOND: Arc = 39;
2626

2727
/// Maximum number of bytes supported in an arc.
2828
///
29-
/// Note that OIDs are LEB128 encoded (i.e. base 128), so we must consider how many bytes are
30-
/// required when each byte can only represent 7-bits of the input.
29+
/// Note that OIDs are base 128 encoded (with continuation bits), so we must consider how many bytes
30+
/// are required when each byte can only represent 7-bits of the input.
3131
const ARC_MAX_BYTES: usize = (Arc::BITS as usize).div_ceil(7);
3232

3333
/// Maximum value of the last byte in an arc.

const-oid/src/checked.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,17 @@ macro_rules! checked_add {
55
($a:expr, $b:expr) => {
66
match $a.checked_add($b) {
77
Some(n) => n,
8-
None => return Err(Error::Length),
8+
None => return Err(Error::Overflow),
9+
}
10+
};
11+
}
12+
13+
/// `const fn`-friendly checked addition helper.
14+
macro_rules! checked_sub {
15+
($a:expr, $b:expr) => {
16+
match $a.checked_sub($b) {
17+
Some(n) => n,
18+
None => return Err(Error::Overflow),
919
}
1020
};
1121
}

const-oid/src/encoder.rs

Lines changed: 40 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ enum State {
2424
/// Initial state - no arcs yet encoded.
2525
Initial,
2626

27-
/// First arc parsed.
27+
/// First arc has been supplied and stored as the wrapped [`Arc`].
2828
FirstArc(Arc),
2929

3030
/// Encoding base 128 body of the OID.
@@ -83,10 +83,7 @@ impl<const MAX_SIZE: usize> Encoder<MAX_SIZE> {
8383
self.cursor = 1;
8484
Ok(self)
8585
}
86-
State::Body => {
87-
let nbytes = base128_len(arc);
88-
self.encode_base128(arc, nbytes)
89-
}
86+
State::Body => self.encode_base128(arc),
9087
}
9188
}
9289

@@ -104,64 +101,48 @@ impl<const MAX_SIZE: usize> Encoder<MAX_SIZE> {
104101
Ok(ObjectIdentifier { ber })
105102
}
106103

107-
/// Encode a single byte of a Base 128 value.
108-
const fn encode_base128(mut self, n: u32, remaining_len: usize) -> Result<Self> {
109-
if self.cursor >= MAX_SIZE {
104+
/// Encode base 128.
105+
const fn encode_base128(mut self, arc: Arc) -> Result<Self> {
106+
let nbytes = base128_len(arc);
107+
let end_pos = checked_add!(self.cursor, nbytes);
108+
109+
if end_pos > MAX_SIZE {
110110
return Err(Error::Length);
111111
}
112112

113-
let mask = if remaining_len > 0 { 0b10000000 } else { 0 };
114-
let (hi, lo) = split_hi_bits(n);
115-
self.bytes[self.cursor] = hi | mask;
116-
self.cursor = checked_add!(self.cursor, 1);
117-
118-
match remaining_len.checked_sub(1) {
119-
Some(len) => self.encode_base128(lo, len),
120-
None => Ok(self),
113+
let mut i = 0;
114+
while i < nbytes {
115+
// TODO(tarcieri): use `?` when stable in `const fn`
116+
self.bytes[self.cursor] = match base128_byte(arc, i, nbytes) {
117+
Ok(byte) => byte,
118+
Err(e) => return Err(e),
119+
};
120+
self.cursor = checked_add!(self.cursor, 1);
121+
i = checked_add!(i, 1);
121122
}
123+
124+
Ok(self)
122125
}
123126
}
124127

125-
/// Compute the length - 1 of an arc when encoded in base 128.
128+
/// Compute the length of an arc when encoded in base 128.
126129
const fn base128_len(arc: Arc) -> usize {
127130
match arc {
128-
0..=0x7f => 0,
129-
0x80..=0x3fff => 1,
130-
0x4000..=0x1fffff => 2,
131-
0x200000..=0x1fffffff => 3,
132-
_ => 4,
131+
0..=0x7f => 1,
132+
0x80..=0x3fff => 2,
133+
0x4000..=0x1fffff => 3,
134+
0x200000..=0x1fffffff => 4,
135+
_ => 5,
133136
}
134137
}
135138

136-
/// Split the highest 7-bits of an [`Arc`] from the rest of an arc.
137-
///
138-
/// Returns: `(hi, lo)`
139-
#[inline]
140-
const fn split_hi_bits(arc: Arc) -> (u8, Arc) {
141-
if arc < 0x80 {
142-
return (arc as u8, 0);
143-
}
144-
145-
let hi_bit = match 32u32.checked_sub(arc.leading_zeros()) {
146-
Some(bit) => bit,
147-
None => unreachable!(),
148-
};
149-
150-
let hi_bit_mod7 = hi_bit % 7;
151-
let upper_bit_offset = if hi_bit > 0 && hi_bit_mod7 == 0 {
152-
7
153-
} else {
154-
hi_bit_mod7
155-
};
156-
157-
let upper_bit_pos = match hi_bit.checked_sub(upper_bit_offset) {
158-
Some(bit) => bit,
159-
None => unreachable!(),
160-
};
161-
162-
let upper_bits = arc >> upper_bit_pos;
163-
let lower_bits = arc ^ (upper_bits << upper_bit_pos);
164-
(upper_bits as u8, lower_bits)
139+
/// Compute the big endian base 128 encoding of the given [`Arc`] at the given byte.
140+
const fn base128_byte(arc: Arc, pos: usize, total: usize) -> Result<u8> {
141+
debug_assert!(pos < total);
142+
let last_byte = checked_add!(pos, 1) == total;
143+
let mask = if last_byte { 0 } else { 0b10000000 };
144+
let shift = checked_sub!(checked_sub!(total, pos), 1) * 7;
145+
Ok(((arc >> shift) & 0b1111111) as u8 | mask)
165146
}
166147

167148
#[cfg(test)]
@@ -174,9 +155,14 @@ mod tests {
174155
const EXAMPLE_OID_BER: &[u8] = &hex!("2A8648CE3D0201");
175156

176157
#[test]
177-
fn split_hi_bits_with_gaps() {
178-
assert_eq!(super::split_hi_bits(0x3a00002), (0x1d, 0x2));
179-
assert_eq!(super::split_hi_bits(0x3a08000), (0x1d, 0x8000));
158+
fn base128_byte() {
159+
let example_arc = 0x44332211;
160+
assert_eq!(super::base128_len(example_arc), 5);
161+
assert_eq!(super::base128_byte(example_arc, 0, 5).unwrap(), 0b10000100);
162+
assert_eq!(super::base128_byte(example_arc, 1, 5).unwrap(), 0b10100001);
163+
assert_eq!(super::base128_byte(example_arc, 2, 5).unwrap(), 0b11001100);
164+
assert_eq!(super::base128_byte(example_arc, 3, 5).unwrap(), 0b11000100);
165+
assert_eq!(super::base128_byte(example_arc, 4, 5).unwrap(), 0b10001);
180166
}
181167

182168
#[test]

const-oid/src/error.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ pub enum Error {
3737
/// OID length is invalid (too short or too long).
3838
Length,
3939

40+
/// Arithmetic overflow (or underflow) errors.
41+
Overflow,
42+
4043
/// Repeated `..` characters in input data.
4144
RepeatedDot,
4245

@@ -56,6 +59,7 @@ impl Error {
5659
Error::DigitExpected { .. } => panic!("OID expected to start with digit"),
5760
Error::Empty => panic!("OID value is empty"),
5861
Error::Length => panic!("OID length invalid"),
62+
Error::Overflow => panic!("arithmetic calculation overflowed"),
5963
Error::RepeatedDot => panic!("repeated consecutive '..' characters in OID"),
6064
Error::TrailingDot => panic!("OID ends with invalid trailing '.'"),
6165
}
@@ -73,6 +77,7 @@ impl fmt::Display for Error {
7377
}
7478
Error::Empty => f.write_str("OID value is empty"),
7579
Error::Length => f.write_str("OID length invalid"),
80+
Error::Overflow => f.write_str("arithmetic calculation overflowed"),
7681
Error::RepeatedDot => f.write_str("repeated consecutive '..' characters in OID"),
7782
Error::TrailingDot => f.write_str("OID ends with invalid trailing '.'"),
7883
}

const-oid/src/parser.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ impl Parser {
6363
self.current_arc = match arc.checked_mul(10) {
6464
Some(arc) => match arc.checked_add(digit as Arc) {
6565
None => return Err(Error::ArcTooBig),
66-
arc => arc,
66+
Some(arc) => Some(arc),
6767
},
6868
None => return Err(Error::ArcTooBig),
6969
};

const-oid/tests/oid.rs

Lines changed: 64 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ const EXAMPLE_OID_LARGE_ARC_0: ObjectIdentifier =
2929
ObjectIdentifier::new_unwrap(crate::EXAMPLE_OID_LARGE_ARC_0_STR);
3030

3131
/// Example OID value with a large arc
32-
const EXAMPLE_OID_LARGE_ARC_1_STR: &str = "0.9.2342.19200300.100.1.1";
33-
const EXAMPLE_OID_LARGE_ARC_1_BER: &[u8] = &hex!("0992268993F22C640101");
32+
const EXAMPLE_OID_LARGE_ARC_1_STR: &str = "1.1.1.60817410.1";
33+
const EXAMPLE_OID_LARGE_ARC_1_BER: &[u8] = &hex!("29019D80800201");
3434
const EXAMPLE_OID_LARGE_ARC_1: ObjectIdentifier =
3535
ObjectIdentifier::new_unwrap(EXAMPLE_OID_LARGE_ARC_1_STR);
3636

@@ -45,54 +45,69 @@ pub fn oid(s: &str) -> ObjectIdentifier {
4545
ObjectIdentifier::new(s).unwrap()
4646
}
4747

48+
/// 0.9.2342.19200300.100.1.1
4849
#[test]
49-
fn from_bytes() {
50-
// 0.9.2342.19200300.100.1.1
51-
let oid0 = ObjectIdentifier::from_bytes(EXAMPLE_OID_0_BER).unwrap();
52-
assert_eq!(oid0.arc(0).unwrap(), 0);
53-
assert_eq!(oid0.arc(1).unwrap(), 9);
54-
assert_eq!(oid0.arc(2).unwrap(), 2342);
55-
assert_eq!(oid0, EXAMPLE_OID_0);
50+
fn from_bytes_oid_0() {
51+
let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_0_BER).unwrap();
52+
assert_eq!(oid, EXAMPLE_OID_0);
53+
assert_eq!(oid.arc(0).unwrap(), 0);
54+
assert_eq!(oid.arc(1).unwrap(), 9);
55+
assert_eq!(oid.arc(2).unwrap(), 2342);
56+
}
5657

57-
// 1.2.840.10045.2.1
58-
let oid1 = ObjectIdentifier::from_bytes(EXAMPLE_OID_1_BER).unwrap();
59-
assert_eq!(oid1.arc(0).unwrap(), 1);
60-
assert_eq!(oid1.arc(1).unwrap(), 2);
61-
assert_eq!(oid1.arc(2).unwrap(), 840);
62-
assert_eq!(oid1, EXAMPLE_OID_1);
58+
/// 1.2.840.10045.2.1
59+
#[test]
60+
fn from_bytes_oid_1() {
61+
let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_1_BER).unwrap();
62+
assert_eq!(oid, EXAMPLE_OID_1);
63+
assert_eq!(oid.arc(0).unwrap(), 1);
64+
assert_eq!(oid.arc(1).unwrap(), 2);
65+
assert_eq!(oid.arc(2).unwrap(), 840);
66+
}
6367

64-
// 2.16.840.1.101.3.4.1.42
65-
let oid2 = ObjectIdentifier::from_bytes(EXAMPLE_OID_2_BER).unwrap();
66-
assert_eq!(oid2.arc(0).unwrap(), 2);
67-
assert_eq!(oid2.arc(1).unwrap(), 16);
68-
assert_eq!(oid2.arc(2).unwrap(), 840);
69-
assert_eq!(oid2, EXAMPLE_OID_2);
68+
/// 2.16.840.1.101.3.4.1.42
69+
#[test]
70+
fn from_bytes_oid_2() {
71+
let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_2_BER).unwrap();
72+
assert_eq!(oid, EXAMPLE_OID_2);
73+
assert_eq!(oid.arc(0).unwrap(), 2);
74+
assert_eq!(oid.arc(1).unwrap(), 16);
75+
assert_eq!(oid.arc(2).unwrap(), 840);
76+
}
7077

71-
// 1.2.16384
72-
let oid_largearc0 = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_0_BER).unwrap();
73-
assert_eq!(oid_largearc0.arc(0).unwrap(), 1);
74-
assert_eq!(oid_largearc0.arc(1).unwrap(), 2);
75-
assert_eq!(oid_largearc0.arc(2).unwrap(), 16384);
76-
assert_eq!(oid_largearc0.arc(3), None);
77-
assert_eq!(oid_largearc0, EXAMPLE_OID_LARGE_ARC_0);
78+
/// 1.2.16384
79+
#[test]
80+
fn from_bytes_oid_largearc_0() {
81+
let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_0_BER).unwrap();
82+
assert_eq!(oid, EXAMPLE_OID_LARGE_ARC_0);
83+
assert_eq!(oid.arc(0).unwrap(), 1);
84+
assert_eq!(oid.arc(1).unwrap(), 2);
85+
assert_eq!(oid.arc(2).unwrap(), 16384);
86+
assert_eq!(oid.arc(3), None);
87+
}
7888

79-
// 0.9.2342.19200300.100.1.1
80-
let oid_largearc1 = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_1_BER).unwrap();
81-
assert_eq!(oid_largearc1.arc(0).unwrap(), 0);
82-
assert_eq!(oid_largearc1.arc(1).unwrap(), 9);
83-
assert_eq!(oid_largearc1.arc(2).unwrap(), 2342);
84-
assert_eq!(oid_largearc1.arc(3).unwrap(), 19200300);
85-
assert_eq!(oid_largearc1.arc(4).unwrap(), 100);
86-
assert_eq!(oid_largearc1.arc(5).unwrap(), 1);
87-
assert_eq!(oid_largearc1.arc(6).unwrap(), 1);
88-
assert_eq!(oid_largearc1, EXAMPLE_OID_LARGE_ARC_1);
89+
/// 1.1.1.60817410.1
90+
#[test]
91+
fn from_bytes_oid_largearc_1() {
92+
let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_1_BER).unwrap();
93+
assert_eq!(oid, EXAMPLE_OID_LARGE_ARC_1);
94+
assert_eq!(oid.arc(0).unwrap(), 1);
95+
assert_eq!(oid.arc(1).unwrap(), 1);
96+
assert_eq!(oid.arc(2).unwrap(), 1);
97+
assert_eq!(oid.arc(3).unwrap(), 60817410);
98+
assert_eq!(oid.arc(4).unwrap(), 1);
99+
assert_eq!(oid.arc(5), None);
100+
}
89101

90-
// 1.2.4294967295
91-
let oid_largearc2 = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_2_BER).unwrap();
92-
assert_eq!(oid_largearc2.arc(0).unwrap(), 1);
93-
assert_eq!(oid_largearc2.arc(1).unwrap(), 2);
94-
assert_eq!(oid_largearc2.arc(2).unwrap(), 4294967295);
95-
assert_eq!(oid_largearc2, EXAMPLE_OID_LARGE_ARC_2);
102+
/// 1.2.4294967295
103+
#[test]
104+
fn from_bytes_oid_largearc_2() {
105+
let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_2_BER).unwrap();
106+
assert_eq!(oid, EXAMPLE_OID_LARGE_ARC_2);
107+
assert_eq!(oid.arc(0).unwrap(), 1);
108+
assert_eq!(oid.arc(1).unwrap(), 2);
109+
assert_eq!(oid.arc(2).unwrap(), 4294967295);
110+
assert_eq!(oid.arc(3), None);
96111

97112
// Empty
98113
assert_eq!(ObjectIdentifier::from_bytes(&[]), Err(Error::Empty));
@@ -126,13 +141,11 @@ fn from_str() {
126141
let oid_largearc1 = EXAMPLE_OID_LARGE_ARC_1_STR
127142
.parse::<ObjectIdentifier>()
128143
.unwrap();
129-
assert_eq!(oid_largearc1.arc(0).unwrap(), 0);
130-
assert_eq!(oid_largearc1.arc(1).unwrap(), 9);
131-
assert_eq!(oid_largearc1.arc(2).unwrap(), 2342);
132-
assert_eq!(oid_largearc1.arc(3).unwrap(), 19200300);
133-
assert_eq!(oid_largearc1.arc(4).unwrap(), 100);
134-
assert_eq!(oid_largearc1.arc(5).unwrap(), 1);
135-
assert_eq!(oid_largearc1.arc(6).unwrap(), 1);
144+
assert_eq!(oid_largearc1.arc(0).unwrap(), 1);
145+
assert_eq!(oid_largearc1.arc(1).unwrap(), 1);
146+
assert_eq!(oid_largearc1.arc(2).unwrap(), 1);
147+
assert_eq!(oid_largearc1.arc(3).unwrap(), 60817410);
148+
assert_eq!(oid_largearc1.arc(4).unwrap(), 1);
136149
assert_eq!(oid_largearc1, EXAMPLE_OID_LARGE_ARC_1);
137150

138151
let oid_largearc2 = EXAMPLE_OID_LARGE_ARC_2_STR

0 commit comments

Comments
 (0)