Skip to content

Commit 99488f6

Browse files
committed
feat(pq-key-encoder): phase 2 - ASN.1 primitives (ENG-1310, ENG-1311, ENG-1312, ENG-1313, ENG-1314, ENG-1315, ENG-1316)
1 parent a37eb78 commit 99488f6

File tree

7 files changed

+576
-0
lines changed

7 files changed

+576
-0
lines changed
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
use alloc::vec::Vec;
2+
use pq_oid::Algorithm;
3+
4+
use crate::error::{Error, Result};
5+
6+
use super::decode::{decode_oid, read_tlv};
7+
use super::length::encode_length;
8+
use super::tags;
9+
10+
/// Encode an AlgorithmIdentifier SEQUENCE for the given algorithm.
11+
/// Writes: SEQUENCE { OID } (no NULL parameter for PQ algorithms).
12+
pub(crate) fn encode_algorithm_identifier(algorithm: Algorithm, out: &mut Vec<u8>) {
13+
// Build the OID TLV
14+
let mut oid_bytes = Vec::new();
15+
pq_oid::encode_oid_to(algorithm.oid(), &mut oid_bytes)
16+
.expect("known algorithm OID should always encode");
17+
18+
let mut oid_tlv = Vec::new();
19+
oid_tlv.push(tags::TAG_OBJECT_IDENTIFIER);
20+
encode_length(oid_bytes.len(), &mut oid_tlv);
21+
oid_tlv.extend_from_slice(&oid_bytes);
22+
23+
// Wrap in SEQUENCE
24+
out.push(tags::TAG_SEQUENCE);
25+
encode_length(oid_tlv.len(), out);
26+
out.extend_from_slice(&oid_tlv);
27+
}
28+
29+
/// Decode an AlgorithmIdentifier SEQUENCE.
30+
/// Returns `(Algorithm, bytes_read)`.
31+
/// Accepts both absent and NULL parameters for interoperability.
32+
pub(crate) fn decode_algorithm_identifier(
33+
input: &[u8],
34+
offset: usize,
35+
) -> Result<(Algorithm, usize)> {
36+
// Read outer SEQUENCE
37+
let outer = read_tlv(input, offset)?;
38+
if outer.tag != tags::TAG_SEQUENCE {
39+
return Err(Error::InvalidDer(
40+
"expected SEQUENCE for AlgorithmIdentifier",
41+
));
42+
}
43+
44+
// Read OID inside the sequence
45+
let oid_tlv = read_tlv(outer.value, 0)?;
46+
if oid_tlv.tag != tags::TAG_OBJECT_IDENTIFIER {
47+
return Err(Error::InvalidDer(
48+
"expected OBJECT IDENTIFIER in AlgorithmIdentifier",
49+
));
50+
}
51+
52+
let oid_string = decode_oid(oid_tlv.value)?;
53+
54+
// Check for optional parameters after the OID
55+
let consumed = oid_tlv.bytes_read;
56+
if consumed < outer.value.len() {
57+
let remaining = &outer.value[consumed..];
58+
// Accept NULL (0x05 0x00)
59+
if remaining.len() >= 2 && remaining[0] == tags::TAG_NULL && remaining[1] == 0x00 {
60+
// NULL parameter — accepted and skipped
61+
} else {
62+
return Err(Error::InvalidDer(
63+
"unsupported AlgorithmIdentifier parameters",
64+
));
65+
}
66+
}
67+
68+
let algorithm = Algorithm::from_oid(&oid_string).map_err(|_| Error::UnsupportedAlgorithm)?;
69+
70+
Ok((algorithm, outer.bytes_read))
71+
}
72+
73+
#[cfg(test)]
74+
mod tests {
75+
use super::*;
76+
use pq_oid::{MlDsa, MlKem, SlhDsa};
77+
78+
#[test]
79+
fn test_roundtrip_ml_kem_512() {
80+
let alg = Algorithm::MlKem(MlKem::Kem512);
81+
let mut buf = Vec::new();
82+
encode_algorithm_identifier(alg, &mut buf);
83+
let (decoded, bytes_read) = decode_algorithm_identifier(&buf, 0).unwrap();
84+
assert_eq!(decoded, alg);
85+
assert_eq!(bytes_read, buf.len());
86+
}
87+
88+
#[test]
89+
fn test_roundtrip_ml_dsa_44() {
90+
let alg = Algorithm::MlDsa(MlDsa::Dsa44);
91+
let mut buf = Vec::new();
92+
encode_algorithm_identifier(alg, &mut buf);
93+
let (decoded, bytes_read) = decode_algorithm_identifier(&buf, 0).unwrap();
94+
assert_eq!(decoded, alg);
95+
assert_eq!(bytes_read, buf.len());
96+
}
97+
98+
#[test]
99+
fn test_roundtrip_slh_dsa_sha2_128s() {
100+
let alg = Algorithm::SlhDsa(SlhDsa::Sha2_128s);
101+
let mut buf = Vec::new();
102+
encode_algorithm_identifier(alg, &mut buf);
103+
let (decoded, bytes_read) = decode_algorithm_identifier(&buf, 0).unwrap();
104+
assert_eq!(decoded, alg);
105+
assert_eq!(bytes_read, buf.len());
106+
}
107+
108+
#[test]
109+
fn test_roundtrip_all_algorithms() {
110+
for alg in Algorithm::all() {
111+
let mut buf = Vec::new();
112+
encode_algorithm_identifier(alg, &mut buf);
113+
let (decoded, bytes_read) = decode_algorithm_identifier(&buf, 0).unwrap();
114+
assert_eq!(decoded, alg, "failed for {}", alg);
115+
assert_eq!(bytes_read, buf.len());
116+
}
117+
}
118+
119+
#[test]
120+
fn test_decode_with_null_parameter() {
121+
let alg = Algorithm::MlKem(MlKem::Kem512);
122+
let mut buf = Vec::new();
123+
encode_algorithm_identifier(alg, &mut buf);
124+
125+
// Manually add NULL parameter (0x05 0x00) inside the SEQUENCE
126+
// We need to rebuild: SEQUENCE { OID, NULL }
127+
let mut oid_bytes = Vec::new();
128+
pq_oid::encode_oid_to(alg.oid(), &mut oid_bytes).unwrap();
129+
130+
let mut inner = Vec::new();
131+
inner.push(0x06); // OID tag
132+
super::super::length::encode_length(oid_bytes.len(), &mut inner);
133+
inner.extend_from_slice(&oid_bytes);
134+
inner.extend_from_slice(&[0x05, 0x00]); // NULL
135+
136+
let mut with_null = Vec::new();
137+
with_null.push(0x30); // SEQUENCE
138+
super::super::length::encode_length(inner.len(), &mut with_null);
139+
with_null.extend_from_slice(&inner);
140+
141+
let (decoded, _) = decode_algorithm_identifier(&with_null, 0).unwrap();
142+
assert_eq!(decoded, alg);
143+
}
144+
145+
#[test]
146+
fn test_decode_unsupported_parameters() {
147+
let alg = Algorithm::MlKem(MlKem::Kem512);
148+
149+
let mut oid_bytes = Vec::new();
150+
pq_oid::encode_oid_to(alg.oid(), &mut oid_bytes).unwrap();
151+
152+
let mut inner = Vec::new();
153+
inner.push(0x06);
154+
super::super::length::encode_length(oid_bytes.len(), &mut inner);
155+
inner.extend_from_slice(&oid_bytes);
156+
inner.extend_from_slice(&[0x04, 0x01, 0x00]); // OCTET STRING (unsupported)
157+
158+
let mut bad = Vec::new();
159+
bad.push(0x30);
160+
super::super::length::encode_length(inner.len(), &mut bad);
161+
bad.extend_from_slice(&inner);
162+
163+
assert!(decode_algorithm_identifier(&bad, 0).is_err());
164+
}
165+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
use alloc::string::String;
2+
3+
use crate::error::{Error, Result};
4+
5+
use super::length::decode_length;
6+
7+
/// A parsed TLV element borrowing from the input buffer.
8+
#[derive(Debug)]
9+
pub(crate) struct Tlv<'a> {
10+
pub tag: u8,
11+
pub value: &'a [u8],
12+
pub bytes_read: usize,
13+
}
14+
15+
/// Read one TLV element from input starting at offset.
16+
pub(crate) fn read_tlv(input: &[u8], offset: usize) -> Result<Tlv<'_>> {
17+
if offset >= input.len() {
18+
return Err(Error::InvalidDer("unexpected end of input"));
19+
}
20+
21+
let tag = input[offset];
22+
let (length, len_bytes) = decode_length(input, offset + 1)?;
23+
let value_start = offset + 1 + len_bytes;
24+
25+
if value_start + length > input.len() {
26+
return Err(Error::InvalidDer("truncated value"));
27+
}
28+
29+
Ok(Tlv {
30+
tag,
31+
value: &input[value_start..value_start + length],
32+
bytes_read: 1 + len_bytes + length,
33+
})
34+
}
35+
36+
/// Decode an OID from raw DER value bytes (without tag/length).
37+
/// Returns dotted notation string like "2.16.840.1.101.3.4.4.1".
38+
pub(crate) fn decode_oid(bytes: &[u8]) -> Result<String> {
39+
Ok(pq_oid::decode_oid(bytes)?)
40+
}
41+
42+
#[cfg(test)]
43+
mod tests {
44+
use super::*;
45+
46+
#[test]
47+
fn test_read_tlv_simple() {
48+
// OCTET STRING containing [0xAA, 0xBB]
49+
let data = [0x04, 0x02, 0xAA, 0xBB];
50+
let tlv = read_tlv(&data, 0).unwrap();
51+
assert_eq!(tlv.tag, 0x04);
52+
assert_eq!(tlv.value, &[0xAA, 0xBB]);
53+
assert_eq!(tlv.bytes_read, 4);
54+
}
55+
56+
#[test]
57+
fn test_read_tlv_empty_value() {
58+
let data = [0x05, 0x00]; // NULL
59+
let tlv = read_tlv(&data, 0).unwrap();
60+
assert_eq!(tlv.tag, 0x05);
61+
assert_eq!(tlv.value, &[]);
62+
assert_eq!(tlv.bytes_read, 2);
63+
}
64+
65+
#[test]
66+
fn test_read_tlv_with_offset() {
67+
let data = [0xFF, 0xFF, 0x02, 0x01, 0x00];
68+
let tlv = read_tlv(&data, 2).unwrap();
69+
assert_eq!(tlv.tag, 0x02);
70+
assert_eq!(tlv.value, &[0x00]);
71+
assert_eq!(tlv.bytes_read, 3);
72+
}
73+
74+
#[test]
75+
fn test_read_tlv_sequence() {
76+
// SEQUENCE { INTEGER 0 }
77+
let data = [0x30, 0x03, 0x02, 0x01, 0x00];
78+
let tlv = read_tlv(&data, 0).unwrap();
79+
assert_eq!(tlv.tag, 0x30);
80+
assert_eq!(tlv.value, &[0x02, 0x01, 0x00]);
81+
assert_eq!(tlv.bytes_read, 5);
82+
}
83+
84+
#[test]
85+
fn test_read_tlv_error_empty() {
86+
assert!(read_tlv(&[], 0).is_err());
87+
}
88+
89+
#[test]
90+
fn test_read_tlv_error_truncated() {
91+
// Claims 3 bytes but only 2 available
92+
let data = [0x04, 0x03, 0xAA, 0xBB];
93+
assert!(read_tlv(&data, 0).is_err());
94+
}
95+
96+
#[test]
97+
fn test_read_tlv_error_offset_past_end() {
98+
let data = [0x04, 0x01, 0xAA];
99+
assert!(read_tlv(&data, 10).is_err());
100+
}
101+
102+
#[test]
103+
fn test_decode_oid_ml_kem_512() {
104+
let bytes = [0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x04, 0x01];
105+
let oid = decode_oid(&bytes).unwrap();
106+
assert_eq!(oid, "2.16.840.1.101.3.4.4.1");
107+
}
108+
109+
#[test]
110+
fn test_decode_oid_invalid() {
111+
assert!(decode_oid(&[]).is_err());
112+
}
113+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
use alloc::vec::Vec;
2+
3+
use super::length::encode_length;
4+
use super::tags;
5+
6+
/// Write a complete TLV (tag + length + value) to the buffer.
7+
pub(crate) fn encode_tlv(tag: u8, value: &[u8], out: &mut Vec<u8>) {
8+
out.push(tag);
9+
encode_length(value.len(), out);
10+
out.extend_from_slice(value);
11+
}
12+
13+
/// Write a SEQUENCE containing the given pre-encoded elements.
14+
pub(crate) fn encode_sequence(elements: &[&[u8]], out: &mut Vec<u8>) {
15+
let total_len: usize = elements.iter().map(|e| e.len()).sum();
16+
out.push(tags::TAG_SEQUENCE);
17+
encode_length(total_len, out);
18+
for element in elements {
19+
out.extend_from_slice(element);
20+
}
21+
}
22+
23+
/// Write an OCTET STRING wrapping the given data.
24+
pub(crate) fn encode_octet_string(data: &[u8], out: &mut Vec<u8>) {
25+
encode_tlv(tags::TAG_OCTET_STRING, data, out);
26+
}
27+
28+
/// Write a BIT STRING with 0 unused bits wrapping the given data.
29+
pub(crate) fn encode_bit_string(data: &[u8], out: &mut Vec<u8>) {
30+
out.push(tags::TAG_BIT_STRING);
31+
encode_length(data.len() + 1, out);
32+
out.push(0x00); // unused bits
33+
out.extend_from_slice(data);
34+
}
35+
36+
/// Write INTEGER 0 (version field for PKCS8). Fixed bytes: 02 01 00.
37+
pub(crate) fn encode_integer_zero(out: &mut Vec<u8>) {
38+
out.extend_from_slice(&[0x02, 0x01, 0x00]);
39+
}
40+
41+
#[cfg(test)]
42+
mod tests {
43+
use super::*;
44+
45+
#[test]
46+
fn test_encode_tlv() {
47+
let mut out = Vec::new();
48+
encode_tlv(0x04, &[0x01, 0x02, 0x03], &mut out);
49+
assert_eq!(out, [0x04, 0x03, 0x01, 0x02, 0x03]);
50+
}
51+
52+
#[test]
53+
fn test_encode_tlv_empty() {
54+
let mut out = Vec::new();
55+
encode_tlv(0x04, &[], &mut out);
56+
assert_eq!(out, [0x04, 0x00]);
57+
}
58+
59+
#[test]
60+
fn test_encode_sequence() {
61+
let elem1 = [0x02, 0x01, 0x00]; // INTEGER 0
62+
let elem2 = [0x04, 0x02, 0xAA, 0xBB]; // OCTET STRING
63+
let mut out = Vec::new();
64+
encode_sequence(&[&elem1, &elem2], &mut out);
65+
assert_eq!(out, [0x30, 0x07, 0x02, 0x01, 0x00, 0x04, 0x02, 0xAA, 0xBB]);
66+
}
67+
68+
#[test]
69+
fn test_encode_octet_string() {
70+
let mut out = Vec::new();
71+
encode_octet_string(&[0xDE, 0xAD], &mut out);
72+
assert_eq!(out, [0x04, 0x02, 0xDE, 0xAD]);
73+
}
74+
75+
#[test]
76+
fn test_encode_bit_string() {
77+
let mut out = Vec::new();
78+
encode_bit_string(&[0xCA, 0xFE], &mut out);
79+
assert_eq!(out, [0x03, 0x03, 0x00, 0xCA, 0xFE]);
80+
}
81+
82+
#[test]
83+
fn test_encode_integer_zero() {
84+
let mut out = Vec::new();
85+
encode_integer_zero(&mut out);
86+
assert_eq!(out, [0x02, 0x01, 0x00]);
87+
}
88+
}

0 commit comments

Comments
 (0)