11//! Post-quantum base OT using ML-KEM.
2+ //!
3+ //! Implements the MR19 protocol (Masny-Rindal, ePrint 2019/706, Figure 8)
4+ //! instantiated with ML-KEM as per Section D.3.
5+ //! See `docs/mlkem-ot-protocol.md` for the full protocol description.
26
3- use std:: io ;
7+ use std:: { io , mem :: size_of } ;
48
59use cryprot_core:: { Block , buf:: Buf , rand_compat:: RngCompat , random_oracle:: RandomOracle } ;
610use cryprot_net:: { Connection , ConnectionError } ;
711use futures:: { SinkExt , StreamExt } ;
12+ use hybrid_array:: typenum:: Unsigned ;
813// ML-KEM variant: change to MlKem512/MlKem512Params or MlKem768/MlKem768Params
914// for different security levels.
1015use ml_kem:: {
1116 Ciphertext as MlKemCiphertext , EncodedSizeUser , KemCore , MlKem1024 as MlKem ,
12- MlKem1024Params as MlKemParams , SharedKey ,
13- array:: typenum:: Unsigned ,
17+ MlKem1024Params as MlKemParams , ParameterSet , SharedKey ,
1418 kem:: { Decapsulate , DecapsulationKey , Encapsulate , EncapsulationKey as MlKemEncapsulationKey } ,
1519} ;
20+ use module_lattice:: { Encode , Field , NttPolynomial } ;
1621use rand:: { RngExt , rngs:: StdRng } ;
1722use serde:: { Deserialize , Serialize } ;
23+ use sha3:: {
24+ Shake128 ,
25+ digest:: { ExtendableOutput , Update , XofReader } ,
26+ } ;
1827use subtle:: { Choice , ConditionallySelectable } ;
1928use tracing:: Level ;
2029
2130use crate :: { Connected , RotReceiver , RotSender , SemiHonest , phase} ;
2231
32+ // Define the ML-KEM base field (q = 3329).
33+ module_lattice:: define_field!( MlKemField , u16 , u32 , u64 , 3329 ) ;
34+
35+ // Module dimension derived from the chosen ML-KEM parameter set.
36+ type K = <MlKemParams as ParameterSet >:: K ;
37+
38+ type NttVector = module_lattice:: NttVector < MlKemField , K > ;
39+
40+ type U12 = hybrid_array:: typenum:: U12 ;
41+
2342const ENCAPSULATION_KEY_LEN : usize =
2443 <MlKemEncapsulationKey < MlKemParams > as EncodedSizeUser >:: EncodedSize :: USIZE ;
2544const CIPHERTEXT_LEN : usize = <MlKem as KemCore >:: CiphertextSize :: USIZE ;
2645const HASH_DOMAIN_SEPARATOR : & [ u8 ] = b"MlKemOt" ;
2746
47+ // Number of coefficients per polynomial (FIPS 203, Section 2: n = 256).
48+ const NUM_COEFFICIENTS : usize = 256 ;
49+
50+ /// rho is a 32-byte seed used to derive the public matrix A_hat (FIPS 203).
51+ type Rho = [ u8 ; 32 ] ;
52+
53+ // Serialized t_hat is the encapsulation key minus the rho suffix.
54+ const T_HAT_BYTES_LEN : usize = ENCAPSULATION_KEY_LEN - size_of :: < Rho > ( ) ;
55+
56+ // ---------------------------------------------------------------------------
57+ // Protocol helper functions (see docs/mlkem-ot-protocol.md)
58+ // ---------------------------------------------------------------------------
59+
60+ /// Parse serialized encapsulation key bytes into (NttVector, rho).
61+ /// The input is a fixed-size array so the slicing is infallible.
62+ fn parse_ek ( bytes : & [ u8 ; ENCAPSULATION_KEY_LEN ] ) -> ( NttVector , Rho ) {
63+ let enc = bytes[ ..T_HAT_BYTES_LEN ]
64+ . try_into ( )
65+ . expect ( "t_hat length mismatch" ) ;
66+ let t_hat = <NttVector as Encode < U12 > >:: decode ( enc) ;
67+ let rho = bytes[ T_HAT_BYTES_LEN ..]
68+ . try_into ( )
69+ . expect ( "rho length mismatch" ) ;
70+ ( t_hat, rho)
71+ }
72+
73+ /// Serialize NttVector + rho back into encapsulation key bytes.
74+ fn serialize_ek ( t_hat : & NttVector , rho : & Rho ) -> [ u8 ; ENCAPSULATION_KEY_LEN ] {
75+ let encoded = <NttVector as Encode < U12 > >:: encode ( t_hat) ;
76+ let mut out = [ 0u8 ; ENCAPSULATION_KEY_LEN ] ;
77+ out[ ..T_HAT_BYTES_LEN ] . copy_from_slice ( encoded. as_slice ( ) ) ;
78+ out[ T_HAT_BYTES_LEN ..] . copy_from_slice ( rho) ;
79+ out
80+ }
81+
82+ /// XOF(rho, j, i) from FIPS 203, Algorithm 2 SHAKE128example.
83+ /// In Algorithm 13 (K-PKE.KeyGen), this is called as XOF(rho, j, i) where
84+ /// j is the column index (byte 32) and i is the row index (byte 33),
85+ /// using 0-based indexing.
86+ fn xof ( seed : & Rho , j : u8 , i : u8 ) -> impl XofReader {
87+ let mut h = Shake128 :: default ( ) ;
88+ h. update ( seed) ;
89+ h. update ( & [ i, j] ) ;
90+ h. finalize_xof ( )
91+ }
92+
93+ /// FIPS 203 Algorithm 7: SampleNTT.
94+ /// Rejection sampling from a byte stream to produce a pseudorandom NTT
95+ /// polynomial.
96+ ///
97+ /// Adapted from the ml-kem crate's `sample_ntt`.
98+ fn sample_ntt_poly ( xof : & mut impl XofReader ) -> NttPolynomial < MlKemField > {
99+ const Q : u16 = MlKemField :: Q ;
100+ // Read 32 triples (3 bytes each) at a time from the XOF.
101+ const BUF_LEN : usize = 96 ;
102+ let mut poly = NttPolynomial :: < MlKemField > :: default ( ) ;
103+ let mut buf = [ 0u8 ; BUF_LEN ] ;
104+ let mut pos = BUF_LEN ; // start at end to trigger first read
105+ let mut i = 0 ;
106+
107+ while i < NUM_COEFFICIENTS {
108+ if pos >= BUF_LEN {
109+ xof. read ( & mut buf) ;
110+ pos = 0 ;
111+ }
112+
113+ let d1 = u16:: from ( buf[ pos] ) | ( ( u16:: from ( buf[ pos + 1 ] ) & 0x0F ) << 8 ) ;
114+ let d2 = ( u16:: from ( buf[ pos + 1 ] ) >> 4 ) | ( u16:: from ( buf[ pos + 2 ] ) << 4 ) ;
115+ pos += 3 ;
116+
117+ if d1 < Q {
118+ poly. 0 [ i] = module_lattice:: Elem :: new ( d1) ;
119+ i += 1 ;
120+ }
121+ if i < NUM_COEFFICIENTS && d2 < Q {
122+ poly. 0 [ i] = module_lattice:: Elem :: new ( d2) ;
123+ i += 1 ;
124+ }
125+ }
126+
127+ poly
128+ }
129+
130+ /// SampleNTTVector: call SampleNTT k times with FIPS 203 domain separation.
131+ /// Produces a pseudorandom NttVector<k> from a 32-byte seed.
132+ /// Each polynomial j uses XOF(seed || j || 0).
133+ fn sample_ntt_vector ( seed : & Rho ) -> NttVector {
134+ NttVector :: new (
135+ ( 0 ..K :: USIZE )
136+ . map ( |j| {
137+ let mut reader = xof ( seed, j as u8 , 0 ) ;
138+ sample_ntt_poly ( & mut reader)
139+ } )
140+ . collect ( ) ,
141+ )
142+ }
143+
144+ /// H(ek): hash-to-key. Maps an NttVector to another NttVector via SHA3-256.
145+ /// Corresponds to libOTe's `pkHash`.
146+ fn hash_to_key ( t_hat : & NttVector ) -> NttVector {
147+ use sha3:: Digest ;
148+ let encoded = <NttVector as Encode < U12 > >:: encode ( t_hat) ;
149+ let seed: Rho = sha3:: Sha3_256 :: digest ( encoded. as_slice ( ) ) . into ( ) ;
150+ sample_ntt_vector ( & seed)
151+ }
152+
153+ /// RandomEK: generate a random NttVector from a random seed.
154+ fn random_ek ( rng : & mut StdRng ) -> NttVector {
155+ let seed: Rho = rng. random ( ) ;
156+ sample_ntt_vector ( & seed)
157+ }
158+
159+ // ---------------------------------------------------------------------------
160+ // Wire types and protocol implementation
161+ // ---------------------------------------------------------------------------
162+
28163#[ derive( thiserror:: Error , Debug ) ]
29164pub enum Error {
30165 #[ error( "quic connection error" ) ]
@@ -67,8 +202,7 @@ impl ConditionallySelectable for CtBytes {
67202 }
68203}
69204
70- // Message from receiver to sender: two encapsulation keys per OT.
71- // For choice bit c, ek{c} is a real key, ek{1-c} is random bytes.
205+ // Message from receiver to sender: two values (r_0, r_1) per OT.
72206#[ derive( Serialize , Deserialize ) ]
73207struct EncapsulationKeysMessage {
74208 eks0 : Vec < EncapKeyBytes > ,
@@ -133,16 +267,23 @@ impl RotSender for MlKemOt {
133267
134268 let mut cts0 = Vec :: with_capacity ( count) ;
135269 let mut cts1 = Vec :: with_capacity ( count) ;
136- for ( i, ( ek0 , ek1 ) ) in receiver_msg
270+ for ( i, ( r0_bytes , r1_bytes ) ) in receiver_msg
137271 . eks0
138272 . iter ( )
139273 . zip ( receiver_msg. eks1 . iter ( ) )
140274 . enumerate ( )
141275 {
142- let ( ct0, key0) = encapsulate ( ek0, & mut self . rng ) ;
276+ // Reconstruct encapsulation keys: ek_j = r_j + H(r_{1-j})
277+ let ( r0, rho) = parse_ek ( & r0_bytes. 0 ) ;
278+ let ( r1, _) = parse_ek ( & r1_bytes. 0 ) ;
279+
280+ let ek0_bytes = serialize_ek ( & ( & r0 + & hash_to_key ( & r1) ) , & rho) ;
281+ let ek1_bytes = serialize_ek ( & ( & r1 + & hash_to_key ( & r0) ) , & rho) ;
282+
283+ let ( ct0, key0) = encapsulate ( & EncapKeyBytes ( ek0_bytes) , & mut self . rng ) ;
143284 let key0 = hash ( & key0, i) ;
144285
145- let ( ct1, key1) = encapsulate ( ek1 , & mut self . rng ) ;
286+ let ( ct1, key1) = encapsulate ( & EncapKeyBytes ( ek1_bytes ) , & mut self . rng ) ;
146287 let key1 = hash ( & key1, i) ;
147288
148289 cts0. push ( ct0) ;
@@ -180,18 +321,30 @@ impl RotReceiver for MlKemOt {
180321 let mut eks1 = Vec :: with_capacity ( count) ;
181322
182323 for choice in choices. iter ( ) {
183- // Generate real keypair.
324+ // Step 1: Generate real keypair.
184325 let ( dk, ek) = MlKem :: generate ( & mut RngCompat ( & mut self . rng ) ) ;
185- let real_ek = EncapKeyBytes (
186- ek. as_bytes ( )
187- . as_slice ( )
188- . try_into ( )
189- . expect ( "incorrect encapsulation key size" ) ,
190- ) ;
191- let fake_ek = EncapKeyBytes ( self . rng . random ( ) ) ;
326+ let ek_bytes: [ u8 ; ENCAPSULATION_KEY_LEN ] = ek
327+ . as_bytes ( )
328+ . as_slice ( )
329+ . try_into ( )
330+ . expect ( "incorrect encapsulation key size" ) ;
331+ let ( real_t_hat, rho) = parse_ek ( & ek_bytes) ;
332+
333+ // Step 2: Sample random key for position 1-b.
334+ let rand_t_hat = random_ek ( & mut self . rng ) ;
335+
336+ // Step 3: Compute correlated key for position b: r_b = ek - H(r_{1-b}).
337+ let correlated_t_hat = & real_t_hat - & hash_to_key ( & rand_t_hat) ;
338+
339+ // Serialize both keys with the same rho.
340+ let correlated_bytes = EncapKeyBytes ( serialize_ek ( & correlated_t_hat, & rho) ) ;
341+ let random_bytes = EncapKeyBytes ( serialize_ek ( & rand_t_hat, & rho) ) ;
192342
193- let ek0 = EncapKeyBytes :: conditional_select ( & real_ek, & fake_ek, * choice) ;
194- let ek1 = EncapKeyBytes :: conditional_select ( & fake_ek, & real_ek, * choice) ;
343+ // Step 4: Select (r_0, r_1) based on choice bit (constant-time).
344+ // If b=0: r_0 = correlated (real side), r_1 = random.
345+ // If b=1: r_0 = random, r_1 = correlated (real side).
346+ let ek0 = EncapKeyBytes :: conditional_select ( & correlated_bytes, & random_bytes, * choice) ;
347+ let ek1 = EncapKeyBytes :: conditional_select ( & random_bytes, & correlated_bytes, * choice) ;
195348
196349 decap_keys. push ( dk) ;
197350 eks0. push ( ek0) ;
@@ -217,15 +370,18 @@ impl RotReceiver for MlKemOt {
217370 } ) ;
218371 }
219372
220- // Decapsulate the chosen ciphertext for each OT.
373+ // Step 10-11: Decapsulate the chosen ciphertext and derive OT key .
221374 for ( i, ( ( dk, choice) , ( ct0, ct1) ) ) in decap_keys
222375 . iter ( )
223376 . zip ( choices. iter ( ) )
224377 . zip ( sender_msg. cts0 . iter ( ) . zip ( sender_msg. cts1 . iter ( ) ) )
225378 . enumerate ( )
226379 {
227- let chosen_ct: MlKemCiphertext < MlKem > =
228- CtBytes :: conditional_select ( ct0, ct1, * choice) . 0 . into ( ) ;
380+ let ct_bytes = CtBytes :: conditional_select ( ct0, ct1, * choice) . 0 ;
381+ let chosen_ct: MlKemCiphertext < MlKem > = ct_bytes
382+ . as_slice ( )
383+ . try_into ( )
384+ . expect ( "incorrect ciphertext size" ) ;
229385 let shared_key = dk
230386 . decapsulate ( & chosen_ct)
231387 . map_err ( |_| Error :: Decapsulation ) ?;
0 commit comments