Skip to content

Commit 96aeea7

Browse files
authored
ml-kem: factor DecapsulationKey/EncapsulationKey into modules (#225)
Factors them into `decapsulation_key` and `encapsulation_key` modules, which makes it easier to find impls related to either, and also frees up the `kem` name so we can re-export the `kem` crate.
1 parent 69a5030 commit 96aeea7

File tree

5 files changed

+467
-466
lines changed

5 files changed

+467
-466
lines changed

ml-kem/src/decapsulation_key.rs

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
use crate::{
2+
B32, EncapsulationKey, Encoded, EncodedSizeUser, ExpandedDecapsulationKey, Seed, SharedKey,
3+
crypto::{G, J},
4+
kem::{Generate, InvalidKey, Kem, KeyInit, KeySizeUser},
5+
param::{DecapsulationKeySize, KemParams},
6+
pke::{DecryptionKey, EncryptionKey},
7+
};
8+
use array::sizes::{U32, U64};
9+
use kem::{Ciphertext, Decapsulate};
10+
use rand_core::{TryCryptoRng, TryRng};
11+
use subtle::{ConditionallySelectable, ConstantTimeEq};
12+
13+
#[cfg(feature = "zeroize")]
14+
use zeroize::{Zeroize, ZeroizeOnDrop};
15+
16+
/// A `DecapsulationKey` provides the ability to generate a new key pair, and decapsulate an
17+
/// encapsulated shared key.
18+
#[derive(Clone, Debug)]
19+
pub struct DecapsulationKey<P>
20+
where
21+
P: KemParams,
22+
{
23+
dk_pke: DecryptionKey<P>,
24+
ek: EncapsulationKey<P>,
25+
d: Option<B32>,
26+
z: B32,
27+
}
28+
29+
impl<P> DecapsulationKey<P>
30+
where
31+
P: KemParams,
32+
{
33+
/// Create a [`DecapsulationKey`] instance from a 64-byte random seed value.
34+
#[inline]
35+
#[must_use]
36+
pub fn from_seed(seed: Seed) -> Self {
37+
let (d, z) = seed.split();
38+
Self::generate_deterministic(d, z)
39+
}
40+
41+
/// Initialize a [`DecapsulationKey`] from the serialized expanded key form.
42+
///
43+
/// Note that this form is deprecated in practice; prefer to use
44+
/// [`DecapsulationKey::from_seed`].
45+
///
46+
/// # Errors
47+
/// - Returns [`InvalidKey`] in the event the expanded key failed validation
48+
#[deprecated(since = "0.3.0", note = "use `DecapsulationKey::from_seed` instead")]
49+
pub fn from_expanded(enc: &ExpandedDecapsulationKey<P>) -> Result<Self, InvalidKey> {
50+
let (dk_pke, ek_pke, h, z) = P::split_dk(enc);
51+
let dk_pke = DecryptionKey::from_bytes(dk_pke);
52+
let ek_pke = EncryptionKey::from_bytes(ek_pke)?;
53+
54+
let ek = EncapsulationKey::from_encryption_key(ek_pke);
55+
if ek.h() != *h {
56+
return Err(InvalidKey);
57+
}
58+
59+
Ok(Self {
60+
dk_pke,
61+
ek,
62+
d: None,
63+
z: z.clone(),
64+
})
65+
}
66+
67+
/// Serialize the [`Seed`] value: 64-bytes which can be used to reconstruct the
68+
/// [`DecapsulationKey`].
69+
///
70+
/// <div class="warning">
71+
/// <b>Warning!</B>
72+
///
73+
/// This value is key material. Please treat it with care.
74+
/// </div>
75+
///
76+
/// # Returns
77+
/// - `Some` if the [`DecapsulationKey`] was initialized using `from_seed` or `generate`.
78+
/// - `None` if the [`DecapsulationKey`] was initialized from the expanded form.
79+
#[inline]
80+
pub fn to_seed(&self) -> Option<Seed> {
81+
self.d.map(|d| d.concat(self.z))
82+
}
83+
84+
/// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationKey`].
85+
pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
86+
&self.ek
87+
}
88+
89+
#[inline]
90+
pub(crate) fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
91+
where
92+
R: TryCryptoRng + ?Sized,
93+
{
94+
let d = B32::try_generate_from_rng(rng)?;
95+
let z = B32::try_generate_from_rng(rng)?;
96+
Ok(Self::generate_deterministic(d, z))
97+
}
98+
99+
#[inline]
100+
#[must_use]
101+
#[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec
102+
pub(crate) fn generate_deterministic(d: B32, z: B32) -> Self {
103+
let (dk_pke, ek_pke) = DecryptionKey::generate(&d);
104+
let ek = EncapsulationKey::from_encryption_key(ek_pke);
105+
let d = Some(d);
106+
Self { dk_pke, ek, d, z }
107+
}
108+
}
109+
110+
// Handwritten to omit `d` in the comparisons, so keys initialized from seeds compare equally to
111+
// keys initialized from the expanded form
112+
impl<P> PartialEq for DecapsulationKey<P>
113+
where
114+
P: KemParams,
115+
{
116+
fn eq(&self, other: &Self) -> bool {
117+
self.dk_pke.ct_eq(&other.dk_pke).into() && self.ek.eq(&other.ek) && self.z.eq(&other.z)
118+
}
119+
}
120+
121+
#[cfg(feature = "zeroize")]
122+
impl<P> Drop for DecapsulationKey<P>
123+
where
124+
P: KemParams,
125+
{
126+
fn drop(&mut self) {
127+
self.dk_pke.zeroize();
128+
self.d.zeroize();
129+
self.z.zeroize();
130+
}
131+
}
132+
133+
#[cfg(feature = "zeroize")]
134+
impl<P> ZeroizeOnDrop for DecapsulationKey<P> where P: KemParams {}
135+
136+
impl<P> From<Seed> for DecapsulationKey<P>
137+
where
138+
P: KemParams,
139+
{
140+
fn from(seed: Seed) -> Self {
141+
Self::from_seed(seed)
142+
}
143+
}
144+
145+
impl<P> Decapsulate<P> for DecapsulationKey<P>
146+
where
147+
P: Kem<EncapsulationKey = EncapsulationKey<P>, SharedKeySize = U32> + KemParams,
148+
{
149+
fn decapsulate(&self, encapsulated_key: &Ciphertext<P>) -> SharedKey {
150+
let mp = self.dk_pke.decrypt(encapsulated_key);
151+
let (Kp, rp) = G(&[&mp, &self.ek.h()]);
152+
let Kbar = J(&[self.z.as_slice(), encapsulated_key.as_ref()]);
153+
let cp = self.ek.ek_pke().encrypt(&mp, &rp);
154+
B32::conditional_select(&Kbar, &Kp, cp.ct_eq(encapsulated_key))
155+
}
156+
}
157+
158+
impl<P> AsRef<EncapsulationKey<P>> for DecapsulationKey<P>
159+
where
160+
P: KemParams,
161+
{
162+
fn as_ref(&self) -> &EncapsulationKey<P> {
163+
&self.ek
164+
}
165+
}
166+
167+
impl<P> EncodedSizeUser for DecapsulationKey<P>
168+
where
169+
P: KemParams,
170+
{
171+
type EncodedSize = DecapsulationKeySize<P>;
172+
173+
fn from_encoded_bytes(expanded: &Encoded<Self>) -> Result<Self, InvalidKey> {
174+
#[allow(deprecated)]
175+
Self::from_expanded(expanded)
176+
}
177+
178+
fn to_encoded_bytes(&self) -> Encoded<Self> {
179+
let dk_pke = self.dk_pke.to_bytes();
180+
let ek = self.ek.to_encoded_bytes();
181+
P::concat_dk(dk_pke, ek, self.ek.h(), self.z.clone())
182+
}
183+
}
184+
185+
impl<P> Generate for DecapsulationKey<P>
186+
where
187+
P: KemParams,
188+
{
189+
fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
190+
where
191+
R: TryCryptoRng + ?Sized,
192+
{
193+
Self::try_generate_from_rng(rng)
194+
}
195+
}
196+
197+
impl<P> KeySizeUser for DecapsulationKey<P>
198+
where
199+
P: KemParams,
200+
{
201+
type KeySize = U64;
202+
}
203+
204+
impl<P> KeyInit for DecapsulationKey<P>
205+
where
206+
P: KemParams,
207+
{
208+
#[inline]
209+
fn new(seed: &Seed) -> Self {
210+
Self::from_seed(*seed)
211+
}
212+
}

ml-kem/src/encapsulation_key.rs

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
use crate::{
2+
B32, Encoded, EncodedSizeUser, SharedKey,
3+
crypto::{G, H},
4+
kem::{InvalidKey, Kem, Key, KeyExport, KeySizeUser, TryKeyInit},
5+
param::{EncapsulationKeySize, KemParams},
6+
pke::EncryptionKey,
7+
};
8+
use array::sizes::U32;
9+
use kem::{Ciphertext, Encapsulate, Generate};
10+
use rand_core::CryptoRng;
11+
12+
/// An `EncapsulationKey` provides the ability to encapsulate a shared key so that it can only be
13+
/// decapsulated by the holder of the corresponding decapsulation key.
14+
#[derive(Clone, Debug)]
15+
pub struct EncapsulationKey<P>
16+
where
17+
P: KemParams,
18+
{
19+
ek_pke: EncryptionKey<P>,
20+
h: B32,
21+
}
22+
23+
impl<P> EncapsulationKey<P>
24+
where
25+
P: Kem<SharedKeySize = U32> + KemParams,
26+
{
27+
/// Encapsulates with the given randomness. This is useful for testing against known vectors.
28+
///
29+
/// # Warning
30+
/// Do NOT use this function unless you know what you're doing. If you fail to use all uniform
31+
/// random bytes even once, you can have catastrophic security failure.
32+
#[cfg_attr(not(feature = "hazmat"), doc(hidden))]
33+
pub fn encapsulate_deterministic(&self, m: &B32) -> (Ciphertext<P>, SharedKey) {
34+
let (K, r) = G(&[m, &self.h]);
35+
let c = self.ek_pke.encrypt(m, &r);
36+
(c, K)
37+
}
38+
39+
/// Convert from an `EncryptionKey`.
40+
pub(crate) fn from_encryption_key(ek_pke: EncryptionKey<P>) -> Self {
41+
let h = H(ek_pke.to_bytes());
42+
Self { ek_pke, h }
43+
}
44+
45+
/// Borrow the encryption key.
46+
pub(crate) fn ek_pke(&self) -> &EncryptionKey<P> {
47+
&self.ek_pke
48+
}
49+
50+
/// Retrieve the hash of the encryption key.
51+
pub(crate) fn h(&self) -> B32 {
52+
self.h
53+
}
54+
}
55+
56+
impl<P> Encapsulate<P> for EncapsulationKey<P>
57+
where
58+
P: Kem + KemParams,
59+
{
60+
fn encapsulate_with_rng<R>(&self, rng: &mut R) -> (Ciphertext<P>, SharedKey)
61+
where
62+
R: CryptoRng + ?Sized,
63+
{
64+
let m = B32::generate_from_rng(rng);
65+
self.encapsulate_deterministic(&m)
66+
}
67+
}
68+
69+
impl<P> EncodedSizeUser for EncapsulationKey<P>
70+
where
71+
P: KemParams,
72+
{
73+
type EncodedSize = EncapsulationKeySize<P>;
74+
75+
fn from_encoded_bytes(enc: &Encoded<Self>) -> Result<Self, InvalidKey> {
76+
Ok(Self::from_encryption_key(EncryptionKey::from_bytes(enc)?))
77+
}
78+
79+
fn to_encoded_bytes(&self) -> Encoded<Self> {
80+
self.ek_pke.to_bytes()
81+
}
82+
}
83+
84+
impl<P> KeyExport for EncapsulationKey<P>
85+
where
86+
P: KemParams,
87+
{
88+
fn to_bytes(&self) -> Key<Self> {
89+
self.ek_pke.to_bytes()
90+
}
91+
}
92+
93+
impl<P> KeySizeUser for EncapsulationKey<P>
94+
where
95+
P: KemParams,
96+
{
97+
type KeySize = EncapsulationKeySize<P>;
98+
}
99+
100+
impl<P> TryKeyInit for EncapsulationKey<P>
101+
where
102+
P: KemParams,
103+
{
104+
fn new(encapsulation_key: &Key<Self>) -> Result<Self, InvalidKey> {
105+
EncryptionKey::from_bytes(encapsulation_key)
106+
.map(Self::from_encryption_key)
107+
.map_err(|_| InvalidKey)
108+
}
109+
}
110+
111+
impl<P> Eq for EncapsulationKey<P> where P: KemParams {}
112+
impl<P> PartialEq for EncapsulationKey<P>
113+
where
114+
P: KemParams,
115+
{
116+
fn eq(&self, other: &Self) -> bool {
117+
// Handwritten to avoid derive putting `Eq` bounds on `KemParams`
118+
self.ek_pke == other.ek_pke && self.h == other.h
119+
}
120+
}

0 commit comments

Comments
 (0)