@@ -10,18 +10,22 @@ use cryprot_core::{Block, buf::Buf, rand_compat::RngCompat, random_oracle::Rando
1010use cryprot_net:: { Connection , ConnectionError } ;
1111use futures:: { SinkExt , StreamExt } ;
1212use hybrid_array:: typenum:: Unsigned ;
13- // ML-KEM variant: change to MlKem512/MlKem512Params or MlKem768/MlKem768Params
14- // for different security levels.
1513use ml_kem:: {
16- Ciphertext as MlKemCiphertext , EncodedSizeUser , KemCore , MlKem1024 as MlKem ,
17- MlKem1024Params as MlKemParams , ParameterSet , SharedKey ,
14+ Ciphertext as MlKemCiphertext , EncodedSizeUser , KemCore , ParameterSet , SharedKey ,
1815 kem:: { Decapsulate , DecapsulationKey , Encapsulate , EncapsulationKey as MlKemEncapsulationKey } ,
1916} ;
17+ // ML-KEM parameter set selection.
18+ #[ cfg( feature = "ml-kem-base-ot-512" ) ]
19+ use ml_kem:: { MlKem512 as MlKem , MlKem512Params as MlKemParams } ;
20+ #[ cfg( feature = "ml-kem-base-ot-768" ) ]
21+ use ml_kem:: { MlKem768 as MlKem , MlKem768Params as MlKemParams } ;
22+ #[ cfg( feature = "ml-kem-base-ot-1024" ) ]
23+ use ml_kem:: { MlKem1024 as MlKem , MlKem1024Params as MlKemParams } ;
2024use module_lattice:: { Encode , Field , NttPolynomial } ;
2125use rand:: { RngExt , rngs:: StdRng } ;
2226use serde:: { Deserialize , Serialize } ;
2327use sha3:: {
24- Shake128 ,
28+ Digest , Shake128 ,
2529 digest:: { ExtendableOutput , Update , XofReader } ,
2630} ;
2731use subtle:: { Choice , ConditionallySelectable } ;
@@ -49,43 +53,69 @@ const NUM_COEFFICIENTS: usize = 256;
4953
5054type Seed = [ u8 ; 32 ] ;
5155
52- /// rho is the seed used to derive the public matrix A_hat (FIPS 203).
53- type Rho = Seed ;
54-
5556// Serialized t_hat is the encapsulation key minus the rho suffix.
56- const T_HAT_BYTES_LEN : usize = ENCAPSULATION_KEY_LEN - size_of :: < Rho > ( ) ;
57+ const T_HAT_BYTES_LEN : usize = ENCAPSULATION_KEY_LEN - size_of :: < Seed > ( ) ;
5758
5859// ---------------------------------------------------------------------------
5960// Protocol helper functions (see docs/mlkem-ot-protocol.md)
6061// ---------------------------------------------------------------------------
6162
62- /// Parse serialized encapsulation key bytes into (NttVector, rho).
63- /// The input is a fixed-size array so the slicing is infallible.
64- fn parse_ek ( bytes : & [ u8 ; ENCAPSULATION_KEY_LEN ] ) -> ( NttVector , Rho ) {
65- let enc = bytes[ ..T_HAT_BYTES_LEN ]
66- . try_into ( )
67- . expect ( "t_hat length mismatch" ) ;
68- let t_hat = <NttVector as Encode < U12 > >:: decode ( enc) ;
69- let rho = bytes[ T_HAT_BYTES_LEN ..]
70- . try_into ( )
71- . expect ( "rho length mismatch" ) ;
72- ( t_hat, rho)
63+ /// Parsed encapsulation key: ek = (t_hat, rho).
64+ struct EncapsulationKey {
65+ t_hat : NttVector ,
66+ rho : Seed ,
67+ }
68+
69+ impl EncapsulationKey {
70+ /// Parse from serialized bytes.
71+ fn from_bytes ( bytes : & [ u8 ; ENCAPSULATION_KEY_LEN ] ) -> Self {
72+ let enc = bytes[ ..T_HAT_BYTES_LEN ]
73+ . try_into ( )
74+ . expect ( "t_hat length mismatch" ) ;
75+ let t_hat = <NttVector as Encode < U12 > >:: decode ( enc) ;
76+ let rho = bytes[ T_HAT_BYTES_LEN ..]
77+ . try_into ( )
78+ . expect ( "rho length mismatch" ) ;
79+ Self { t_hat, rho }
80+ }
81+
82+ /// Serialize to bytes.
83+ fn to_bytes ( & self ) -> [ u8 ; ENCAPSULATION_KEY_LEN ] {
84+ let encoded = <NttVector as Encode < U12 > >:: encode ( & self . t_hat ) ;
85+ let mut out = [ 0u8 ; ENCAPSULATION_KEY_LEN ] ;
86+ out[ ..T_HAT_BYTES_LEN ] . copy_from_slice ( encoded. as_slice ( ) ) ;
87+ out[ T_HAT_BYTES_LEN ..] . copy_from_slice ( & self . rho ) ;
88+ out
89+ }
90+ }
91+
92+ impl std:: ops:: Sub < & NttVector > for & EncapsulationKey {
93+ type Output = EncapsulationKey ;
94+
95+ fn sub ( self , rhs : & NttVector ) -> EncapsulationKey {
96+ EncapsulationKey {
97+ t_hat : & self . t_hat - rhs,
98+ rho : self . rho ,
99+ }
100+ }
73101}
74102
75- /// Serialize NttVector + rho back into encapsulation key bytes.
76- fn serialize_ek ( t_hat : & NttVector , rho : & Rho ) -> [ u8 ; ENCAPSULATION_KEY_LEN ] {
77- let encoded = <NttVector as Encode < U12 > >:: encode ( t_hat) ;
78- let mut out = [ 0u8 ; ENCAPSULATION_KEY_LEN ] ;
79- out[ ..T_HAT_BYTES_LEN ] . copy_from_slice ( encoded. as_slice ( ) ) ;
80- out[ T_HAT_BYTES_LEN ..] . copy_from_slice ( rho) ;
81- out
103+ impl std:: ops:: Add < & NttVector > for & EncapsulationKey {
104+ type Output = EncapsulationKey ;
105+
106+ fn add ( self , rhs : & NttVector ) -> EncapsulationKey {
107+ EncapsulationKey {
108+ t_hat : & self . t_hat + rhs,
109+ rho : self . rho ,
110+ }
111+ }
82112}
83113
84114/// XOF(rho, j, i) from FIPS 203, Algorithm 2 SHAKE128example.
85115/// In Algorithm 13 (K-PKE.KeyGen), this is called as XOF(rho, j, i) where
86116/// j is the column index (byte 32) and i is the row index (byte 33),
87117/// using 0-based indexing.
88- fn xof ( seed : & Rho , j : u8 , i : u8 ) -> impl XofReader {
118+ fn xof ( seed : & Seed , j : u8 , i : u8 ) -> impl XofReader {
89119 let mut h = Shake128 :: default ( ) ;
90120 h. update ( seed) ;
91121 h. update ( & [ i, j] ) ;
@@ -103,11 +133,12 @@ fn sample_ntt_poly(xof: &mut impl XofReader) -> NttPolynomial<MlKemField> {
103133 const BUF_LEN : usize = 32 * 3 ;
104134 let mut poly = NttPolynomial :: < MlKemField > :: default ( ) ;
105135 let mut buf = [ 0u8 ; BUF_LEN ] ;
106- let mut pos = BUF_LEN ; // start at end to trigger first read
136+ xof. read ( & mut buf) ;
137+ let mut pos = 0 ;
107138 let mut i = 0 ;
108139
109140 while i < NUM_COEFFICIENTS {
110- // Read BUF_LEN chunks from the XOF, consume and then refill once exhausted.
141+ // Refill the buffer from the XOF stream when exhausted.
111142 if pos >= BUF_LEN {
112143 xof. read ( & mut buf) ;
113144 pos = 0 ;
@@ -144,19 +175,22 @@ fn sample_ntt_vector(seed: &Seed) -> NttVector {
144175 )
145176}
146177
147- /// H(ek): hash-to-key. Maps an NttVector to another NttVector via SHA3-256.
178+ /// Maps an encapsulation key to an NttVector via SHA3-256.
179+ /// Only the t_hat component is hashed; rho is ignored.
148180/// Corresponds to libOTe's `pkHash`.
149- fn hash_to_key ( t_hat : & NttVector ) -> NttVector {
150- use sha3:: Digest ;
151- let encoded = <NttVector as Encode < U12 > >:: encode ( t_hat) ;
152- let seed: Rho = sha3:: Sha3_256 :: digest ( encoded. as_slice ( ) ) . into ( ) ;
181+ fn hash_to_key ( ek : & EncapsulationKey ) -> NttVector {
182+ let encoded = <NttVector as Encode < U12 > >:: encode ( & ek. t_hat ) ;
183+ let seed: Seed = sha3:: Sha3_256 :: digest ( encoded. as_slice ( ) ) . into ( ) ;
153184 sample_ntt_vector ( & seed)
154185}
155186
156- /// RandomEK: generate a random NttVector from a random seed.
157- fn random_ek ( rng : & mut StdRng ) -> NttVector {
158- let seed: Rho = rng. random ( ) ;
159- sample_ntt_vector ( & seed)
187+ /// RandomEK: generate a random encapsulation key from a random seed.
188+ fn random_ek ( rng : & mut StdRng , rho : Seed ) -> EncapsulationKey {
189+ let seed: Seed = rng. random ( ) ;
190+ EncapsulationKey {
191+ t_hat : sample_ntt_vector ( & seed) ,
192+ rho,
193+ }
160194}
161195
162196// ---------------------------------------------------------------------------
@@ -184,9 +218,9 @@ pub enum Error {
184218}
185219
186220#[ derive( Copy , Clone , Serialize , Deserialize ) ]
187- struct EncapKeyBytes ( #[ serde( with = "serde_bytes" ) ] [ u8 ; ENCAPSULATION_KEY_LEN ] ) ;
221+ struct EncapsulationKeyBytes ( #[ serde( with = "serde_bytes" ) ] [ u8 ; ENCAPSULATION_KEY_LEN ] ) ;
188222
189- impl ConditionallySelectable for EncapKeyBytes {
223+ impl ConditionallySelectable for EncapsulationKeyBytes {
190224 fn conditional_select ( a : & Self , b : & Self , choice : Choice ) -> Self {
191225 Self ( <[ u8 ; ENCAPSULATION_KEY_LEN ] >:: conditional_select (
192226 & a. 0 , & b. 0 , choice,
@@ -208,8 +242,8 @@ impl ConditionallySelectable for CtBytes {
208242// Message from receiver to sender: two values (r_0, r_1) per OT.
209243#[ derive( Serialize , Deserialize ) ]
210244struct EncapsulationKeysMessage {
211- eks0 : Vec < EncapKeyBytes > ,
212- eks1 : Vec < EncapKeyBytes > ,
245+ eks0 : Vec < EncapsulationKeyBytes > ,
246+ eks1 : Vec < EncapsulationKeyBytes > ,
213247}
214248
215249// Message from sender to receiver: two ciphertexts per OT.
@@ -277,16 +311,16 @@ impl RotSender for MlKemOt {
277311 . enumerate ( )
278312 {
279313 // Reconstruct encapsulation keys: ek_j = r_j + H(r_{1-j})
280- let ( r0 , rho ) = parse_ek ( & r0_bytes. 0 ) ;
281- let ( r1 , _ ) = parse_ek ( & r1_bytes. 0 ) ;
314+ let r0 = EncapsulationKey :: from_bytes ( & r0_bytes. 0 ) ;
315+ let r1 = EncapsulationKey :: from_bytes ( & r1_bytes. 0 ) ;
282316
283- let ek0_bytes = serialize_ek ( & ( & r0 + & hash_to_key ( & r1) ) , & rho ) ;
284- let ek1_bytes = serialize_ek ( & ( & r1 + & hash_to_key ( & r0) ) , & rho ) ;
317+ let ek0 = & r0 + & hash_to_key ( & r1) ;
318+ let ek1 = & r1 + & hash_to_key ( & r0) ;
285319
286- let ( ct0, key0) = encapsulate ( & EncapKeyBytes ( ek0_bytes ) , & mut self . rng ) ;
320+ let ( ct0, key0) = encapsulate ( & EncapsulationKeyBytes ( ek0 . to_bytes ( ) ) , & mut self . rng ) ;
287321 let key0 = hash ( & key0, i) ;
288322
289- let ( ct1, key1) = encapsulate ( & EncapKeyBytes ( ek1_bytes ) , & mut self . rng ) ;
323+ let ( ct1, key1) = encapsulate ( & EncapsulationKeyBytes ( ek1 . to_bytes ( ) ) , & mut self . rng ) ;
290324 let key1 = hash ( & key1, i) ;
291325
292326 cts0. push ( ct0) ;
@@ -331,23 +365,23 @@ impl RotReceiver for MlKemOt {
331365 . as_slice ( )
332366 . try_into ( )
333367 . expect ( "incorrect encapsulation key size" ) ;
334- let ( real_t_hat , rho ) = parse_ek ( & ek_bytes) ;
368+ let ek = EncapsulationKey :: from_bytes ( & ek_bytes) ;
335369
336370 // Step 2: Sample random key for position 1-b.
337- let rand_t_hat = random_ek ( & mut self . rng ) ;
371+ let r_1_b = random_ek ( & mut self . rng , ek . rho ) ;
338372
339- // Step 3: Compute correlated key for position b : r_b = ek - H(r_{1-b}).
340- let correlated_t_hat = & real_t_hat - & hash_to_key ( & rand_t_hat ) ;
373+ // Step 3: Compute correlated key: r_b = ek - H(r_{1-b}).
374+ let r_b = & ek - & hash_to_key ( & r_1_b ) ;
341375
342- // Serialize both keys with the same rho .
343- let correlated_bytes = EncapKeyBytes ( serialize_ek ( & correlated_t_hat , & rho ) ) ;
344- let random_bytes = EncapKeyBytes ( serialize_ek ( & rand_t_hat , & rho ) ) ;
376+ // Serialize both keys.
377+ let r_b_bytes = EncapsulationKeyBytes ( r_b . to_bytes ( ) ) ;
378+ let r_1_b_bytes = EncapsulationKeyBytes ( r_1_b . to_bytes ( ) ) ;
345379
346380 // Step 4: Select (r_0, r_1) based on choice bit (constant-time).
347- // If b=0: r_0 = correlated ( real side) , r_1 = random.
348- // If b=1: r_0 = random, r_1 = correlated ( real side) .
349- let ek0 = EncapKeyBytes :: conditional_select ( & correlated_bytes , & random_bytes , * choice) ;
350- let ek1 = EncapKeyBytes :: conditional_select ( & random_bytes , & correlated_bytes , * choice) ;
381+ // If b=0: r_0 = real, r_1 = random.
382+ // If b=1: r_0 = random, r_1 = real.
383+ let ek0 = EncapsulationKeyBytes :: conditional_select ( & r_b_bytes , & r_1_b_bytes , * choice) ;
384+ let ek1 = EncapsulationKeyBytes :: conditional_select ( & r_1_b_bytes , & r_b_bytes , * choice) ;
351385
352386 decap_keys. push ( dk) ;
353387 eks0. push ( ek0) ;
@@ -397,7 +431,7 @@ impl RotReceiver for MlKemOt {
397431}
398432
399433// Encapsulates to the given key, returning the ciphertext and the shared key.
400- fn encapsulate ( ek : & EncapKeyBytes , rng : & mut StdRng ) -> ( CtBytes , SharedKey < MlKem > ) {
434+ fn encapsulate ( ek : & EncapsulationKeyBytes , rng : & mut StdRng ) -> ( CtBytes , SharedKey < MlKem > ) {
401435 let parsed_ek = MlKemEncapsulationKey :: < MlKemParams > :: from_bytes ( ( & ek. 0 ) . into ( ) ) ;
402436 let ( ct, k) : ( MlKemCiphertext < MlKem > , SharedKey < MlKem > ) = parsed_ek
403437 . encapsulate ( & mut RngCompat ( rng) )
0 commit comments