Skip to content

Commit 05575b8

Browse files
authored
mlkem: Move caching of encapsulation key (#465)
* mlkem: Move caching of encapsulation key up to individual DecapsulationKey newtypes to avoid storing it in its enitrety twice in memory * 0.17.10
1 parent 7f0d944 commit 05575b8

File tree

7 files changed

+224
-90
lines changed

7 files changed

+224
-90
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
### 0.17.10
22

3-
**Date:** TBD.
3+
**Date:** April 12, 2025.
44

55
**Changelog:**
66
- Add `encap_deterministic()` and `auth_encap_deterministic()` to `DhKem` in `hazardous::kem::x25519_hkdf_sha256::DhKem` [#458](https://github.com/orion-rs/orion/pull/458).

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "orion"
3-
version = "0.17.9"
3+
version = "0.17.10"
44
authors = ["brycx <brycx@protonmail.com>"]
55
description = "Usable, easy and safe pure-Rust crypto"
66
keywords = ["cryptography", "crypto", "aead", "pqc", "kem"]

src/hazardous/ecc/x25519.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,6 @@ pub fn key_agreement(
630630

631631
#[cfg(test)]
632632
mod public {
633-
use super::FieldElement;
634633
use crate::hazardous::ecc::x25519::{
635634
key_agreement, PrivateKey, PublicKey, Scalar, SharedKey, BASEPOINT, PRIVATE_KEY_SIZE,
636635
PUBLIC_KEY_SIZE,
@@ -696,6 +695,7 @@ mod public {
696695
#[cfg(feature = "safe_api")]
697696
// format! is only available with std
698697
fn test_privatekey_debug_impl() {
698+
use super::FieldElement;
699699
let value = format!("{:?}", [1u64, 0u64, 0u64, 0u64, 0u64,].as_ref());
700700
let test_debug_contents = format!("{:?}", FieldElement::one());
701701
assert!(test_debug_contents.contains(&value));

src/hazardous/kem/ml_kem/internal/mod.rs

Lines changed: 39 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ pub fn g(c: &[&[u8]]) -> ([u8; 32], Zeroizing<[u8; 32]>) {
8686
}
8787

8888
/// Internal PKE-related function, for generalizing over the three different PKE parameter-sets.
89-
pub(crate) trait PkeParameters: Clone {
89+
pub(crate) trait PkeParameters {
9090
const N: usize = 256;
9191
const K: usize;
9292
const ETA_1: usize;
@@ -523,7 +523,7 @@ pub(crate) struct DecapKey<
523523
> {
524524
pub(crate) bytes: [u8; ENCODED_SIZE_DK],
525525
s_hat: [RingElementNTT; K],
526-
ek: EncapKey<K, ENCODED_SIZE_EK, Pke>,
526+
_phantom: PhantomData<Pke>,
527527
}
528528

529529
impl<
@@ -617,15 +617,10 @@ impl<
617617
ByteSerialization::decode_12(dk_part, &mut s_hat_poly.coefficients);
618618
}
619619

620-
// Save the encapsulation key, such that it doesn't need to be re-computed in MLKEM-decap_internal().
621-
let ek = EncapKey::<K, ENCODED_SIZE_EK, Pke>::from_slice(
622-
&slice[ENCODE_SIZE_POLY * K..(768 * K) + 32],
623-
)?;
624-
625620
Ok(Self {
626621
bytes: slice.try_into().unwrap(), // NOTE: Should never panic if decapsulation_key_check() succeeds.
627622
s_hat,
628-
ek,
623+
_phantom: PhantomData,
629624
})
630625
}
631626

@@ -682,11 +677,13 @@ impl<
682677
/// - decapsulation key dk ∈ B^{768k+96}.
683678
/// - ciphertext c ∈ B^{32*(d_u*k+d_v)}.
684679
/// - shared secret K ∈ B^{32}.
685-
pub(crate) fn mlkem_decap_internal(
680+
pub(crate) fn mlkem_decap_internal_with_ek(
686681
&self,
687682
c: &[u8],
688683
c_prime: &mut [u8],
684+
ek: &EncapKey<K, ENCODED_SIZE_EK, Pke>,
689685
) -> Result<[u8; 32], UnknownCryptoError> {
686+
debug_assert_eq!(self.get_encapsulation_key_bytes(), ek.as_ref());
690687
debug_assert_eq!(c.len(), Pke::CIPHERTEXT_SIZE);
691688

692689
// Step 1:
@@ -712,9 +709,8 @@ impl<
712709
xof.squeeze(k_bar.as_mut())?;
713710

714711
// Step 8:
715-
debug_assert_eq!(self.get_encapsulation_key_bytes(), self.ek.as_ref());
716712
debug_assert_eq!(self.get_encapsulation_key_bytes(), ek_pke);
717-
self.ek.encrypt(&m, r.as_ref(), c_prime)?;
713+
ek.encrypt(&m, r.as_ref(), c_prime)?;
718714

719715
// Step 9:
720716
let ct_choice = c.ct_ne(c_prime);
@@ -727,6 +723,19 @@ impl<
727723
Ok(k)
728724
}
729725

726+
#[cfg(feature = "safe_api")] // used in from_keys which requires safe_api
727+
pub(crate) fn mlkem_decap_internal(
728+
&self,
729+
c: &[u8],
730+
c_prime: &mut [u8],
731+
) -> Result<[u8; 32], UnknownCryptoError> {
732+
// In this case we aren't provided a cached encapsulation key.
733+
let ek =
734+
EncapKey::<K, ENCODED_SIZE_EK, Pke>::from_slice(self.get_encapsulation_key_bytes())?;
735+
736+
self.mlkem_decap_internal_with_ek(c, c_prime, &ek)
737+
}
738+
730739
pub(crate) fn unprotected_as_bytes(&self) -> &[u8] {
731740
&self.bytes
732741
}
@@ -745,6 +754,7 @@ impl<Pke: PkeParameters> KeyPairInternal<Pke> {
745754
/// k \in [2, 3, 4]
746755
fn keygen<const K: usize, const ENCODED_SIZE_EK: usize, const ENCODED_SIZE_DK: usize>(
747756
d: &[u8],
757+
ek: &mut EncapKey<K, ENCODED_SIZE_EK, Pke>,
748758
dk: &mut DecapKey<K, ENCODED_SIZE_EK, ENCODED_SIZE_DK, Pke>,
749759
) -> Result<(), UnknownCryptoError> {
750760
let (rho, sigma) = g(&[d, &[Pke::K as u8]]);
@@ -753,7 +763,7 @@ impl<Pke: PkeParameters> KeyPairInternal<Pke> {
753763
// Steps 3..7
754764
for i in 0..Pke::K {
755765
for j in 0..Pke::K {
756-
dk.ek.mat_a[i][j] = sample_ntt(&rho, &[j as u8, i as u8])?;
766+
ek.mat_a[i][j] = sample_ntt(&rho, &[j as u8, i as u8])?;
757767
}
758768
}
759769

@@ -773,34 +783,33 @@ impl<Pke: PkeParameters> KeyPairInternal<Pke> {
773783

774784
for i in 0..Pke::K {
775785
dk.s_hat[i] = to_ntt(&s[i]);
776-
dk.ek.t_hat[i] = to_ntt(&e[i]);
786+
ek.t_hat[i] = to_ntt(&e[i]);
777787
}
778788

779789
s.zeroize();
780790

781791
// t ← A ∘ ŝ + ê
782792
for i in 0..Pke::K {
783793
for j in 0..Pke::K {
784-
dk.ek.t_hat[i] += dk.ek.mat_a[i][j] * dk.s_hat[j];
794+
ek.t_hat[i] += ek.mat_a[i][j] * dk.s_hat[j];
785795
}
786796
}
787797

788798
// Step 19
789-
for (re, ek_part) in dk
790-
.ek
799+
for (re, ek_part) in ek
791800
.t_hat
792801
.iter()
793-
.zip(dk.ek.bytes.chunks_exact_mut(ENCODE_SIZE_POLY))
802+
.zip(ek.bytes.chunks_exact_mut(ENCODE_SIZE_POLY))
794803
{
795804
ByteSerialization::encode_12(&re.coefficients, ek_part);
796805
}
797806

798807
let idx = ENCODED_SIZE_EK - rho.len();
799-
dk.ek.bytes[idx..].copy_from_slice(&rho);
808+
ek.bytes[idx..].copy_from_slice(&rho);
800809

801810
// Cache hash of ek so we don't need to re-compute for every encap().
802-
let h_ek = Sha3_256::digest(&dk.ek.bytes)?;
803-
dk.ek.h_ek = h_ek.value;
811+
let h_ek = Sha3_256::digest(&ek.bytes)?;
812+
ek.h_ek = h_ek.value;
804813

805814
// Step 20
806815
for (re, dk_part) in dk
@@ -827,7 +836,7 @@ impl<Pke: PkeParameters> KeyPairInternal<Pke> {
827836
),
828837
UnknownCryptoError,
829838
> {
830-
let ek = EncapKey::<K, ENCODED_SIZE_EK, Pke> {
839+
let mut encap_key = EncapKey::<K, ENCODED_SIZE_EK, Pke> {
831840
bytes: [0u8; ENCODED_SIZE_EK],
832841
h_ek: [0u8; SHA3_256_OUTSIZE],
833842
t_hat: [RingElementNTT::zero(); K],
@@ -839,28 +848,27 @@ impl<Pke: PkeParameters> KeyPairInternal<Pke> {
839848
let mut decap_key = DecapKey::<K, ENCODED_SIZE_EK, ENCODED_SIZE_DK, Pke> {
840849
bytes: [0u8; ENCODED_SIZE_DK],
841850
s_hat: [RingElementNTT::zero(); K],
842-
ek,
851+
_phantom: PhantomData,
843852
};
844853

845854
// Step 1 + 2. (ekPKE, dkPKE) ← K-PKE.KeyGen(d)
846-
Self::keygen(&seed.unprotected_as_bytes()[..32], &mut decap_key)?;
855+
Self::keygen(
856+
&seed.unprotected_as_bytes()[..32],
857+
&mut encap_key,
858+
&mut decap_key,
859+
)?;
847860

848861
// Step 3. dk ← (dkPKE‖ek‖H(ek)‖z)
849-
let ek_bytes = decap_key.ek.as_ref();
850862
decap_key.bytes[(ENCODE_SIZE_POLY * K)..(ENCODE_SIZE_POLY * K) + Pke::EK_SIZE]
851-
.copy_from_slice(ek_bytes);
863+
.copy_from_slice(&encap_key.bytes);
852864
decap_key.bytes
853865
[(ENCODE_SIZE_POLY * K) + Pke::EK_SIZE..(ENCODE_SIZE_POLY * K) + Pke::EK_SIZE + 32]
854-
.copy_from_slice(Sha3_256::digest(ek_bytes).unwrap().as_ref());
866+
.copy_from_slice(Sha3_256::digest(&encap_key.bytes).unwrap().as_ref());
855867
decap_key.bytes[(ENCODE_SIZE_POLY * K) + Pke::EK_SIZE + 32
856868
..(ENCODE_SIZE_POLY * K) + Pke::EK_SIZE + 32 + 32]
857869
.copy_from_slice(&seed.unprotected_as_bytes()[32..64]);
858870

859-
let encap_key = decap_key.ek.clone(); // TODO: Can we maybe get rid of this clone? Probably some internal API changes propagating upwards
860-
debug_assert_eq!(
861-
decap_key.get_encapsulation_key_bytes(),
862-
decap_key.ek.as_ref()
863-
);
871+
debug_assert_eq!(decap_key.get_encapsulation_key_bytes(), encap_key.as_ref());
864872

865873
Ok((encap_key, decap_key))
866874
}
@@ -927,20 +935,6 @@ mod tests {
927935
use crate::hazardous::kem::ml_kem::mlkem512::KeyPair as MlKem512KeyPair;
928936
use crate::hazardous::kem::ml_kem::mlkem768::KeyPair as MlKem768KeyPair;
929937

930-
#[test]
931-
fn test_keypair_dk_ek_match_internal() {
932-
let seed = Seed::from_slice(&[128u8; 64]).unwrap();
933-
934-
let kp = MlKem512KeyPair::try_from(&seed).unwrap();
935-
assert_eq!(kp.public().value, kp.private().value.ek);
936-
937-
let kp = MlKem768KeyPair::try_from(&seed).unwrap();
938-
assert_eq!(kp.public().value, kp.private().value.ek);
939-
940-
let kp = MlKem1024KeyPair::try_from(&seed).unwrap();
941-
assert_eq!(kp.public().value, kp.private().value.ek);
942-
}
943-
944938
#[test]
945939
#[cfg(feature = "safe_api")]
946940
fn test_seed_and_dk_mismatch() {

src/hazardous/kem/ml_kem/mlkem1024.rs

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ impl_from_trait!(Ciphertext, MlKem1024Internal::CIPHERTEXT_SIZE);
110110
/// A keypair of ML-KEM-1024 keys, that are derived from a given seed.
111111
pub struct KeyPair {
112112
seed: Seed,
113-
ek: EncapsulationKey,
114113
dk: DecapsulationKey,
115114
}
116115

@@ -124,8 +123,10 @@ impl KeyPair {
124123

125124
Ok(Self {
126125
seed,
127-
ek: EncapsulationKey { value: ek },
128-
dk: DecapsulationKey { value: dk },
126+
dk: DecapsulationKey {
127+
value: dk,
128+
cached_ek: EncapsulationKey { value: ek },
129+
},
129130
})
130131
}
131132

@@ -147,8 +148,10 @@ impl KeyPair {
147148

148149
Ok(Self {
149150
seed: Seed::from_slice(seed.unprotected_as_bytes()).unwrap(),
150-
ek: EncapsulationKey { value: ek },
151-
dk: DecapsulationKey { value: dk },
151+
dk: DecapsulationKey {
152+
value: dk,
153+
cached_ek: EncapsulationKey { value: ek },
154+
},
152155
})
153156
}
154157

@@ -160,7 +163,7 @@ impl KeyPair {
160163

161164
/// Get the public [EncapsulationKey] corresponding to this keypair.
162165
pub fn public(&self) -> &EncapsulationKey {
163-
&self.ek
166+
&self.dk.cached_ek
164167
}
165168

166169
/// Get the private [DecapsulationKey] used to generate this keypair. In order to store the private
@@ -178,8 +181,10 @@ impl TryFrom<&Seed> for KeyPair {
178181

179182
Ok(Self {
180183
seed: Seed::from_slice(value.unprotected_as_bytes()).unwrap(),
181-
ek: EncapsulationKey { value: ek },
182-
dk: DecapsulationKey { value: dk },
184+
dk: DecapsulationKey {
185+
value: dk,
186+
cached_ek: EncapsulationKey { value: ek },
187+
},
183188
})
184189
}
185190
}
@@ -188,6 +193,10 @@ impl TryFrom<&Seed> for KeyPair {
188193
/// A type to represent the `DecapsulationKey` that ML-KEM-1024 produces.
189194
pub struct DecapsulationKey {
190195
pub(crate) value: DecapKey<4, 1568, 3168, MlKem1024Internal>,
196+
// NOTE(brycx): This is simply a cache of the encapsulation key, so we avoid recomputing it
197+
// on decap() operations. This is not a part of PartialEq, AsRef<> implementations or other logic
198+
// pertaining to the `DecapsulationKey`, serving a purely internal purpose.
199+
pub(crate) cached_ek: EncapsulationKey,
191200
}
192201

193202
impl PartialEq<&[u8]> for DecapsulationKey {
@@ -200,17 +209,25 @@ impl PartialEq<&[u8]> for DecapsulationKey {
200209
impl DecapsulationKey {
201210
/// Instantiate a [DecapsulationKey] with only key-checks from FIPS-203, section 7.3. Not MAL-BIND-K-CT secure.
202211
pub fn unchecked_from_slice(slice: &[u8]) -> Result<Self, UnknownCryptoError> {
212+
let dk_unchecked =
213+
DecapKey::<4, 1568, 3168, MlKem1024Internal>::unchecked_from_slice(slice)?;
214+
let ek_unchecked =
215+
EncapsulationKey::from_slice(dk_unchecked.get_encapsulation_key_bytes())?;
216+
203217
Ok(Self {
204-
value: DecapKey::<4, 1568, 3168, MlKem1024Internal>::unchecked_from_slice(slice)?,
218+
value: dk_unchecked,
219+
cached_ek: ek_unchecked,
205220
})
206221
}
207222

208223
/// Perform decapsulation of a [Ciphertext].
209224
pub fn decap(&self, c: &Ciphertext) -> Result<SharedSecret, UnknownCryptoError> {
210225
let mut c_prime_buf = [0u8; MlKem1024Internal::CIPHERTEXT_SIZE];
211-
let mut k_internal = self
212-
.value
213-
.mlkem_decap_internal(c.as_ref(), &mut c_prime_buf)?;
226+
let mut k_internal = self.value.mlkem_decap_internal_with_ek(
227+
c.as_ref(),
228+
&mut c_prime_buf,
229+
&self.cached_ek.value,
230+
)?;
214231
let k = SharedSecret::from_slice(&k_internal)?;
215232
k_internal.zeroize();
216233

@@ -329,7 +346,7 @@ mod tests {
329346
let kp = KeyPair::try_from(&Seed::from_slice(seed).unwrap()).unwrap();
330347

331348
Ok((
332-
kp.ek.as_ref().to_vec(),
349+
kp.dk.cached_ek.as_ref().to_vec(),
333350
kp.dk.value.unprotected_as_bytes().to_vec(),
334351
))
335352
}
@@ -349,11 +366,41 @@ mod tests {
349366
}
350367
}
351368

369+
#[test]
370+
fn test_keypair_dk_ek_match_internal() {
371+
let seed = Seed::from_slice(&[128u8; 64]).unwrap();
372+
let kp = KeyPair::try_from(&seed).unwrap();
373+
assert_eq!(kp.public(), &kp.private().cached_ek);
374+
}
375+
376+
#[test]
377+
#[cfg(feature = "safe_api")]
378+
fn test_dk_cached_ek() {
379+
let seed = Seed::from_slice(&[128u8; 64]).unwrap();
380+
let kp = KeyPair::try_from(&seed).unwrap();
381+
let (ss_pubapi, ct_pubapi) = kp.public().encap_deterministic(&[125u8; 32]).unwrap();
382+
let mut c_prime = [0u8; MlKem1024Internal::CIPHERTEXT_SIZE];
383+
// This call re-computes encap key internally from the bytes a decapkey would store.
384+
let ss_privapi = kp
385+
.private()
386+
.value
387+
.mlkem_decap_internal(ct_pubapi.as_ref(), &mut c_prime)
388+
.unwrap();
389+
assert_eq!(ss_privapi.as_ref(), ss_pubapi.unprotected_as_bytes());
390+
assert_eq!(
391+
MlKem1024::decap(kp.private(), &ct_pubapi).unwrap(),
392+
ss_pubapi
393+
);
394+
}
395+
352396
#[cfg(feature = "safe_api")]
353397
#[test]
354398
fn test_dk_to_ek_conversions() {
355399
let kp = KeyPair::generate().unwrap();
356-
assert_eq!(kp.ek, EncapsulationKey::try_from(kp.private()).unwrap());
400+
assert_eq!(
401+
kp.dk.cached_ek,
402+
EncapsulationKey::try_from(kp.private()).unwrap()
403+
);
357404
}
358405

359406
#[cfg(feature = "safe_api")]

0 commit comments

Comments
 (0)