@@ -25,6 +25,16 @@ const ENCAPSULATION_KEY_LEN: usize =
2525const CIPHERTEXT_LEN : usize = <MlKem as KemCore >:: CiphertextSize :: USIZE ;
2626const HASH_DOMAIN_SEPARATOR : & [ u8 ] = b"MlKemOt" ;
2727
28+ // ML-KEM polynomial arithmetic constants for the MR19 protocol.
29+ // The encapsulation key is encoded as ek = ByteEncode₁₂(t̂) ‖ ρ,
30+ // where t̂ is a vector of k polynomials in NTT domain (k=4 for ML-KEM-1024)
31+ // and ρ is a 32-byte seed for the public matrix A.
32+ const Q : u16 = 3329 ;
33+ const RHO_BYTES : usize = 32 ;
34+ const T_HAT_BYTES : usize = ENCAPSULATION_KEY_LEN - RHO_BYTES ;
35+ const NUM_COEFFS : usize = T_HAT_BYTES * 2 / 3 ;
36+ const MR19_HASH_DOMAIN : & [ u8 ] = b"MlKemOtMR19" ;
37+
2838#[ derive( thiserror:: Error , Debug ) ]
2939pub enum Error {
3040 #[ error( "quic connection error" ) ]
@@ -67,8 +77,8 @@ impl ConditionallySelectable for CtBytes {
6777 }
6878}
6979
70- // Message from receiver to sender: two encapsulation keys per OT.
71- // For choice bit c, ek{c} is a real key, ek {1-c} is random bytes .
80+ // Message from receiver to sender: two values (r_0, r_1) per OT.
81+ // MR19 protocol: sender reconstructs pk_i = r_i + H(r_ {1-i}) .
7282#[ derive( Serialize , Deserialize ) ]
7383struct EncapsulationKeysMessage {
7484 eks0 : Vec < EncapKeyBytes > ,
@@ -133,16 +143,21 @@ impl RotSender for MlKemOt {
133143
134144 let mut cts0 = Vec :: with_capacity ( count) ;
135145 let mut cts1 = Vec :: with_capacity ( count) ;
136- for ( i, ( ek0 , ek1 ) ) in receiver_msg
146+ for ( i, ( r0 , r1 ) ) in receiver_msg
137147 . eks0
138148 . iter ( )
139149 . zip ( receiver_msg. eks1 . iter ( ) )
140150 . enumerate ( )
141151 {
142- let ( ct0, key0) = encapsulate ( ek0, & mut self . rng ) ;
152+ // MR19: reconstruct encapsulation keys from (r_0, r_1).
153+ // pk_0 = r_0 + H(r_1), pk_1 = r_1 + H(r_0)
154+ let pk0 = EncapKeyBytes ( reconstruct_ek ( & r0. 0 , & r1. 0 ) ) ;
155+ let pk1 = EncapKeyBytes ( reconstruct_ek ( & r1. 0 , & r0. 0 ) ) ;
156+
157+ let ( ct0, key0) = encapsulate ( & pk0, & mut self . rng ) ;
143158 let key0 = hash ( & key0, i) ;
144159
145- let ( ct1, key1) = encapsulate ( ek1 , & mut self . rng ) ;
160+ let ( ct1, key1) = encapsulate ( & pk1 , & mut self . rng ) ;
146161 let key1 = hash ( & key1, i) ;
147162
148163 cts0. push ( ct0) ;
@@ -182,16 +197,34 @@ impl RotReceiver for MlKemOt {
182197 for choice in choices. iter ( ) {
183198 // Generate real keypair.
184199 let ( dk, ek) = MlKem :: generate ( & mut RngCompat ( & mut self . rng ) ) ;
185- let real_ek = EncapKeyBytes (
186- ek. as_bytes ( )
187- . as_slice ( )
188- . try_into ( )
189- . expect ( "incorrect encapsulation key size" ) ,
190- ) ;
191- let fake_ek = EncapKeyBytes ( self . rng . random ( ) ) ;
192-
193- let ek0 = EncapKeyBytes :: conditional_select ( & real_ek, & fake_ek, * choice) ;
194- let ek1 = EncapKeyBytes :: conditional_select ( & fake_ek, & real_ek, * choice) ;
200+ let ek_bytes: [ u8 ; ENCAPSULATION_KEY_LEN ] = ek
201+ . as_bytes ( )
202+ . as_slice ( )
203+ . try_into ( )
204+ . expect ( "incorrect encapsulation key size" ) ;
205+
206+ // MR19 protocol: construct (r_b, r_{1-b}) such that
207+ // r_b + H(r_{1-b}) = ek (on the t̂ polynomial vector, with shared ρ).
208+ let rho = & ek_bytes[ T_HAT_BYTES ..] ;
209+
210+ // Generate a random fake ek sharing the same ρ (same matrix A).
211+ let fake_t_hat = random_coeffs ( & mut self . rng ) ;
212+ let fake_ek = assemble_ek ( & fake_t_hat, rho) ;
213+
214+ // r_b = ek - H(fake_ek) on t̂ coefficients.
215+ let t_hat_real = decode_t_hat ( & ek_bytes[ ..T_HAT_BYTES ] ) ;
216+ let h_fake = hash_ek_to_coeffs ( & fake_ek) ;
217+ let r_b_t_hat = sub_mod_q ( & t_hat_real, & h_fake) ;
218+ let r_b_bytes = assemble_ek ( & r_b_t_hat, rho) ;
219+
220+ let r_b = EncapKeyBytes ( r_b_bytes) ;
221+ let r_1_minus_b = EncapKeyBytes ( fake_ek) ;
222+
223+ // Constant-time selection based on choice bit.
224+ // choice=0: ek0=r_b (r_0), ek1=r_{1-b} (r_1) → pk_0 = r_0+H(r_1) = ek
225+ // choice=1: ek0=r_{1-b} (r_0), ek1=r_b (r_1) → pk_1 = r_1+H(r_0) = ek
226+ let ek0 = EncapKeyBytes :: conditional_select ( & r_b, & r_1_minus_b, * choice) ;
227+ let ek1 = EncapKeyBytes :: conditional_select ( & r_1_minus_b, & r_b, * choice) ;
195228
196229 decap_keys. push ( dk) ;
197230 eks0. push ( ek0) ;
@@ -237,6 +270,118 @@ impl RotReceiver for MlKemOt {
237270 }
238271}
239272
273+ // === MR19 polynomial arithmetic helpers ===
274+ // These operate on the t̂ portion of ML-KEM encapsulation keys, which is
275+ // encoded using ByteEncode₁₂ (FIPS 203): each 3 bytes encode 2 coefficients
276+ // of 12 bits each, with all coefficients in [0, q).
277+
278+ /// Decode ByteEncode₁₂: 3 bytes → 2 coefficients (12 bits each), reduced mod q.
279+ fn decode_t_hat ( bytes : & [ u8 ] ) -> [ u16 ; NUM_COEFFS ] {
280+ debug_assert_eq ! ( bytes. len( ) , T_HAT_BYTES ) ;
281+ let mut coeffs = [ 0u16 ; NUM_COEFFS ] ;
282+ for ( i, chunk) in bytes. chunks_exact ( 3 ) . enumerate ( ) {
283+ let d0 = chunk[ 0 ] as u16 ;
284+ let d1 = chunk[ 1 ] as u16 ;
285+ let d2 = chunk[ 2 ] as u16 ;
286+ coeffs[ 2 * i] = ( d0 | ( ( d1 & 0x0F ) << 8 ) ) % Q ;
287+ coeffs[ 2 * i + 1 ] = ( ( d1 >> 4 ) | ( d2 << 4 ) ) % Q ;
288+ }
289+ coeffs
290+ }
291+
292+ /// Encode coefficients as ByteEncode₁₂: 2 coefficients → 3 bytes.
293+ fn encode_t_hat ( coeffs : & [ u16 ; NUM_COEFFS ] ) -> [ u8 ; T_HAT_BYTES ] {
294+ let mut bytes = [ 0u8 ; T_HAT_BYTES ] ;
295+ for ( i, pair) in coeffs. chunks_exact ( 2 ) . enumerate ( ) {
296+ let a = pair[ 0 ] ;
297+ let b = pair[ 1 ] ;
298+ bytes[ 3 * i] = ( a & 0xFF ) as u8 ;
299+ bytes[ 3 * i + 1 ] = ( ( a >> 8 ) | ( ( b & 0x0F ) << 4 ) ) as u8 ;
300+ bytes[ 3 * i + 2 ] = ( b >> 4 ) as u8 ;
301+ }
302+ bytes
303+ }
304+
305+ /// Add two coefficient vectors element-wise mod q.
306+ fn add_mod_q (
307+ a : & [ u16 ; NUM_COEFFS ] ,
308+ b : & [ u16 ; NUM_COEFFS ] ,
309+ ) -> [ u16 ; NUM_COEFFS ] {
310+ let mut result = [ 0u16 ; NUM_COEFFS ] ;
311+ for i in 0 ..NUM_COEFFS {
312+ result[ i] = ( a[ i] + b[ i] ) % Q ;
313+ }
314+ result
315+ }
316+
317+ /// Subtract two coefficient vectors element-wise mod q.
318+ fn sub_mod_q (
319+ a : & [ u16 ; NUM_COEFFS ] ,
320+ b : & [ u16 ; NUM_COEFFS ] ,
321+ ) -> [ u16 ; NUM_COEFFS ] {
322+ let mut result = [ 0u16 ; NUM_COEFFS ] ;
323+ for i in 0 ..NUM_COEFFS {
324+ result[ i] = ( a[ i] + Q - b[ i] ) % Q ;
325+ }
326+ result
327+ }
328+
329+ /// Hash an encoded encapsulation key to a t̂ coefficient vector mod q.
330+ /// Uses BLAKE3 XOF with rejection sampling (12 bits, accept if < q).
331+ fn hash_ek_to_coeffs ( ek : & [ u8 ; ENCAPSULATION_KEY_LEN ] ) -> [ u16 ; NUM_COEFFS ] {
332+ let mut ro = RandomOracle :: new ( ) ;
333+ ro. update ( MR19_HASH_DOMAIN ) ;
334+ ro. update ( ek) ;
335+ let mut xof = ro. finalize_xof ( ) ;
336+ let mut coeffs = [ 0u16 ; NUM_COEFFS ] ;
337+ for c in coeffs. iter_mut ( ) {
338+ loop {
339+ let mut buf = [ 0u8 ; 2 ] ;
340+ xof. fill ( & mut buf) ;
341+ let val = u16:: from_le_bytes ( buf) & 0x0FFF ;
342+ if val < Q {
343+ * c = val;
344+ break ;
345+ }
346+ }
347+ }
348+ coeffs
349+ }
350+
351+ /// Generate random coefficients mod q using rejection sampling.
352+ fn random_coeffs ( rng : & mut impl Rng ) -> [ u16 ; NUM_COEFFS ] {
353+ let mut coeffs = [ 0u16 ; NUM_COEFFS ] ;
354+ for c in coeffs. iter_mut ( ) {
355+ loop {
356+ let val: u16 = rng. random :: < u16 > ( ) & 0x0FFF ;
357+ if val < Q {
358+ * c = val;
359+ break ;
360+ }
361+ }
362+ }
363+ coeffs
364+ }
365+
366+ /// Assemble a full encapsulation key encoding from t̂ coefficients and ρ.
367+ fn assemble_ek ( t_hat : & [ u16 ; NUM_COEFFS ] , rho : & [ u8 ] ) -> [ u8 ; ENCAPSULATION_KEY_LEN ] {
368+ let mut ek = [ 0u8 ; ENCAPSULATION_KEY_LEN ] ;
369+ ek[ ..T_HAT_BYTES ] . copy_from_slice ( & encode_t_hat ( t_hat) ) ;
370+ ek[ T_HAT_BYTES ..] . copy_from_slice ( rho) ;
371+ ek
372+ }
373+
374+ /// MR19 key reconstruction: pk = r + H(other), using ρ from r.
375+ fn reconstruct_ek (
376+ r : & [ u8 ; ENCAPSULATION_KEY_LEN ] ,
377+ other : & [ u8 ; ENCAPSULATION_KEY_LEN ] ,
378+ ) -> [ u8 ; ENCAPSULATION_KEY_LEN ] {
379+ let r_coeffs = decode_t_hat ( & r[ ..T_HAT_BYTES ] ) ;
380+ let h_other = hash_ek_to_coeffs ( other) ;
381+ let pk_coeffs = add_mod_q ( & r_coeffs, & h_other) ;
382+ assemble_ek ( & pk_coeffs, & r[ T_HAT_BYTES ..] )
383+ }
384+
240385// Encapsulates to the given key, returning the ciphertext and the shared key.
241386fn encapsulate ( ek : & EncapKeyBytes , rng : & mut StdRng ) -> ( CtBytes , SharedKey < MlKem > ) {
242387 let parsed_ek = MlKemEncapsulationKey :: < MlKemParams > :: from_bytes ( ( & ek. 0 ) . into ( ) ) ;
@@ -271,6 +416,56 @@ mod tests {
271416 use super :: MlKemOt ;
272417 use crate :: { RotReceiver , RotSender , random_choices} ;
273418
419+ #[ test]
420+ fn encode_decode_roundtrip ( ) {
421+ use super :: * ;
422+ let mut rng = StdRng :: seed_from_u64 ( 42 ) ;
423+ let ( _, ek) = MlKem :: generate ( & mut RngCompat ( & mut rng) ) ;
424+ let ek_bytes: [ u8 ; ENCAPSULATION_KEY_LEN ] =
425+ ek. as_bytes ( ) . as_slice ( ) . try_into ( ) . unwrap ( ) ;
426+
427+ // Check all coefficients are < Q
428+ let coeffs = decode_t_hat ( & ek_bytes[ ..T_HAT_BYTES ] ) ;
429+ for ( i, & c) in coeffs. iter ( ) . enumerate ( ) {
430+ assert ! ( c < Q , "coefficient {i} = {c} >= Q={Q}" ) ;
431+ }
432+
433+ // Check encode roundtrip
434+ let re_encoded = encode_t_hat ( & coeffs) ;
435+ assert_eq ! (
436+ & ek_bytes[ ..T_HAT_BYTES ] ,
437+ & re_encoded[ ..] ,
438+ "encode/decode roundtrip failed"
439+ ) ;
440+ }
441+
442+ #[ test]
443+ fn mr19_reconstruction ( ) {
444+ use super :: * ;
445+ let mut rng = StdRng :: seed_from_u64 ( 42 ) ;
446+ let ( _, ek) = MlKem :: generate ( & mut RngCompat ( & mut rng) ) ;
447+ let ek_bytes: [ u8 ; ENCAPSULATION_KEY_LEN ] =
448+ ek. as_bytes ( ) . as_slice ( ) . try_into ( ) . unwrap ( ) ;
449+ let rho = & ek_bytes[ T_HAT_BYTES ..] ;
450+
451+ // Generate fake key
452+ let fake_t_hat = random_coeffs ( & mut rng) ;
453+ let fake_ek = assemble_ek ( & fake_t_hat, rho) ;
454+
455+ // Compute r_b = ek - H(fake_ek)
456+ let t_hat_real = decode_t_hat ( & ek_bytes[ ..T_HAT_BYTES ] ) ;
457+ let h_fake = hash_ek_to_coeffs ( & fake_ek) ;
458+ let r_b_t_hat = sub_mod_q ( & t_hat_real, & h_fake) ;
459+ let r_b_bytes = assemble_ek ( & r_b_t_hat, rho) ;
460+
461+ // Reconstruct: pk = r_b + H(fake_ek)
462+ let reconstructed = reconstruct_ek ( & r_b_bytes, & fake_ek) ;
463+ assert_eq ! (
464+ ek_bytes, reconstructed,
465+ "MR19 reconstruction failed: pk != ek"
466+ ) ;
467+ }
468+
274469 #[ tokio:: test]
275470 async fn mlkem_base_rot_random_choices ( ) -> Result < ( ) > {
276471 let _g = init_tracing ( ) ;
0 commit comments