Skip to content

Commit e0c39a8

Browse files
committed
initial implementation using proper keys
1 parent 484eb25 commit e0c39a8

File tree

5 files changed

+197
-22
lines changed

5 files changed

+197
-22
lines changed

Cargo.lock

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ futures = "0.3.32"
3232
hybrid-array = { version = "0.4.7", features = ["bytemuck"] }
3333
libc = "0.2.181"
3434
ml-kem = "0.2.2"
35+
module-lattice = "0.1.0"
36+
sha3 = "0.10.8"
3537
ndarray = "0.17.2"
3638
num-traits = "0.2.19"
3739
rand = "0.10.0"

cryprot-ot/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@ cryprot-net.workspace = true
2929
cryprot-pprf.workspace = true
3030
curve25519-dalek = { workspace = true, features = ["rand_core", "serde"] }
3131
futures.workspace = true
32+
hybrid-array.workspace = true
3233
ml-kem.workspace = true
34+
module-lattice.workspace = true
3335
rand.workspace = true
36+
sha3.workspace = true
3437
serde_bytes.workspace = true
3538
serde = { workspace = true, features = ["derive"] }
3639
subtle.workspace = true

cryprot-ot/docs/mlkem-ot-protocol.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ ek_b = r_b + H(r_{1-b})
178178
= (ek - H(r_{1-b})) + H(r_{1-b})
179179
= ek
180180
```
181-
So `ek_b = ek`, the real public key. In step 10, the receiver calls `ML-KEM.Decaps(dk, ct_b)` and recovers the same shared secret `ss_b` that the sender computed via `ML-KEM.Encaps(ek_b)` in step 7.
181+
So `ek_b = ek`, the real public key. In step 10, the receiver calls `ML-KEM.Decaps(dk, ct_b)` and
182+
recovers the same shared secret `ss_b` that the sender computed via `ML-KEM.Encaps(ek_b)` in step 7.
182183
183184
**Security:**
184185

cryprot-ot/src/mlkem_ot.rs

Lines changed: 177 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,165 @@
11
//! Post-quantum base OT using ML-KEM.
2+
//!
3+
//! Implements the MR19 protocol (Masny-Rindal, ePrint 2019/706, Figure 8)
4+
//! instantiated with ML-KEM as per Section D.3.
5+
//! See `docs/mlkem-ot-protocol.md` for the full protocol description.
26
3-
use std::io;
7+
use std::{io, mem::size_of};
48

59
use cryprot_core::{Block, buf::Buf, rand_compat::RngCompat, random_oracle::RandomOracle};
610
use cryprot_net::{Connection, ConnectionError};
711
use futures::{SinkExt, StreamExt};
12+
use hybrid_array::typenum::Unsigned;
813
// ML-KEM variant: change to MlKem512/MlKem512Params or MlKem768/MlKem768Params
914
// for different security levels.
1015
use ml_kem::{
1116
Ciphertext as MlKemCiphertext, EncodedSizeUser, KemCore, MlKem1024 as MlKem,
12-
MlKem1024Params as MlKemParams, SharedKey,
13-
array::typenum::Unsigned,
17+
MlKem1024Params as MlKemParams, ParameterSet, SharedKey,
1418
kem::{Decapsulate, DecapsulationKey, Encapsulate, EncapsulationKey as MlKemEncapsulationKey},
1519
};
20+
use module_lattice::{Encode, Field, NttPolynomial};
1621
use rand::{RngExt, rngs::StdRng};
1722
use serde::{Deserialize, Serialize};
23+
use sha3::{
24+
Shake128,
25+
digest::{ExtendableOutput, Update, XofReader},
26+
};
1827
use subtle::{Choice, ConditionallySelectable};
1928
use tracing::Level;
2029

2130
use crate::{Connected, RotReceiver, RotSender, SemiHonest, phase};
2231

32+
// Define the ML-KEM base field (q = 3329).
33+
module_lattice::define_field!(MlKemField, u16, u32, u64, 3329);
34+
35+
// Module dimension derived from the chosen ML-KEM parameter set.
36+
type K = <MlKemParams as ParameterSet>::K;
37+
38+
type NttVector = module_lattice::NttVector<MlKemField, K>;
39+
40+
type U12 = hybrid_array::typenum::U12;
41+
2342
const ENCAPSULATION_KEY_LEN: usize =
2443
<MlKemEncapsulationKey<MlKemParams> as EncodedSizeUser>::EncodedSize::USIZE;
2544
const CIPHERTEXT_LEN: usize = <MlKem as KemCore>::CiphertextSize::USIZE;
2645
const HASH_DOMAIN_SEPARATOR: &[u8] = b"MlKemOt";
2746

47+
// Number of coefficients per polynomial (FIPS 203, Section 2: n = 256).
48+
const NUM_COEFFICIENTS: usize = 256;
49+
50+
/// rho is a 32-byte seed used to derive the public matrix A_hat (FIPS 203).
51+
type Rho = [u8; 32];
52+
53+
// Serialized t_hat is the encapsulation key minus the rho suffix.
54+
const T_HAT_BYTES_LEN: usize = ENCAPSULATION_KEY_LEN - size_of::<Rho>();
55+
56+
// ---------------------------------------------------------------------------
57+
// Protocol helper functions (see docs/mlkem-ot-protocol.md)
58+
// ---------------------------------------------------------------------------
59+
60+
/// Parse serialized encapsulation key bytes into (NttVector, rho).
61+
/// The input is a fixed-size array so the slicing is infallible.
62+
fn parse_ek(bytes: &[u8; ENCAPSULATION_KEY_LEN]) -> (NttVector, Rho) {
63+
let enc = bytes[..T_HAT_BYTES_LEN]
64+
.try_into()
65+
.expect("t_hat length mismatch");
66+
let t_hat = <NttVector as Encode<U12>>::decode(enc);
67+
let rho = bytes[T_HAT_BYTES_LEN..]
68+
.try_into()
69+
.expect("rho length mismatch");
70+
(t_hat, rho)
71+
}
72+
73+
/// Serialize NttVector + rho back into encapsulation key bytes.
74+
fn serialize_ek(t_hat: &NttVector, rho: &Rho) -> [u8; ENCAPSULATION_KEY_LEN] {
75+
let encoded = <NttVector as Encode<U12>>::encode(t_hat);
76+
let mut out = [0u8; ENCAPSULATION_KEY_LEN];
77+
out[..T_HAT_BYTES_LEN].copy_from_slice(encoded.as_slice());
78+
out[T_HAT_BYTES_LEN..].copy_from_slice(rho);
79+
out
80+
}
81+
82+
/// XOF(rho, j, i) from FIPS 203, Algorithm 2 SHAKE128example.
83+
/// In Algorithm 13 (K-PKE.KeyGen), this is called as XOF(rho, j, i) where
84+
/// j is the column index (byte 32) and i is the row index (byte 33),
85+
/// using 0-based indexing.
86+
fn xof(seed: &Rho, j: u8, i: u8) -> impl XofReader {
87+
let mut h = Shake128::default();
88+
h.update(seed);
89+
h.update(&[i, j]);
90+
h.finalize_xof()
91+
}
92+
93+
/// FIPS 203 Algorithm 7: SampleNTT.
94+
/// Rejection sampling from a byte stream to produce a pseudorandom NTT
95+
/// polynomial.
96+
///
97+
/// Adapted from the ml-kem crate's `sample_ntt`.
98+
fn sample_ntt_poly(xof: &mut impl XofReader) -> NttPolynomial<MlKemField> {
99+
const Q: u16 = MlKemField::Q;
100+
// Read 32 triples (3 bytes each) at a time from the XOF.
101+
const BUF_LEN: usize = 96;
102+
let mut poly = NttPolynomial::<MlKemField>::default();
103+
let mut buf = [0u8; BUF_LEN];
104+
let mut pos = BUF_LEN; // start at end to trigger first read
105+
let mut i = 0;
106+
107+
while i < NUM_COEFFICIENTS {
108+
if pos >= BUF_LEN {
109+
xof.read(&mut buf);
110+
pos = 0;
111+
}
112+
113+
let d1 = u16::from(buf[pos]) | ((u16::from(buf[pos + 1]) & 0x0F) << 8);
114+
let d2 = (u16::from(buf[pos + 1]) >> 4) | (u16::from(buf[pos + 2]) << 4);
115+
pos += 3;
116+
117+
if d1 < Q {
118+
poly.0[i] = module_lattice::Elem::new(d1);
119+
i += 1;
120+
}
121+
if i < NUM_COEFFICIENTS && d2 < Q {
122+
poly.0[i] = module_lattice::Elem::new(d2);
123+
i += 1;
124+
}
125+
}
126+
127+
poly
128+
}
129+
130+
/// SampleNTTVector: call SampleNTT k times with FIPS 203 domain separation.
131+
/// Produces a pseudorandom NttVector<k> from a 32-byte seed.
132+
/// Each polynomial j uses XOF(seed || j || 0).
133+
fn sample_ntt_vector(seed: &Rho) -> NttVector {
134+
NttVector::new(
135+
(0..K::USIZE)
136+
.map(|j| {
137+
let mut reader = xof(seed, j as u8, 0);
138+
sample_ntt_poly(&mut reader)
139+
})
140+
.collect(),
141+
)
142+
}
143+
144+
/// H(ek): hash-to-key. Maps an NttVector to another NttVector via SHA3-256.
145+
/// Corresponds to libOTe's `pkHash`.
146+
fn hash_to_key(t_hat: &NttVector) -> NttVector {
147+
use sha3::Digest;
148+
let encoded = <NttVector as Encode<U12>>::encode(t_hat);
149+
let seed: Rho = sha3::Sha3_256::digest(encoded.as_slice()).into();
150+
sample_ntt_vector(&seed)
151+
}
152+
153+
/// RandomEK: generate a random NttVector from a random seed.
154+
fn random_ek(rng: &mut StdRng) -> NttVector {
155+
let seed: Rho = rng.random();
156+
sample_ntt_vector(&seed)
157+
}
158+
159+
// ---------------------------------------------------------------------------
160+
// Wire types and protocol implementation
161+
// ---------------------------------------------------------------------------
162+
28163
#[derive(thiserror::Error, Debug)]
29164
pub enum Error {
30165
#[error("quic connection error")]
@@ -67,8 +202,7 @@ impl ConditionallySelectable for CtBytes {
67202
}
68203
}
69204

70-
// Message from receiver to sender: two encapsulation keys per OT.
71-
// For choice bit c, ek{c} is a real key, ek{1-c} is random bytes.
205+
// Message from receiver to sender: two values (r_0, r_1) per OT.
72206
#[derive(Serialize, Deserialize)]
73207
struct EncapsulationKeysMessage {
74208
eks0: Vec<EncapKeyBytes>,
@@ -133,16 +267,23 @@ impl RotSender for MlKemOt {
133267

134268
let mut cts0 = Vec::with_capacity(count);
135269
let mut cts1 = Vec::with_capacity(count);
136-
for (i, (ek0, ek1)) in receiver_msg
270+
for (i, (r0_bytes, r1_bytes)) in receiver_msg
137271
.eks0
138272
.iter()
139273
.zip(receiver_msg.eks1.iter())
140274
.enumerate()
141275
{
142-
let (ct0, key0) = encapsulate(ek0, &mut self.rng);
276+
// Reconstruct encapsulation keys: ek_j = r_j + H(r_{1-j})
277+
let (r0, rho) = parse_ek(&r0_bytes.0);
278+
let (r1, _) = parse_ek(&r1_bytes.0);
279+
280+
let ek0_bytes = serialize_ek(&(&r0 + &hash_to_key(&r1)), &rho);
281+
let ek1_bytes = serialize_ek(&(&r1 + &hash_to_key(&r0)), &rho);
282+
283+
let (ct0, key0) = encapsulate(&EncapKeyBytes(ek0_bytes), &mut self.rng);
143284
let key0 = hash(&key0, i);
144285

145-
let (ct1, key1) = encapsulate(ek1, &mut self.rng);
286+
let (ct1, key1) = encapsulate(&EncapKeyBytes(ek1_bytes), &mut self.rng);
146287
let key1 = hash(&key1, i);
147288

148289
cts0.push(ct0);
@@ -180,18 +321,30 @@ impl RotReceiver for MlKemOt {
180321
let mut eks1 = Vec::with_capacity(count);
181322

182323
for choice in choices.iter() {
183-
// Generate real keypair.
324+
// Step 1: Generate real keypair.
184325
let (dk, ek) = MlKem::generate(&mut RngCompat(&mut self.rng));
185-
let real_ek = EncapKeyBytes(
186-
ek.as_bytes()
187-
.as_slice()
188-
.try_into()
189-
.expect("incorrect encapsulation key size"),
190-
);
191-
let fake_ek = EncapKeyBytes(self.rng.random());
326+
let ek_bytes: [u8; ENCAPSULATION_KEY_LEN] = ek
327+
.as_bytes()
328+
.as_slice()
329+
.try_into()
330+
.expect("incorrect encapsulation key size");
331+
let (real_t_hat, rho) = parse_ek(&ek_bytes);
332+
333+
// Step 2: Sample random key for position 1-b.
334+
let rand_t_hat = random_ek(&mut self.rng);
335+
336+
// Step 3: Compute correlated key for position b: r_b = ek - H(r_{1-b}).
337+
let correlated_t_hat = &real_t_hat - &hash_to_key(&rand_t_hat);
338+
339+
// Serialize both keys with the same rho.
340+
let correlated_bytes = EncapKeyBytes(serialize_ek(&correlated_t_hat, &rho));
341+
let random_bytes = EncapKeyBytes(serialize_ek(&rand_t_hat, &rho));
192342

193-
let ek0 = EncapKeyBytes::conditional_select(&real_ek, &fake_ek, *choice);
194-
let ek1 = EncapKeyBytes::conditional_select(&fake_ek, &real_ek, *choice);
343+
// Step 4: Select (r_0, r_1) based on choice bit (constant-time).
344+
// If b=0: r_0 = correlated (real side), r_1 = random.
345+
// If b=1: r_0 = random, r_1 = correlated (real side).
346+
let ek0 = EncapKeyBytes::conditional_select(&correlated_bytes, &random_bytes, *choice);
347+
let ek1 = EncapKeyBytes::conditional_select(&random_bytes, &correlated_bytes, *choice);
195348

196349
decap_keys.push(dk);
197350
eks0.push(ek0);
@@ -217,15 +370,18 @@ impl RotReceiver for MlKemOt {
217370
});
218371
}
219372

220-
// Decapsulate the chosen ciphertext for each OT.
373+
// Step 10-11: Decapsulate the chosen ciphertext and derive OT key.
221374
for (i, ((dk, choice), (ct0, ct1))) in decap_keys
222375
.iter()
223376
.zip(choices.iter())
224377
.zip(sender_msg.cts0.iter().zip(sender_msg.cts1.iter()))
225378
.enumerate()
226379
{
227-
let chosen_ct: MlKemCiphertext<MlKem> =
228-
CtBytes::conditional_select(ct0, ct1, *choice).0.into();
380+
let ct_bytes = CtBytes::conditional_select(ct0, ct1, *choice).0;
381+
let chosen_ct: MlKemCiphertext<MlKem> = ct_bytes
382+
.as_slice()
383+
.try_into()
384+
.expect("incorrect ciphertext size");
229385
let shared_key = dk
230386
.decapsulate(&chosen_ct)
231387
.map_err(|_| Error::Decapsulation)?;

0 commit comments

Comments
 (0)