Skip to content

Commit 3eb8f16

Browse files
authored
der: extract reader::position::Position (#1880)
Extracts a state management struct for nested position tracking. The goal is to support nested message parsing keeping track of where we are in the message using only the call stack, i.e. without the need to keep such information on the heap, as this crate is intended to support no_std/no-alloc "heapless" targets. Extracting this logic into its struct not only makes it potentially reusable, but makes it significantly easier to test, and this commit also adds a number of initial tests.
1 parent db6172a commit 3eb8f16

File tree

3 files changed

+169
-35
lines changed

3 files changed

+169
-35
lines changed

der/src/reader.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
pub(crate) mod pem;
55
pub(crate) mod slice;
66

7+
#[cfg(feature = "pem")]
8+
mod position;
9+
710
use crate::{
811
Decode, DecodeValue, Encode, EncodingRules, Error, ErrorKind, FixedTag, Header, Length, Tag,
912
TagMode, TagNumber, asn1::ContextSpecific,

der/src/reader/pem.rs

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Streaming PEM reader.
22
3-
use super::Reader;
4-
use crate::{EncodingRules, Error, ErrorKind, Length};
3+
use super::{Reader, position::Position};
4+
use crate::{EncodingRules, Error, ErrorKind, Length, Result};
55
use pem_rfc7468::Decoder;
66

77
/// `Reader` type which decodes PEM on-the-fly.
@@ -14,27 +14,23 @@ pub struct PemReader<'i> {
1414
/// Encoding rules to apply when decoding the input.
1515
encoding_rules: EncodingRules,
1616

17-
/// Input length (in bytes after Base64 decoding).
18-
input_len: Length,
19-
20-
/// Position in the input buffer (in bytes after Base64 decoding).
21-
position: Length,
17+
/// Position tracker.
18+
position: Position,
2219
}
2320

2421
#[cfg(feature = "pem")]
2522
impl<'i> PemReader<'i> {
2623
/// Create a new PEM reader which decodes data on-the-fly.
2724
///
2825
/// Uses the default 64-character line wrapping.
29-
pub fn new(pem: &'i [u8]) -> crate::Result<Self> {
26+
pub fn new(pem: &'i [u8]) -> Result<Self> {
3027
let decoder = Decoder::new(pem)?;
3128
let input_len = Length::try_from(decoder.remaining_len())?;
3229

3330
Ok(Self {
3431
decoder,
3532
encoding_rules: EncodingRules::default(),
36-
input_len,
37-
position: Length::ZERO,
33+
position: Position::new(input_len),
3834
})
3935
}
4036

@@ -52,52 +48,37 @@ impl<'i> Reader<'i> for PemReader<'i> {
5248
}
5349

5450
fn input_len(&self) -> Length {
55-
self.input_len
51+
self.position.input_len()
5652
}
5753

58-
fn peek_into(&self, buf: &mut [u8]) -> crate::Result<()> {
54+
fn peek_into(&self, buf: &mut [u8]) -> Result<()> {
5955
self.clone().read_into(buf)?;
6056
Ok(())
6157
}
6258

6359
fn position(&self) -> Length {
64-
self.position
60+
self.position.current()
6561
}
6662

67-
fn read_nested<T, F, E>(&mut self, len: Length, f: F) -> Result<T, E>
63+
fn read_nested<T, F, E>(&mut self, len: Length, f: F) -> core::result::Result<T, E>
6864
where
69-
F: FnOnce(&mut Self) -> Result<T, E>,
65+
F: FnOnce(&mut Self) -> core::result::Result<T, E>,
7066
E: From<Error>,
7167
{
72-
let nested_input_len = (self.position + len)?;
73-
if nested_input_len > self.input_len {
74-
return Err(Error::incomplete(self.input_len).into());
75-
}
76-
77-
let orig_input_len = self.input_len;
78-
self.input_len = nested_input_len;
68+
let resumption = self.position.split_nested(len)?;
7969
let ret = f(self);
80-
self.input_len = orig_input_len;
70+
self.position.resume_nested(resumption);
8171
ret
8272
}
8373

84-
fn read_slice(&mut self, _len: Length) -> crate::Result<&'i [u8]> {
74+
fn read_slice(&mut self, _len: Length) -> Result<&'i [u8]> {
8575
// Can't borrow from PEM because it requires decoding
8676
Err(ErrorKind::Reader.into())
8777
}
8878

89-
fn read_into<'o>(&mut self, buf: &'o mut [u8]) -> crate::Result<&'o [u8]> {
90-
let new_position = (self.position + buf.len())?;
91-
if new_position > self.input_len {
92-
return Err(ErrorKind::Incomplete {
93-
expected_len: new_position,
94-
actual_len: self.input_len,
95-
}
96-
.at(self.position));
97-
}
98-
79+
fn read_into<'o>(&mut self, buf: &'o mut [u8]) -> Result<&'o [u8]> {
80+
self.position.advance(Length::try_from(buf.len())?)?;
9981
self.decoder.decode(buf)?;
100-
self.position = new_position;
10182
Ok(buf)
10283
}
10384
}

der/src/reader/position.rs

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
//! Position tracking for processing nested input messages using only the stack.
2+
3+
use crate::{Error, ErrorKind, Length, Result};
4+
5+
/// State tracker for the current position in the input.
6+
#[derive(Clone, Debug)]
7+
pub(super) struct Position {
8+
/// Input length (in bytes after Base64 decoding).
9+
input_len: Length,
10+
11+
/// Position in the input buffer (in bytes after Base64 decoding).
12+
position: Length,
13+
}
14+
15+
impl Position {
16+
/// Create a new position tracker with the given overall length.
17+
pub(super) fn new(input_len: Length) -> Self {
18+
Self {
19+
input_len,
20+
position: Length::ZERO,
21+
}
22+
}
23+
24+
/// Get the input length.
25+
pub(super) fn input_len(&self) -> Length {
26+
self.input_len
27+
}
28+
29+
/// Get the current position.
30+
pub(super) fn current(&self) -> Length {
31+
self.position
32+
}
33+
34+
/// Advance the current position by the given amount.
35+
///
36+
/// # Returns
37+
///
38+
/// The new current position.
39+
pub(super) fn advance(&mut self, amount: Length) -> Result<Length> {
40+
let new_position = (self.position + amount)?;
41+
42+
if new_position > self.input_len {
43+
return Err(ErrorKind::Incomplete {
44+
expected_len: new_position,
45+
actual_len: self.input_len,
46+
}
47+
.at(self.position));
48+
}
49+
50+
self.position = new_position;
51+
Ok(new_position)
52+
}
53+
54+
/// Split a nested position tracker of the given size.
55+
///
56+
/// # Returns
57+
///
58+
/// A [`Resumption`] value which can be used to continue parsing the outer message.
59+
pub(super) fn split_nested(&mut self, len: Length) -> Result<Resumption> {
60+
let nested_input_len = (self.position + len)?;
61+
62+
if nested_input_len > self.input_len {
63+
return Err(Error::incomplete(self.input_len));
64+
}
65+
66+
let resumption = Resumption {
67+
input_len: self.input_len,
68+
};
69+
self.input_len = nested_input_len;
70+
Ok(resumption)
71+
}
72+
73+
/// Resume processing the rest of a message after processing a nested inner portion.
74+
pub(super) fn resume_nested(&mut self, resumption: Resumption) {
75+
self.input_len = resumption.input_len;
76+
}
77+
}
78+
79+
/// Resumption state needed to continue processing a message after handling a nested inner portion.
80+
#[derive(Debug)]
81+
pub(super) struct Resumption {
82+
/// Outer input length.
83+
input_len: Length,
84+
}
85+
86+
#[cfg(test)]
87+
#[allow(clippy::unwrap_used)]
88+
mod tests {
89+
use super::Position;
90+
use crate::{ErrorKind, Length};
91+
92+
const EXAMPLE_LEN: Length = match Length::new_usize(42) {
93+
Ok(len) => len,
94+
Err(_) => panic!("invalid example len"),
95+
};
96+
97+
#[test]
98+
fn initial_state() {
99+
let pos = Position::new(EXAMPLE_LEN);
100+
assert_eq!(pos.input_len(), EXAMPLE_LEN);
101+
assert_eq!(pos.current(), Length::ZERO);
102+
}
103+
104+
#[test]
105+
fn advance() {
106+
let mut pos = Position::new(EXAMPLE_LEN);
107+
108+
// advance 1 byte: success
109+
let new_pos = pos.advance(Length::ONE).unwrap();
110+
assert_eq!(new_pos, Length::ONE);
111+
assert_eq!(pos.current(), Length::ONE);
112+
113+
// advance to end: success
114+
let end_pos = pos.advance((EXAMPLE_LEN - Length::ONE).unwrap()).unwrap();
115+
assert_eq!(end_pos, EXAMPLE_LEN);
116+
assert_eq!(pos.current(), EXAMPLE_LEN);
117+
118+
// advance one byte past end: error
119+
let err = pos.advance(Length::ONE).unwrap_err();
120+
assert!(matches!(err.kind(), ErrorKind::Incomplete { .. }));
121+
}
122+
123+
#[test]
124+
fn nested() {
125+
let mut pos = Position::new(EXAMPLE_LEN);
126+
127+
// split first byte
128+
let resumption = pos.split_nested(Length::ONE).unwrap();
129+
assert_eq!(pos.current(), Length::ZERO);
130+
assert_eq!(pos.input_len(), Length::ONE);
131+
132+
// advance one byte
133+
assert_eq!(pos.advance(Length::ONE).unwrap(), Length::ONE);
134+
135+
// can't advance two bytes
136+
let err = pos.advance(Length::ONE).unwrap_err();
137+
assert!(matches!(err.kind(), ErrorKind::Incomplete { .. }));
138+
139+
// resume processing the rest of the message
140+
// TODO(tarcieri): should we fail here if we previously failed reading a nested message?
141+
pos.resume_nested(resumption);
142+
143+
assert_eq!(pos.current(), Length::ONE);
144+
assert_eq!(pos.input_len(), EXAMPLE_LEN);
145+
146+
// try to split one byte past end: error
147+
let err = pos.split_nested(EXAMPLE_LEN).unwrap_err();
148+
assert!(matches!(err.kind(), ErrorKind::Incomplete { .. }));
149+
}
150+
}

0 commit comments

Comments
 (0)