Skip to content

Commit 6424004

Browse files
committed
readability and Cargo feature improvements
1 parent eae94ad commit 6424004

File tree

4 files changed

+109
-70
lines changed

4 files changed

+109
-70
lines changed

cryprot-ot/Cargo.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@ authors.workspace = true
1010
repository.workspace = true
1111

1212
[features]
13-
# Use ML-KEM-based OT for base OT.
14-
ml-kem-base-ot = []
13+
# ML-KEM-based base OT. Pickm only one variant.
14+
ml-kem-base-ot-512 = ["_ml-kem-base-ot"]
15+
ml-kem-base-ot-768 = ["_ml-kem-base-ot"]
16+
ml-kem-base-ot-1024 = ["_ml-kem-base-ot"]
17+
# Internal feature — do not enable directly.
18+
_ml-kem-base-ot = []
1519

1620
[lints]
1721
workspace = true

cryprot-ot/docs/mlkem-ot-protocol.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ In libOTe, this corresponds to `randomPK`, where it instead generates `A_hat` an
6969

7070
Output: `(t_hat, rho)`. The `rho` is passed through unchanged.
7171

72-
**`H(ek) -> (h, ek.rho)`**
72+
**`HashToKey(ek) -> (h, ek.rho)`**
7373

74-
Hash-to-key (corresponds to libOTe's `pkHash`). Maps an encapsulation key to another
74+
HashToKey corresponds to libOTe's `pkHash`. Maps an encapsulation key to another
7575
encapsulation key. Takes an element of `T_q^k`, hashes it to a 32-byte seed, and uses that seed to sample a new element of `T_q^k`.
7676

7777
Given an encapsulation key `ek = (t_hat, rho)`:

cryprot-ot/src/lib.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
//!
1919
//! ## ML-KEM Base OT
2020
//!
21-
//! Enable the `ml-kem-base-ot` feature to use ML-KEM-based OT for the base OT
21+
//! Enable one of the `ml-kem-base-ot-{512,768,1024}` features to use ML-KEM-based OT for the base OT
2222
//! protocol, providing post-quantum security:
2323
//!
2424
//! This replaces the classical "Simplest OT" with an ML-KEM-based construction
@@ -68,6 +68,7 @@ use subtle::Choice;
6868

6969
pub mod adapter;
7070
pub mod extension;
71+
#[cfg(feature = "_ml-kem-base-ot")]
7172
pub mod mlkem_ot;
7273
pub mod noisy_vole;
7374
pub mod phase;
@@ -77,22 +78,22 @@ pub mod simplest_ot;
7778
/// Base OT implementation used by extension protocols.
7879
///
7980
/// When the `ml-kem-base-ot` feature is enabled, use [`mlkem_ot::MlKemOt`]
80-
#[cfg(feature = "ml-kem-base-ot")]
81+
#[cfg(feature = "_ml-kem-base-ot")]
8182
pub type BaseOt = mlkem_ot::MlKemOt;
8283

8384
/// Base OT implementation used by extension protocols.
8485
///
8586
/// When the `ml-kem-base-ot` feature is not enabled, use
8687
/// [`simplest_ot::SimplestOt`].
87-
#[cfg(not(feature = "ml-kem-base-ot"))]
88+
#[cfg(not(feature = "_ml-kem-base-ot"))]
8889
pub type BaseOt = simplest_ot::SimplestOt;
8990

9091
/// Error type for base OT operations.
91-
#[cfg(feature = "ml-kem-base-ot")]
92+
#[cfg(feature = "_ml-kem-base-ot")]
9293
pub type BaseOtError = mlkem_ot::Error;
9394

9495
/// Error type for base OT operations.
95-
#[cfg(not(feature = "ml-kem-base-ot"))]
96+
#[cfg(not(feature = "_ml-kem-base-ot"))]
9697
pub type BaseOtError = simplest_ot::Error;
9798

9899
/// Trait for OT receivers/senders which hold a [`Connection`].

cryprot-ot/src/mlkem_ot.rs

Lines changed: 95 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,22 @@ use cryprot_core::{Block, buf::Buf, rand_compat::RngCompat, random_oracle::Rando
1010
use cryprot_net::{Connection, ConnectionError};
1111
use futures::{SinkExt, StreamExt};
1212
use hybrid_array::typenum::Unsigned;
13-
// ML-KEM variant: change to MlKem512/MlKem512Params or MlKem768/MlKem768Params
14-
// for different security levels.
1513
use ml_kem::{
16-
Ciphertext as MlKemCiphertext, EncodedSizeUser, KemCore, MlKem1024 as MlKem,
17-
MlKem1024Params as MlKemParams, ParameterSet, SharedKey,
14+
Ciphertext as MlKemCiphertext, EncodedSizeUser, KemCore, ParameterSet, SharedKey,
1815
kem::{Decapsulate, DecapsulationKey, Encapsulate, EncapsulationKey as MlKemEncapsulationKey},
1916
};
17+
// ML-KEM parameter set selection.
18+
#[cfg(feature = "ml-kem-base-ot-512")]
19+
use ml_kem::{MlKem512 as MlKem, MlKem512Params as MlKemParams};
20+
#[cfg(feature = "ml-kem-base-ot-768")]
21+
use ml_kem::{MlKem768 as MlKem, MlKem768Params as MlKemParams};
22+
#[cfg(feature = "ml-kem-base-ot-1024")]
23+
use ml_kem::{MlKem1024 as MlKem, MlKem1024Params as MlKemParams};
2024
use module_lattice::{Encode, Field, NttPolynomial};
2125
use rand::{RngExt, rngs::StdRng};
2226
use serde::{Deserialize, Serialize};
2327
use sha3::{
24-
Shake128,
28+
Digest, Shake128,
2529
digest::{ExtendableOutput, Update, XofReader},
2630
};
2731
use subtle::{Choice, ConditionallySelectable};
@@ -49,43 +53,69 @@ const NUM_COEFFICIENTS: usize = 256;
4953

5054
type Seed = [u8; 32];
5155

52-
/// rho is the seed used to derive the public matrix A_hat (FIPS 203).
53-
type Rho = Seed;
54-
5556
// Serialized t_hat is the encapsulation key minus the rho suffix.
56-
const T_HAT_BYTES_LEN: usize = ENCAPSULATION_KEY_LEN - size_of::<Rho>();
57+
const T_HAT_BYTES_LEN: usize = ENCAPSULATION_KEY_LEN - size_of::<Seed>();
5758

5859
// ---------------------------------------------------------------------------
5960
// Protocol helper functions (see docs/mlkem-ot-protocol.md)
6061
// ---------------------------------------------------------------------------
6162

62-
/// Parse serialized encapsulation key bytes into (NttVector, rho).
63-
/// The input is a fixed-size array so the slicing is infallible.
64-
fn parse_ek(bytes: &[u8; ENCAPSULATION_KEY_LEN]) -> (NttVector, Rho) {
65-
let enc = bytes[..T_HAT_BYTES_LEN]
66-
.try_into()
67-
.expect("t_hat length mismatch");
68-
let t_hat = <NttVector as Encode<U12>>::decode(enc);
69-
let rho = bytes[T_HAT_BYTES_LEN..]
70-
.try_into()
71-
.expect("rho length mismatch");
72-
(t_hat, rho)
63+
/// Parsed encapsulation key: ek = (t_hat, rho).
64+
struct EncapsulationKey {
65+
t_hat: NttVector,
66+
rho: Seed,
67+
}
68+
69+
impl EncapsulationKey {
70+
/// Parse from serialized bytes.
71+
fn from_bytes(bytes: &[u8; ENCAPSULATION_KEY_LEN]) -> Self {
72+
let enc = bytes[..T_HAT_BYTES_LEN]
73+
.try_into()
74+
.expect("t_hat length mismatch");
75+
let t_hat = <NttVector as Encode<U12>>::decode(enc);
76+
let rho = bytes[T_HAT_BYTES_LEN..]
77+
.try_into()
78+
.expect("rho length mismatch");
79+
Self { t_hat, rho }
80+
}
81+
82+
/// Serialize to bytes.
83+
fn to_bytes(&self) -> [u8; ENCAPSULATION_KEY_LEN] {
84+
let encoded = <NttVector as Encode<U12>>::encode(&self.t_hat);
85+
let mut out = [0u8; ENCAPSULATION_KEY_LEN];
86+
out[..T_HAT_BYTES_LEN].copy_from_slice(encoded.as_slice());
87+
out[T_HAT_BYTES_LEN..].copy_from_slice(&self.rho);
88+
out
89+
}
90+
}
91+
92+
impl std::ops::Sub<&NttVector> for &EncapsulationKey {
93+
type Output = EncapsulationKey;
94+
95+
fn sub(self, rhs: &NttVector) -> EncapsulationKey {
96+
EncapsulationKey {
97+
t_hat: &self.t_hat - rhs,
98+
rho: self.rho,
99+
}
100+
}
73101
}
74102

75-
/// Serialize NttVector + rho back into encapsulation key bytes.
76-
fn serialize_ek(t_hat: &NttVector, rho: &Rho) -> [u8; ENCAPSULATION_KEY_LEN] {
77-
let encoded = <NttVector as Encode<U12>>::encode(t_hat);
78-
let mut out = [0u8; ENCAPSULATION_KEY_LEN];
79-
out[..T_HAT_BYTES_LEN].copy_from_slice(encoded.as_slice());
80-
out[T_HAT_BYTES_LEN..].copy_from_slice(rho);
81-
out
103+
impl std::ops::Add<&NttVector> for &EncapsulationKey {
104+
type Output = EncapsulationKey;
105+
106+
fn add(self, rhs: &NttVector) -> EncapsulationKey {
107+
EncapsulationKey {
108+
t_hat: &self.t_hat + rhs,
109+
rho: self.rho,
110+
}
111+
}
82112
}
83113

84114
/// XOF(rho, j, i) from FIPS 203, Algorithm 2 SHAKE128example.
85115
/// In Algorithm 13 (K-PKE.KeyGen), this is called as XOF(rho, j, i) where
86116
/// j is the column index (byte 32) and i is the row index (byte 33),
87117
/// using 0-based indexing.
88-
fn xof(seed: &Rho, j: u8, i: u8) -> impl XofReader {
118+
fn xof(seed: &Seed, j: u8, i: u8) -> impl XofReader {
89119
let mut h = Shake128::default();
90120
h.update(seed);
91121
h.update(&[i, j]);
@@ -103,11 +133,12 @@ fn sample_ntt_poly(xof: &mut impl XofReader) -> NttPolynomial<MlKemField> {
103133
const BUF_LEN: usize = 32 * 3;
104134
let mut poly = NttPolynomial::<MlKemField>::default();
105135
let mut buf = [0u8; BUF_LEN];
106-
let mut pos = BUF_LEN; // start at end to trigger first read
136+
xof.read(&mut buf);
137+
let mut pos = 0;
107138
let mut i = 0;
108139

109140
while i < NUM_COEFFICIENTS {
110-
// Read BUF_LEN chunks from the XOF, consume and then refill once exhausted.
141+
// Refill the buffer from the XOF stream when exhausted.
111142
if pos >= BUF_LEN {
112143
xof.read(&mut buf);
113144
pos = 0;
@@ -144,19 +175,22 @@ fn sample_ntt_vector(seed: &Seed) -> NttVector {
144175
)
145176
}
146177

147-
/// H(ek): hash-to-key. Maps an NttVector to another NttVector via SHA3-256.
178+
/// Maps an encapsulation key to an NttVector via SHA3-256.
179+
/// Only the t_hat component is hashed; rho is ignored.
148180
/// Corresponds to libOTe's `pkHash`.
149-
fn hash_to_key(t_hat: &NttVector) -> NttVector {
150-
use sha3::Digest;
151-
let encoded = <NttVector as Encode<U12>>::encode(t_hat);
152-
let seed: Rho = sha3::Sha3_256::digest(encoded.as_slice()).into();
181+
fn hash_to_key(ek: &EncapsulationKey) -> NttVector {
182+
let encoded = <NttVector as Encode<U12>>::encode(&ek.t_hat);
183+
let seed: Seed = sha3::Sha3_256::digest(encoded.as_slice()).into();
153184
sample_ntt_vector(&seed)
154185
}
155186

156-
/// RandomEK: generate a random NttVector from a random seed.
157-
fn random_ek(rng: &mut StdRng) -> NttVector {
158-
let seed: Rho = rng.random();
159-
sample_ntt_vector(&seed)
187+
/// RandomEK: generate a random encapsulation key from a random seed.
188+
fn random_ek(rng: &mut StdRng, rho: Seed) -> EncapsulationKey {
189+
let seed: Seed = rng.random();
190+
EncapsulationKey {
191+
t_hat: sample_ntt_vector(&seed),
192+
rho,
193+
}
160194
}
161195

162196
// ---------------------------------------------------------------------------
@@ -184,9 +218,9 @@ pub enum Error {
184218
}
185219

186220
#[derive(Copy, Clone, Serialize, Deserialize)]
187-
struct EncapKeyBytes(#[serde(with = "serde_bytes")] [u8; ENCAPSULATION_KEY_LEN]);
221+
struct EncapsulationKeyBytes(#[serde(with = "serde_bytes")] [u8; ENCAPSULATION_KEY_LEN]);
188222

189-
impl ConditionallySelectable for EncapKeyBytes {
223+
impl ConditionallySelectable for EncapsulationKeyBytes {
190224
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
191225
Self(<[u8; ENCAPSULATION_KEY_LEN]>::conditional_select(
192226
&a.0, &b.0, choice,
@@ -208,8 +242,8 @@ impl ConditionallySelectable for CtBytes {
208242
// Message from receiver to sender: two values (r_0, r_1) per OT.
209243
#[derive(Serialize, Deserialize)]
210244
struct EncapsulationKeysMessage {
211-
eks0: Vec<EncapKeyBytes>,
212-
eks1: Vec<EncapKeyBytes>,
245+
eks0: Vec<EncapsulationKeyBytes>,
246+
eks1: Vec<EncapsulationKeyBytes>,
213247
}
214248

215249
// Message from sender to receiver: two ciphertexts per OT.
@@ -277,16 +311,16 @@ impl RotSender for MlKemOt {
277311
.enumerate()
278312
{
279313
// Reconstruct encapsulation keys: ek_j = r_j + H(r_{1-j})
280-
let (r0, rho) = parse_ek(&r0_bytes.0);
281-
let (r1, _) = parse_ek(&r1_bytes.0);
314+
let r0 = EncapsulationKey::from_bytes(&r0_bytes.0);
315+
let r1 = EncapsulationKey::from_bytes(&r1_bytes.0);
282316

283-
let ek0_bytes = serialize_ek(&(&r0 + &hash_to_key(&r1)), &rho);
284-
let ek1_bytes = serialize_ek(&(&r1 + &hash_to_key(&r0)), &rho);
317+
let ek0 = &r0 + &hash_to_key(&r1);
318+
let ek1 = &r1 + &hash_to_key(&r0);
285319

286-
let (ct0, key0) = encapsulate(&EncapKeyBytes(ek0_bytes), &mut self.rng);
320+
let (ct0, key0) = encapsulate(&EncapsulationKeyBytes(ek0.to_bytes()), &mut self.rng);
287321
let key0 = hash(&key0, i);
288322

289-
let (ct1, key1) = encapsulate(&EncapKeyBytes(ek1_bytes), &mut self.rng);
323+
let (ct1, key1) = encapsulate(&EncapsulationKeyBytes(ek1.to_bytes()), &mut self.rng);
290324
let key1 = hash(&key1, i);
291325

292326
cts0.push(ct0);
@@ -331,23 +365,23 @@ impl RotReceiver for MlKemOt {
331365
.as_slice()
332366
.try_into()
333367
.expect("incorrect encapsulation key size");
334-
let (real_t_hat, rho) = parse_ek(&ek_bytes);
368+
let ek = EncapsulationKey::from_bytes(&ek_bytes);
335369

336370
// Step 2: Sample random key for position 1-b.
337-
let rand_t_hat = random_ek(&mut self.rng);
371+
let r_1_b = random_ek(&mut self.rng, ek.rho);
338372

339-
// Step 3: Compute correlated key for position b: r_b = ek - H(r_{1-b}).
340-
let correlated_t_hat = &real_t_hat - &hash_to_key(&rand_t_hat);
373+
// Step 3: Compute correlated key: r_b = ek - H(r_{1-b}).
374+
let r_b = &ek - &hash_to_key(&r_1_b);
341375

342-
// Serialize both keys with the same rho.
343-
let correlated_bytes = EncapKeyBytes(serialize_ek(&correlated_t_hat, &rho));
344-
let random_bytes = EncapKeyBytes(serialize_ek(&rand_t_hat, &rho));
376+
// Serialize both keys.
377+
let r_b_bytes = EncapsulationKeyBytes(r_b.to_bytes());
378+
let r_1_b_bytes = EncapsulationKeyBytes(r_1_b.to_bytes());
345379

346380
// Step 4: Select (r_0, r_1) based on choice bit (constant-time).
347-
// If b=0: r_0 = correlated (real side), r_1 = random.
348-
// If b=1: r_0 = random, r_1 = correlated (real side).
349-
let ek0 = EncapKeyBytes::conditional_select(&correlated_bytes, &random_bytes, *choice);
350-
let ek1 = EncapKeyBytes::conditional_select(&random_bytes, &correlated_bytes, *choice);
381+
// If b=0: r_0 = real, r_1 = random.
382+
// If b=1: r_0 = random, r_1 = real.
383+
let ek0 = EncapsulationKeyBytes::conditional_select(&r_b_bytes, &r_1_b_bytes, *choice);
384+
let ek1 = EncapsulationKeyBytes::conditional_select(&r_1_b_bytes, &r_b_bytes, *choice);
351385

352386
decap_keys.push(dk);
353387
eks0.push(ek0);
@@ -397,7 +431,7 @@ impl RotReceiver for MlKemOt {
397431
}
398432

399433
// Encapsulates to the given key, returning the ciphertext and the shared key.
400-
fn encapsulate(ek: &EncapKeyBytes, rng: &mut StdRng) -> (CtBytes, SharedKey<MlKem>) {
434+
fn encapsulate(ek: &EncapsulationKeyBytes, rng: &mut StdRng) -> (CtBytes, SharedKey<MlKem>) {
401435
let parsed_ek = MlKemEncapsulationKey::<MlKemParams>::from_bytes((&ek.0).into());
402436
let (ct, k): (MlKemCiphertext<MlKem>, SharedKey<MlKem>) = parsed_ek
403437
.encapsulate(&mut RngCompat(rng))

0 commit comments

Comments
 (0)