Skip to content

Commit 9aeb92a

Browse files
authored
mlkem: Cache hash of encapsulation key (#464)
* mlkem: Cache hash of encapsulation key to avoid re-computing for multiple encap() calls * Update CHANGELOG * mlkem: Test KeyPair internal ek caches * nit
1 parent 9412af3 commit 9aeb92a

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
**Changelog:**
66
- Add `encap_deterministic()` and `auth_encap_deterministic()` to `DhKem` in `hazardous::kem::x25519_hkdf_sha256::DhKem` [#458](https://github.com/orion-rs/orion/pull/458).
7-
- Make `hazardous::kem::x25519_hkdf_sha256::DhKem` available in `[no_std]` context [#458](https://github.com/orion-rs/orion/pull/458).
7+
- Make `hazardous::kem::x25519_hkdf_sha256::DhKem` available in `#![no_std]` context [#458](https://github.com/orion-rs/orion/pull/458).
88
- Add support for HPKE (RFC 9180) [#458](https://github.com/orion-rs/orion/pull/458).
99
- Switch to source-based code coverage [#462](https://github.com/orion-rs/orion/pull/462).
10+
- ML-KEM (internal): Cache hash of encapsulation key to save computation on multiple `encap()` operations [#464](https://github.com/orion-rs/orion/pull/464).
11+
- ML-KEM (internal): Cache encapsulation key within decapsulation key, to avoid re-computation after generation of decapsulation key [#447](https://github.com/orion-rs/orion/pull/447).
1012

1113
### 0.17.9
1214

src/hazardous/kem/ml_kem/internal/mod.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ pub(crate) mod serialization;
3232
/// Sampling ring elements from seeds.
3333
pub(crate) mod sampling;
3434

35-
use crate::errors::UnknownCryptoError;
3635
use crate::hazardous::hash::sha3::sha3_256::Sha3_256;
3736
use crate::hazardous::hash::sha3::sha3_512::Sha3_512;
3837
use crate::hazardous::hash::sha3::shake256;
3938
use crate::hazardous::kem::ml_kem::Seed;
39+
use crate::{errors::UnknownCryptoError, hazardous::hash::sha3::sha3_256::SHA3_256_OUTSIZE};
4040
use core::marker::PhantomData;
4141
use fe::*;
4242
use re::*;
@@ -323,6 +323,7 @@ impl PkeParameters for MlKem1024Internal {
323323
/// ML-KEM encapsulation key.
324324
pub(crate) struct EncapKey<const K: usize, const ENCODED_SIZE: usize, Pke: PkeParameters> {
325325
pub(crate) bytes: [u8; ENCODED_SIZE],
326+
h_ek: [u8; SHA3_256_OUTSIZE],
326327
t_hat: [RingElementNTT; K],
327328
mat_a: [[RingElementNTT; K]; K],
328329
_phantom: PhantomData<Pke>,
@@ -371,8 +372,12 @@ impl<const K: usize, const ENCODED_SIZE: usize, Pke: PkeParameters> EncapKey<K,
371372
}
372373
}
373374

375+
// Cache hash of the bytes so we don't need to re-compute for every encap().
376+
let h_ek = Sha3_256::digest(slice)?;
377+
374378
Ok(Self {
375379
bytes: slice.try_into().unwrap(), // NOTE: Should never panic if encapsulation_key_check() succeeds.
380+
h_ek: h_ek.value,
376381
t_hat,
377382
mat_a,
378383
_phantom: PhantomData,
@@ -500,7 +505,7 @@ impl<const K: usize, const ENCODED_SIZE: usize, Pke: PkeParameters> EncapKey<K,
500505
}
501506

502507
// Step 1: (K, r) ← G(m‖H(ek))
503-
let (k, r) = g(&[m, Sha3_256::digest(&self.bytes).unwrap().as_ref()]);
508+
let (k, r) = g(&[m, self.h_ek.as_ref()]);
504509

505510
// Step 2: c ← K-PKE.Encrypt(ek, m, r)
506511
self.encrypt(m, r.as_ref(), c)?;
@@ -793,6 +798,10 @@ impl<Pke: PkeParameters> KeyPairInternal<Pke> {
793798
let idx = ENCODED_SIZE_EK - rho.len();
794799
dk.ek.bytes[idx..].copy_from_slice(&rho);
795800

801+
// Cache hash of ek so we don't need to re-compute for every encap().
802+
let h_ek = Sha3_256::digest(&dk.ek.bytes)?;
803+
dk.ek.h_ek = h_ek.value;
804+
796805
// Step 20
797806
for (re, dk_part) in dk
798807
.s_hat
@@ -820,6 +829,7 @@ impl<Pke: PkeParameters> KeyPairInternal<Pke> {
820829
> {
821830
let ek = EncapKey::<K, ENCODED_SIZE_EK, Pke> {
822831
bytes: [0u8; ENCODED_SIZE_EK],
832+
h_ek: [0u8; SHA3_256_OUTSIZE],
823833
t_hat: [RingElementNTT::zero(); K],
824834
mat_a: [[RingElementNTT::zero(); K]; K],
825835
_phantom: PhantomData,
@@ -917,6 +927,20 @@ mod tests {
917927
use crate::hazardous::kem::ml_kem::mlkem512::KeyPair as MlKem512KeyPair;
918928
use crate::hazardous::kem::ml_kem::mlkem768::KeyPair as MlKem768KeyPair;
919929

930+
#[test]
931+
fn test_keypair_dk_ek_match_internal() {
932+
let seed = Seed::from_slice(&[128u8; 64]).unwrap();
933+
934+
let kp = MlKem512KeyPair::try_from(&seed).unwrap();
935+
assert_eq!(kp.public().value, kp.private().value.ek);
936+
937+
let kp = MlKem768KeyPair::try_from(&seed).unwrap();
938+
assert_eq!(kp.public().value, kp.private().value.ek);
939+
940+
let kp = MlKem1024KeyPair::try_from(&seed).unwrap();
941+
assert_eq!(kp.public().value, kp.private().value.ek);
942+
}
943+
920944
#[test]
921945
#[cfg(feature = "safe_api")]
922946
fn test_seed_and_dk_mismatch() {

0 commit comments

Comments
 (0)