Skip to content

Commit b1a2fd3

Browse files
authored
feat(pq-key-encoder): phase 2 - ASN.1 primitives (ENG-1310, ENG-1311, ENG-1312, ENG-1313, ENG-1314, ENG-1315, ENG-1316) (#8)
1 parent a37eb78 commit b1a2fd3

File tree

7 files changed

+615
-0
lines changed

7 files changed

+615
-0
lines changed
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
use alloc::vec::Vec;
2+
use pq_oid::Algorithm;
3+
4+
use crate::error::{Error, Result};
5+
6+
use super::decode::{decode_oid, read_tlv};
7+
use super::length::encode_length;
8+
use super::tags;
9+
10+
/// Encode an AlgorithmIdentifier SEQUENCE for the given algorithm.
11+
/// Writes: SEQUENCE { OID } (no NULL parameter for PQ algorithms).
12+
pub(crate) fn encode_algorithm_identifier(algorithm: Algorithm, out: &mut Vec<u8>) {
13+
// Build the OID TLV
14+
let mut oid_bytes = Vec::new();
15+
pq_oid::encode_oid_to(algorithm.oid(), &mut oid_bytes)
16+
.expect("known algorithm OID should always encode");
17+
18+
let mut oid_tlv = Vec::new();
19+
oid_tlv.push(tags::TAG_OBJECT_IDENTIFIER);
20+
encode_length(oid_bytes.len(), &mut oid_tlv);
21+
oid_tlv.extend_from_slice(&oid_bytes);
22+
23+
// Wrap in SEQUENCE
24+
out.push(tags::TAG_SEQUENCE);
25+
encode_length(oid_tlv.len(), out);
26+
out.extend_from_slice(&oid_tlv);
27+
}
28+
29+
/// Decode an AlgorithmIdentifier SEQUENCE.
30+
/// Returns `(Algorithm, bytes_read)`.
31+
/// Accepts both absent and NULL parameters for interoperability.
32+
pub(crate) fn decode_algorithm_identifier(
33+
input: &[u8],
34+
offset: usize,
35+
) -> Result<(Algorithm, usize)> {
36+
// Read outer SEQUENCE
37+
let outer = read_tlv(input, offset)?;
38+
if outer.tag != tags::TAG_SEQUENCE {
39+
return Err(Error::InvalidDer(
40+
"expected SEQUENCE for AlgorithmIdentifier",
41+
));
42+
}
43+
44+
// Read OID inside the sequence
45+
let oid_tlv = read_tlv(outer.value, 0)?;
46+
if oid_tlv.tag != tags::TAG_OBJECT_IDENTIFIER {
47+
return Err(Error::InvalidDer(
48+
"expected OBJECT IDENTIFIER in AlgorithmIdentifier",
49+
));
50+
}
51+
52+
let oid_string = decode_oid(oid_tlv.value)?;
53+
54+
// Check for optional parameters after the OID
55+
let consumed = oid_tlv.bytes_read;
56+
if consumed < outer.value.len() {
57+
let remaining = &outer.value[consumed..];
58+
// Accept NULL (0x05 0x00), but reject any trailing data after it
59+
if remaining.len() >= 2 && remaining[0] == tags::TAG_NULL && remaining[1] == 0x00 {
60+
if remaining.len() > 2 {
61+
return Err(Error::InvalidDer(
62+
"trailing data after AlgorithmIdentifier parameters",
63+
));
64+
}
65+
} else {
66+
return Err(Error::InvalidDer(
67+
"unsupported AlgorithmIdentifier parameters",
68+
));
69+
}
70+
}
71+
72+
let algorithm = Algorithm::from_oid(&oid_string).map_err(|_| Error::UnsupportedAlgorithm)?;
73+
74+
Ok((algorithm, outer.bytes_read))
75+
}
76+
77+
#[cfg(test)]
78+
mod tests {
79+
use super::*;
80+
use pq_oid::{MlDsa, MlKem, SlhDsa};
81+
82+
#[test]
83+
fn test_roundtrip_ml_kem_512() {
84+
let alg = Algorithm::MlKem(MlKem::Kem512);
85+
let mut buf = Vec::new();
86+
encode_algorithm_identifier(alg, &mut buf);
87+
let (decoded, bytes_read) = decode_algorithm_identifier(&buf, 0).unwrap();
88+
assert_eq!(decoded, alg);
89+
assert_eq!(bytes_read, buf.len());
90+
}
91+
92+
#[test]
93+
fn test_roundtrip_ml_dsa_44() {
94+
let alg = Algorithm::MlDsa(MlDsa::Dsa44);
95+
let mut buf = Vec::new();
96+
encode_algorithm_identifier(alg, &mut buf);
97+
let (decoded, bytes_read) = decode_algorithm_identifier(&buf, 0).unwrap();
98+
assert_eq!(decoded, alg);
99+
assert_eq!(bytes_read, buf.len());
100+
}
101+
102+
#[test]
103+
fn test_roundtrip_slh_dsa_sha2_128s() {
104+
let alg = Algorithm::SlhDsa(SlhDsa::Sha2_128s);
105+
let mut buf = Vec::new();
106+
encode_algorithm_identifier(alg, &mut buf);
107+
let (decoded, bytes_read) = decode_algorithm_identifier(&buf, 0).unwrap();
108+
assert_eq!(decoded, alg);
109+
assert_eq!(bytes_read, buf.len());
110+
}
111+
112+
#[test]
113+
fn test_roundtrip_all_algorithms() {
114+
for alg in Algorithm::all() {
115+
let mut buf = Vec::new();
116+
encode_algorithm_identifier(alg, &mut buf);
117+
let (decoded, bytes_read) = decode_algorithm_identifier(&buf, 0).unwrap();
118+
assert_eq!(decoded, alg, "failed for {}", alg);
119+
assert_eq!(bytes_read, buf.len());
120+
}
121+
}
122+
123+
#[test]
124+
fn test_decode_with_null_parameter() {
125+
let alg = Algorithm::MlKem(MlKem::Kem512);
126+
let mut buf = Vec::new();
127+
encode_algorithm_identifier(alg, &mut buf);
128+
129+
// Manually add NULL parameter (0x05 0x00) inside the SEQUENCE
130+
// We need to rebuild: SEQUENCE { OID, NULL }
131+
let mut oid_bytes = Vec::new();
132+
pq_oid::encode_oid_to(alg.oid(), &mut oid_bytes).unwrap();
133+
134+
let mut inner = Vec::new();
135+
inner.push(0x06); // OID tag
136+
super::super::length::encode_length(oid_bytes.len(), &mut inner);
137+
inner.extend_from_slice(&oid_bytes);
138+
inner.extend_from_slice(&[0x05, 0x00]); // NULL
139+
140+
let mut with_null = Vec::new();
141+
with_null.push(0x30); // SEQUENCE
142+
super::super::length::encode_length(inner.len(), &mut with_null);
143+
with_null.extend_from_slice(&inner);
144+
145+
let (decoded, _) = decode_algorithm_identifier(&with_null, 0).unwrap();
146+
assert_eq!(decoded, alg);
147+
}
148+
149+
#[test]
150+
fn test_decode_unsupported_parameters() {
151+
let alg = Algorithm::MlKem(MlKem::Kem512);
152+
153+
let mut oid_bytes = Vec::new();
154+
pq_oid::encode_oid_to(alg.oid(), &mut oid_bytes).unwrap();
155+
156+
let mut inner = Vec::new();
157+
inner.push(0x06);
158+
super::super::length::encode_length(oid_bytes.len(), &mut inner);
159+
inner.extend_from_slice(&oid_bytes);
160+
inner.extend_from_slice(&[0x04, 0x01, 0x00]); // OCTET STRING (unsupported)
161+
162+
let mut bad = Vec::new();
163+
bad.push(0x30);
164+
super::super::length::encode_length(inner.len(), &mut bad);
165+
bad.extend_from_slice(&inner);
166+
167+
assert!(decode_algorithm_identifier(&bad, 0).is_err());
168+
}
169+
170+
#[test]
171+
fn test_decode_trailing_data_after_null() {
172+
let alg = Algorithm::MlKem(MlKem::Kem512);
173+
174+
let mut oid_bytes = Vec::new();
175+
pq_oid::encode_oid_to(alg.oid(), &mut oid_bytes).unwrap();
176+
177+
let mut inner = Vec::new();
178+
inner.push(0x06);
179+
super::super::length::encode_length(oid_bytes.len(), &mut inner);
180+
inner.extend_from_slice(&oid_bytes);
181+
inner.extend_from_slice(&[0x05, 0x00]); // NULL
182+
inner.extend_from_slice(&[0x04, 0x01, 0x00]); // trailing OCTET STRING
183+
184+
let mut bad = Vec::new();
185+
bad.push(0x30);
186+
super::super::length::encode_length(inner.len(), &mut bad);
187+
bad.extend_from_slice(&inner);
188+
189+
assert!(decode_algorithm_identifier(&bad, 0).is_err());
190+
}
191+
}
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
use alloc::string::String;
2+
3+
use crate::error::{Error, Result};
4+
5+
use super::length::decode_length;
6+
7+
/// A parsed TLV element borrowing from the input buffer.
8+
#[derive(Debug)]
9+
pub(crate) struct Tlv<'a> {
10+
pub tag: u8,
11+
pub value: &'a [u8],
12+
pub bytes_read: usize,
13+
}
14+
15+
/// Read one TLV element from input starting at offset.
16+
pub(crate) fn read_tlv(input: &[u8], offset: usize) -> Result<Tlv<'_>> {
17+
if offset >= input.len() {
18+
return Err(Error::InvalidDer("unexpected end of input"));
19+
}
20+
21+
let tag = input[offset];
22+
let (length, len_bytes) = decode_length(input, offset + 1)?;
23+
let value_start = offset
24+
.checked_add(1)
25+
.and_then(|v| v.checked_add(len_bytes))
26+
.ok_or(Error::InvalidDer("length overflow"))?;
27+
let value_end = value_start
28+
.checked_add(length)
29+
.ok_or(Error::InvalidDer("length overflow"))?;
30+
31+
if value_end > input.len() {
32+
return Err(Error::InvalidDer("truncated value"));
33+
}
34+
35+
Ok(Tlv {
36+
tag,
37+
value: &input[value_start..value_end],
38+
bytes_read: 1 + len_bytes + length,
39+
})
40+
}
41+
42+
/// Decode an OID from raw DER value bytes (without tag/length).
43+
/// Returns dotted notation string like "2.16.840.1.101.3.4.4.1".
44+
pub(crate) fn decode_oid(bytes: &[u8]) -> Result<String> {
45+
Ok(pq_oid::decode_oid(bytes)?)
46+
}
47+
48+
#[cfg(test)]
49+
mod tests {
50+
use super::*;
51+
52+
#[test]
53+
fn test_read_tlv_simple() {
54+
// OCTET STRING containing [0xAA, 0xBB]
55+
let data = [0x04, 0x02, 0xAA, 0xBB];
56+
let tlv = read_tlv(&data, 0).unwrap();
57+
assert_eq!(tlv.tag, 0x04);
58+
assert_eq!(tlv.value, &[0xAA, 0xBB]);
59+
assert_eq!(tlv.bytes_read, 4);
60+
}
61+
62+
#[test]
63+
fn test_read_tlv_empty_value() {
64+
let data = [0x05, 0x00]; // NULL
65+
let tlv = read_tlv(&data, 0).unwrap();
66+
assert_eq!(tlv.tag, 0x05);
67+
assert_eq!(tlv.value, &[]);
68+
assert_eq!(tlv.bytes_read, 2);
69+
}
70+
71+
#[test]
72+
fn test_read_tlv_with_offset() {
73+
let data = [0xFF, 0xFF, 0x02, 0x01, 0x00];
74+
let tlv = read_tlv(&data, 2).unwrap();
75+
assert_eq!(tlv.tag, 0x02);
76+
assert_eq!(tlv.value, &[0x00]);
77+
assert_eq!(tlv.bytes_read, 3);
78+
}
79+
80+
#[test]
81+
fn test_read_tlv_sequence() {
82+
// SEQUENCE { INTEGER 0 }
83+
let data = [0x30, 0x03, 0x02, 0x01, 0x00];
84+
let tlv = read_tlv(&data, 0).unwrap();
85+
assert_eq!(tlv.tag, 0x30);
86+
assert_eq!(tlv.value, &[0x02, 0x01, 0x00]);
87+
assert_eq!(tlv.bytes_read, 5);
88+
}
89+
90+
#[test]
91+
fn test_read_tlv_error_empty() {
92+
assert!(read_tlv(&[], 0).is_err());
93+
}
94+
95+
#[test]
96+
fn test_read_tlv_error_truncated() {
97+
// Claims 3 bytes but only 2 available
98+
let data = [0x04, 0x03, 0xAA, 0xBB];
99+
assert!(read_tlv(&data, 0).is_err());
100+
}
101+
102+
#[test]
103+
fn test_read_tlv_error_offset_past_end() {
104+
let data = [0x04, 0x01, 0xAA];
105+
assert!(read_tlv(&data, 10).is_err());
106+
}
107+
108+
#[test]
109+
fn test_decode_oid_ml_kem_512() {
110+
let bytes = [0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x04, 0x01];
111+
let oid = decode_oid(&bytes).unwrap();
112+
assert_eq!(oid, "2.16.840.1.101.3.4.4.1");
113+
}
114+
115+
#[test]
116+
fn test_decode_oid_invalid() {
117+
assert!(decode_oid(&[]).is_err());
118+
}
119+
}

0 commit comments

Comments
 (0)