@@ -21,6 +21,7 @@ use std::mem::MaybeUninit;
2121use crate :: cvt;
2222use crate :: error:: ErrorStack ;
2323use crate :: ffi;
24+ use crate :: memcmp:: eq;
2425
2526// CBS_init is inline in BoringSSL, so bindgen can't generate bindings for it.
2627#[ inline]
@@ -131,16 +132,25 @@ impl MlKemPublicKey {
131132 /// Encapsulates a shared secret to the given public key, returning
132133 /// `(ciphertext, shared_secret)`.
133134 pub fn encapsulate ( & self ) -> Result < ( Vec < u8 > , MlKemSharedSecret ) , ErrorStack > {
134- match & self . 0 {
135+ let mut ss = [ 0 ; SHARED_SECRET_BYTES ] ;
136+ let ct = match & self . 0 {
135137 Either :: MlKem768 ( pk) => {
136- let ( ct, ss) = pk. encapsulate ( ) ;
137- Ok ( ( ct. to_vec ( ) , ss) )
138+ let mut ct = vec ! [ 0 ; MlKem768PrivateKey :: CIPHERTEXT_BYTES ] ;
139+ pk. encapsulate_into ( ct. as_mut_slice ( ) . try_into ( ) . unwrap ( ) , & mut ss) ;
140+ ct
138141 }
139142 Either :: MlKem1024 ( pk) => {
140- let ( ct, ss) = pk. encapsulate ( ) ;
141- Ok ( ( ct. to_vec ( ) , ss) )
143+ let mut ct = vec ! [ 0 ; MlKem1024PrivateKey :: CIPHERTEXT_BYTES ] ;
144+ pk. encapsulate_into ( ct. as_mut_slice ( ) . try_into ( ) . unwrap ( ) , & mut ss) ;
145+ ct
142146 }
147+ } ;
148+ if eq ( ss, [ 0 ; SHARED_SECRET_BYTES ] ) {
149+ return Err ( ErrorStack :: internal_error_str (
150+ "Amazing! I've got the same combination on my luggage" ,
151+ ) ) ;
143152 }
153+ Ok ( ( ct, ss) )
144154 }
145155
146156 /// Query public key and ciphertext length
@@ -229,14 +239,17 @@ impl MlKem768PrivateKey {
229239 ffi:: init ( ) ;
230240 let mut public_key_bytes: MaybeUninit < [ u8 ; MlKem768PublicKey :: PUBLIC_KEY_BYTES ] > =
231241 MaybeUninit :: uninit ( ) ;
232- let mut seed: MaybeUninit < MlKemPrivateKeySeed > = MaybeUninit :: uninit ( ) ;
242+ let mut seed = [ 0 ; PRIVATE_KEY_SEED_BYTES ] ;
233243 let mut expanded: MaybeUninit < ffi:: MLKEM768_private_key > = MaybeUninit :: uninit ( ) ;
234244
235245 ffi:: MLKEM768_generate_key (
236246 public_key_bytes. as_mut_ptr ( ) . cast ( ) ,
237- seed. as_mut_ptr ( ) . cast ( ) ,
247+ seed. as_mut_ptr ( ) ,
238248 expanded. as_mut_ptr ( ) ,
239249 ) ;
250+ if eq ( seed, [ 0 ; PRIVATE_KEY_SEED_BYTES ] ) {
251+ panic ! ( "Amazing! I've got the same combination on my luggage" ) ;
252+ }
240253
241254 let bytes = public_key_bytes. assume_init ( ) ;
242255
@@ -251,7 +264,7 @@ impl MlKem768PrivateKey {
251264 parsed : parsed. assume_init ( ) ,
252265 } ) ,
253266 Box :: new ( MlKem768PrivateKey {
254- seed : seed . assume_init ( ) ,
267+ seed,
255268 expanded : expanded. assume_init ( ) ,
256269 } ) ,
257270 )
@@ -388,25 +401,19 @@ impl MlKem768PublicKey {
388401 }
389402
390403 /// Encapsulate: returns (ciphertext, shared_secret).
391- fn encapsulate (
404+ fn encapsulate_into (
392405 & self ,
393- ) -> (
394- [ u8 ; MlKem768PrivateKey :: CIPHERTEXT_BYTES ] ,
395- MlKemSharedSecret ,
406+ ciphertext : & mut [ u8 ; MlKem768PrivateKey :: CIPHERTEXT_BYTES ] ,
407+ shared_secret : & mut MlKemSharedSecret ,
396408 ) {
397409 // SAFETY: buffers correctly sized, parsed key is valid
398410 unsafe {
399411 ffi:: init ( ) ;
400- let mut ciphertext = [ 0u8 ; MlKem768PrivateKey :: CIPHERTEXT_BYTES ] ;
401- let mut shared_secret = [ 0u8 ; SHARED_SECRET_BYTES ] ;
402-
403412 ffi:: MLKEM768_encap (
404413 ciphertext. as_mut_ptr ( ) ,
405414 shared_secret. as_mut_ptr ( ) ,
406415 & self . parsed ,
407416 ) ;
408-
409- ( ciphertext, shared_secret)
410417 }
411418 }
412419}
@@ -446,14 +453,17 @@ impl MlKem1024PrivateKey {
446453 ffi:: init ( ) ;
447454 let mut public_key_bytes: MaybeUninit < [ u8 ; MlKem1024PublicKey :: PUBLIC_KEY_BYTES ] > =
448455 MaybeUninit :: uninit ( ) ;
449- let mut seed: MaybeUninit < MlKemPrivateKeySeed > = MaybeUninit :: uninit ( ) ;
456+ let mut seed = [ 0 ; PRIVATE_KEY_SEED_BYTES ] ;
450457 let mut expanded: MaybeUninit < ffi:: MLKEM1024_private_key > = MaybeUninit :: uninit ( ) ;
451458
452459 ffi:: MLKEM1024_generate_key (
453460 public_key_bytes. as_mut_ptr ( ) . cast ( ) ,
454- seed. as_mut_ptr ( ) . cast ( ) ,
461+ seed. as_mut_ptr ( ) ,
455462 expanded. as_mut_ptr ( ) ,
456463 ) ;
464+ if eq ( seed, [ 0 ; PRIVATE_KEY_SEED_BYTES ] ) {
465+ panic ! ( "Amazing! I've got the same combination on my luggage" ) ;
466+ }
457467
458468 let bytes = public_key_bytes. assume_init ( ) ;
459469
@@ -468,7 +478,7 @@ impl MlKem1024PrivateKey {
468478 parsed : parsed. assume_init ( ) ,
469479 } ) ,
470480 Box :: new ( MlKem1024PrivateKey {
471- seed : seed . assume_init ( ) ,
481+ seed,
472482 expanded : expanded. assume_init ( ) ,
473483 } ) ,
474484 )
@@ -607,25 +617,19 @@ impl MlKem1024PublicKey {
607617 }
608618
609619 /// Encapsulate: returns (ciphertext, shared_secret).
610- fn encapsulate (
620+ fn encapsulate_into (
611621 & self ,
612- ) -> (
613- [ u8 ; MlKem1024PrivateKey :: CIPHERTEXT_BYTES ] ,
614- [ u8 ; SHARED_SECRET_BYTES ] ,
622+ ciphertext : & mut [ u8 ; MlKem1024PrivateKey :: CIPHERTEXT_BYTES ] ,
623+ shared_secret : & mut [ u8 ; SHARED_SECRET_BYTES ] ,
615624 ) {
616625 // SAFETY: buffers correctly sized, parsed key is valid
617626 unsafe {
618627 ffi:: init ( ) ;
619- let mut ciphertext = [ 0u8 ; MlKem1024PrivateKey :: CIPHERTEXT_BYTES ] ;
620- let mut shared_secret = [ 0u8 ; SHARED_SECRET_BYTES ] ;
621-
622628 ffi:: MLKEM1024_encap (
623629 ciphertext. as_mut_ptr ( ) ,
624630 shared_secret. as_mut_ptr ( ) ,
625631 & self . parsed ,
626632 ) ;
627-
628- ( ciphertext, shared_secret)
629633 }
630634 }
631635}
@@ -650,7 +654,9 @@ mod tests {
650654 #[ test]
651655 fn roundtrip( ) {
652656 let ( pk, sk) = <$priv>:: generate( ) ;
653- let ( ct, ss1) = pk. encapsulate( ) ;
657+ let mut ct = [ 0 ; _] ;
658+ let mut ss1 = [ 0 ; _] ;
659+ pk. encapsulate_into( & mut ct, & mut ss1) ;
654660 let ss2 = sk. decapsulate( & ct) ;
655661 assert_eq!( ss1, ss2) ;
656662 }
@@ -659,7 +665,9 @@ mod tests {
659665 fn seed_roundtrip( ) {
660666 let ( pk, sk) = <$priv>:: generate( ) ;
661667 let sk2 = <$priv>:: from_seed( & sk. seed) . unwrap( ) ;
662- let ( ct, ss1) = pk. encapsulate( ) ;
668+ let mut ct = [ 0 ; _] ;
669+ let mut ss1 = [ 0 ; _] ;
670+ pk. encapsulate_into( & mut ct, & mut ss1) ;
663671 let ss2 = sk2. decapsulate( & ct) ;
664672 assert_eq!( ss1, ss2) ;
665673 }
0 commit comments