Skip to content

Commit 624ac1a

Browse files
committed
Implement MR19 protocol for ML-KEM OT encapsulation key security
Replace the broken approach (real ek + random bytes) with proper MR19 protocol arithmetic on the t̂ polynomial vectors of ML-KEM encapsulation keys. The receiver now computes r_b = ek - H(r_{1-b}) and the sender reconstructs pk_i = r_i + H(r_{1-i}), ensuring the choice bit is hidden by the algebraic relationship rather than random padding. Key changes: - Add ByteEncode₁₂/ByteDecode₁₂ helpers for coefficient-level arithmetic - Add mod-q add/subtract on t̂ vectors and hash-to-coefficients via BLAKE3 XOF - Receiver: generate fake ek with shared ρ, compute r_b from MR19 relation - Sender: reconstruct encapsulation keys before encapsulating - Add unit tests for encode/decode roundtrip and MR19 reconstruction WIP: integration tests still failing, needs debugging. https://claude.ai/code/session_01ENct61kKUXnsJZMpMxiYdT
1 parent 11dc9aa commit 624ac1a

File tree

1 file changed

+210
-15
lines changed

1 file changed

+210
-15
lines changed

cryprot-ot/src/mlkem_ot.rs

Lines changed: 210 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ const ENCAPSULATION_KEY_LEN: usize =
2525
const CIPHERTEXT_LEN: usize = <MlKem as KemCore>::CiphertextSize::USIZE;
2626
const HASH_DOMAIN_SEPARATOR: &[u8] = b"MlKemOt";
2727

28+
// ML-KEM polynomial arithmetic constants for the MR19 protocol.
29+
// The encapsulation key is encoded as ek = ByteEncode₁₂(t̂) ‖ ρ,
30+
// where t̂ is a vector of k polynomials in NTT domain (k=4 for ML-KEM-1024)
31+
// and ρ is a 32-byte seed for the public matrix A.
32+
const Q: u16 = 3329;
33+
const RHO_BYTES: usize = 32;
34+
const T_HAT_BYTES: usize = ENCAPSULATION_KEY_LEN - RHO_BYTES;
35+
const NUM_COEFFS: usize = T_HAT_BYTES * 2 / 3;
36+
const MR19_HASH_DOMAIN: &[u8] = b"MlKemOtMR19";
37+
2838
#[derive(thiserror::Error, Debug)]
2939
pub enum Error {
3040
#[error("quic connection error")]
@@ -67,8 +77,8 @@ impl ConditionallySelectable for CtBytes {
6777
}
6878
}
6979

70-
// Message from receiver to sender: two encapsulation keys per OT.
71-
// For choice bit c, ek{c} is a real key, ek{1-c} is random bytes.
80+
// Message from receiver to sender: two values (r_0, r_1) per OT.
81+
// MR19 protocol: sender reconstructs pk_i = r_i + H(r_{1-i}).
7282
#[derive(Serialize, Deserialize)]
7383
struct EncapsulationKeysMessage {
7484
eks0: Vec<EncapKeyBytes>,
@@ -133,16 +143,21 @@ impl RotSender for MlKemOt {
133143

134144
let mut cts0 = Vec::with_capacity(count);
135145
let mut cts1 = Vec::with_capacity(count);
136-
for (i, (ek0, ek1)) in receiver_msg
146+
for (i, (r0, r1)) in receiver_msg
137147
.eks0
138148
.iter()
139149
.zip(receiver_msg.eks1.iter())
140150
.enumerate()
141151
{
142-
let (ct0, key0) = encapsulate(ek0, &mut self.rng);
152+
// MR19: reconstruct encapsulation keys from (r_0, r_1).
153+
// pk_0 = r_0 + H(r_1), pk_1 = r_1 + H(r_0)
154+
let pk0 = EncapKeyBytes(reconstruct_ek(&r0.0, &r1.0));
155+
let pk1 = EncapKeyBytes(reconstruct_ek(&r1.0, &r0.0));
156+
157+
let (ct0, key0) = encapsulate(&pk0, &mut self.rng);
143158
let key0 = hash(&key0, i);
144159

145-
let (ct1, key1) = encapsulate(ek1, &mut self.rng);
160+
let (ct1, key1) = encapsulate(&pk1, &mut self.rng);
146161
let key1 = hash(&key1, i);
147162

148163
cts0.push(ct0);
@@ -182,16 +197,34 @@ impl RotReceiver for MlKemOt {
182197
for choice in choices.iter() {
183198
// Generate real keypair.
184199
let (dk, ek) = MlKem::generate(&mut RngCompat(&mut self.rng));
185-
let real_ek = EncapKeyBytes(
186-
ek.as_bytes()
187-
.as_slice()
188-
.try_into()
189-
.expect("incorrect encapsulation key size"),
190-
);
191-
let fake_ek = EncapKeyBytes(self.rng.random());
192-
193-
let ek0 = EncapKeyBytes::conditional_select(&real_ek, &fake_ek, *choice);
194-
let ek1 = EncapKeyBytes::conditional_select(&fake_ek, &real_ek, *choice);
200+
let ek_bytes: [u8; ENCAPSULATION_KEY_LEN] = ek
201+
.as_bytes()
202+
.as_slice()
203+
.try_into()
204+
.expect("incorrect encapsulation key size");
205+
206+
// MR19 protocol: construct (r_b, r_{1-b}) such that
207+
// r_b + H(r_{1-b}) = ek (on the t̂ polynomial vector, with shared ρ).
208+
let rho = &ek_bytes[T_HAT_BYTES..];
209+
210+
// Generate a random fake ek sharing the same ρ (same matrix A).
211+
let fake_t_hat = random_coeffs(&mut self.rng);
212+
let fake_ek = assemble_ek(&fake_t_hat, rho);
213+
214+
// r_b = ek - H(fake_ek) on t̂ coefficients.
215+
let t_hat_real = decode_t_hat(&ek_bytes[..T_HAT_BYTES]);
216+
let h_fake = hash_ek_to_coeffs(&fake_ek);
217+
let r_b_t_hat = sub_mod_q(&t_hat_real, &h_fake);
218+
let r_b_bytes = assemble_ek(&r_b_t_hat, rho);
219+
220+
let r_b = EncapKeyBytes(r_b_bytes);
221+
let r_1_minus_b = EncapKeyBytes(fake_ek);
222+
223+
// Constant-time selection based on choice bit.
224+
// choice=0: ek0=r_b (r_0), ek1=r_{1-b} (r_1) → pk_0 = r_0+H(r_1) = ek
225+
// choice=1: ek0=r_{1-b} (r_0), ek1=r_b (r_1) → pk_1 = r_1+H(r_0) = ek
226+
let ek0 = EncapKeyBytes::conditional_select(&r_b, &r_1_minus_b, *choice);
227+
let ek1 = EncapKeyBytes::conditional_select(&r_1_minus_b, &r_b, *choice);
195228

196229
decap_keys.push(dk);
197230
eks0.push(ek0);
@@ -237,6 +270,118 @@ impl RotReceiver for MlKemOt {
237270
}
238271
}
239272

273+
// === MR19 polynomial arithmetic helpers ===
274+
// These operate on the t̂ portion of ML-KEM encapsulation keys, which is
275+
// encoded using ByteEncode₁₂ (FIPS 203): each 3 bytes encode 2 coefficients
276+
// of 12 bits each, with all coefficients in [0, q).
277+
278+
/// Decode ByteEncode₁₂: 3 bytes → 2 coefficients (12 bits each), reduced mod q.
279+
fn decode_t_hat(bytes: &[u8]) -> [u16; NUM_COEFFS] {
280+
debug_assert_eq!(bytes.len(), T_HAT_BYTES);
281+
let mut coeffs = [0u16; NUM_COEFFS];
282+
for (i, chunk) in bytes.chunks_exact(3).enumerate() {
283+
let d0 = chunk[0] as u16;
284+
let d1 = chunk[1] as u16;
285+
let d2 = chunk[2] as u16;
286+
coeffs[2 * i] = (d0 | ((d1 & 0x0F) << 8)) % Q;
287+
coeffs[2 * i + 1] = ((d1 >> 4) | (d2 << 4)) % Q;
288+
}
289+
coeffs
290+
}
291+
292+
/// Encode coefficients as ByteEncode₁₂: 2 coefficients → 3 bytes.
293+
fn encode_t_hat(coeffs: &[u16; NUM_COEFFS]) -> [u8; T_HAT_BYTES] {
294+
let mut bytes = [0u8; T_HAT_BYTES];
295+
for (i, pair) in coeffs.chunks_exact(2).enumerate() {
296+
let a = pair[0];
297+
let b = pair[1];
298+
bytes[3 * i] = (a & 0xFF) as u8;
299+
bytes[3 * i + 1] = ((a >> 8) | ((b & 0x0F) << 4)) as u8;
300+
bytes[3 * i + 2] = (b >> 4) as u8;
301+
}
302+
bytes
303+
}
304+
305+
/// Add two coefficient vectors element-wise mod q.
306+
fn add_mod_q(
307+
a: &[u16; NUM_COEFFS],
308+
b: &[u16; NUM_COEFFS],
309+
) -> [u16; NUM_COEFFS] {
310+
let mut result = [0u16; NUM_COEFFS];
311+
for i in 0..NUM_COEFFS {
312+
result[i] = (a[i] + b[i]) % Q;
313+
}
314+
result
315+
}
316+
317+
/// Subtract two coefficient vectors element-wise mod q.
318+
fn sub_mod_q(
319+
a: &[u16; NUM_COEFFS],
320+
b: &[u16; NUM_COEFFS],
321+
) -> [u16; NUM_COEFFS] {
322+
let mut result = [0u16; NUM_COEFFS];
323+
for i in 0..NUM_COEFFS {
324+
result[i] = (a[i] + Q - b[i]) % Q;
325+
}
326+
result
327+
}
328+
329+
/// Hash an encoded encapsulation key to a t̂ coefficient vector mod q.
330+
/// Uses BLAKE3 XOF with rejection sampling (12 bits, accept if < q).
331+
fn hash_ek_to_coeffs(ek: &[u8; ENCAPSULATION_KEY_LEN]) -> [u16; NUM_COEFFS] {
332+
let mut ro = RandomOracle::new();
333+
ro.update(MR19_HASH_DOMAIN);
334+
ro.update(ek);
335+
let mut xof = ro.finalize_xof();
336+
let mut coeffs = [0u16; NUM_COEFFS];
337+
for c in coeffs.iter_mut() {
338+
loop {
339+
let mut buf = [0u8; 2];
340+
xof.fill(&mut buf);
341+
let val = u16::from_le_bytes(buf) & 0x0FFF;
342+
if val < Q {
343+
*c = val;
344+
break;
345+
}
346+
}
347+
}
348+
coeffs
349+
}
350+
351+
/// Generate random coefficients mod q using rejection sampling.
352+
fn random_coeffs(rng: &mut impl Rng) -> [u16; NUM_COEFFS] {
353+
let mut coeffs = [0u16; NUM_COEFFS];
354+
for c in coeffs.iter_mut() {
355+
loop {
356+
let val: u16 = rng.random::<u16>() & 0x0FFF;
357+
if val < Q {
358+
*c = val;
359+
break;
360+
}
361+
}
362+
}
363+
coeffs
364+
}
365+
366+
/// Assemble a full encapsulation key encoding from t̂ coefficients and ρ.
367+
fn assemble_ek(t_hat: &[u16; NUM_COEFFS], rho: &[u8]) -> [u8; ENCAPSULATION_KEY_LEN] {
368+
let mut ek = [0u8; ENCAPSULATION_KEY_LEN];
369+
ek[..T_HAT_BYTES].copy_from_slice(&encode_t_hat(t_hat));
370+
ek[T_HAT_BYTES..].copy_from_slice(rho);
371+
ek
372+
}
373+
374+
/// MR19 key reconstruction: pk = r + H(other), using ρ from r.
375+
fn reconstruct_ek(
376+
r: &[u8; ENCAPSULATION_KEY_LEN],
377+
other: &[u8; ENCAPSULATION_KEY_LEN],
378+
) -> [u8; ENCAPSULATION_KEY_LEN] {
379+
let r_coeffs = decode_t_hat(&r[..T_HAT_BYTES]);
380+
let h_other = hash_ek_to_coeffs(other);
381+
let pk_coeffs = add_mod_q(&r_coeffs, &h_other);
382+
assemble_ek(&pk_coeffs, &r[T_HAT_BYTES..])
383+
}
384+
240385
// Encapsulates to the given key, returning the ciphertext and the shared key.
241386
fn encapsulate(ek: &EncapKeyBytes, rng: &mut StdRng) -> (CtBytes, SharedKey<MlKem>) {
242387
let parsed_ek = MlKemEncapsulationKey::<MlKemParams>::from_bytes((&ek.0).into());
@@ -271,6 +416,56 @@ mod tests {
271416
use super::MlKemOt;
272417
use crate::{RotReceiver, RotSender, random_choices};
273418

419+
#[test]
420+
fn encode_decode_roundtrip() {
421+
use super::*;
422+
let mut rng = StdRng::seed_from_u64(42);
423+
let (_, ek) = MlKem::generate(&mut RngCompat(&mut rng));
424+
let ek_bytes: [u8; ENCAPSULATION_KEY_LEN] =
425+
ek.as_bytes().as_slice().try_into().unwrap();
426+
427+
// Check all coefficients are < Q
428+
let coeffs = decode_t_hat(&ek_bytes[..T_HAT_BYTES]);
429+
for (i, &c) in coeffs.iter().enumerate() {
430+
assert!(c < Q, "coefficient {i} = {c} >= Q={Q}");
431+
}
432+
433+
// Check encode roundtrip
434+
let re_encoded = encode_t_hat(&coeffs);
435+
assert_eq!(
436+
&ek_bytes[..T_HAT_BYTES],
437+
&re_encoded[..],
438+
"encode/decode roundtrip failed"
439+
);
440+
}
441+
442+
#[test]
443+
fn mr19_reconstruction() {
444+
use super::*;
445+
let mut rng = StdRng::seed_from_u64(42);
446+
let (_, ek) = MlKem::generate(&mut RngCompat(&mut rng));
447+
let ek_bytes: [u8; ENCAPSULATION_KEY_LEN] =
448+
ek.as_bytes().as_slice().try_into().unwrap();
449+
let rho = &ek_bytes[T_HAT_BYTES..];
450+
451+
// Generate fake key
452+
let fake_t_hat = random_coeffs(&mut rng);
453+
let fake_ek = assemble_ek(&fake_t_hat, rho);
454+
455+
// Compute r_b = ek - H(fake_ek)
456+
let t_hat_real = decode_t_hat(&ek_bytes[..T_HAT_BYTES]);
457+
let h_fake = hash_ek_to_coeffs(&fake_ek);
458+
let r_b_t_hat = sub_mod_q(&t_hat_real, &h_fake);
459+
let r_b_bytes = assemble_ek(&r_b_t_hat, rho);
460+
461+
// Reconstruct: pk = r_b + H(fake_ek)
462+
let reconstructed = reconstruct_ek(&r_b_bytes, &fake_ek);
463+
assert_eq!(
464+
ek_bytes, reconstructed,
465+
"MR19 reconstruction failed: pk != ek"
466+
);
467+
}
468+
274469
#[tokio::test]
275470
async fn mlkem_base_rot_random_choices() -> Result<()> {
276471
let _g = init_tracing();

0 commit comments

Comments
 (0)