@@ -86,7 +86,7 @@ pub fn g(c: &[&[u8]]) -> ([u8; 32], Zeroizing<[u8; 32]>) {
8686}
8787
8888/// Internal PKE-related function, for generalizing over the three different PKE parameter-sets.
89- pub ( crate ) trait PkeParameters : Clone {
89+ pub ( crate ) trait PkeParameters {
9090 const N : usize = 256 ;
9191 const K : usize ;
9292 const ETA_1 : usize ;
@@ -523,7 +523,7 @@ pub(crate) struct DecapKey<
523523> {
524524 pub ( crate ) bytes : [ u8 ; ENCODED_SIZE_DK ] ,
525525 s_hat : [ RingElementNTT ; K ] ,
526- ek : EncapKey < K , ENCODED_SIZE_EK , Pke > ,
526+ _phantom : PhantomData < Pke > ,
527527}
528528
529529impl <
@@ -617,15 +617,10 @@ impl<
617617 ByteSerialization :: decode_12 ( dk_part, & mut s_hat_poly. coefficients ) ;
618618 }
619619
620- // Save the encapsulation key, such that it doesn't need to be re-computed in MLKEM-decap_internal().
621- let ek = EncapKey :: < K , ENCODED_SIZE_EK , Pke > :: from_slice (
622- & slice[ ENCODE_SIZE_POLY * K ..( 768 * K ) + 32 ] ,
623- ) ?;
624-
625620 Ok ( Self {
626621 bytes : slice. try_into ( ) . unwrap ( ) , // NOTE: Should never panic if decapsulation_key_check() succeeds.
627622 s_hat,
628- ek ,
623+ _phantom : PhantomData ,
629624 } )
630625 }
631626
@@ -682,11 +677,13 @@ impl<
682677 /// - decapsulation key dk ∈ B^{768k+96}.
683678 /// - ciphertext c ∈ B^{32*(d_u*k+d_v)}.
684679 /// - shared secret K ∈ B^{32}.
685- pub ( crate ) fn mlkem_decap_internal (
680+ pub ( crate ) fn mlkem_decap_internal_with_ek (
686681 & self ,
687682 c : & [ u8 ] ,
688683 c_prime : & mut [ u8 ] ,
684+ ek : & EncapKey < K , ENCODED_SIZE_EK , Pke > ,
689685 ) -> Result < [ u8 ; 32 ] , UnknownCryptoError > {
686+ debug_assert_eq ! ( self . get_encapsulation_key_bytes( ) , ek. as_ref( ) ) ;
690687 debug_assert_eq ! ( c. len( ) , Pke :: CIPHERTEXT_SIZE ) ;
691688
692689 // Step 1:
@@ -712,9 +709,8 @@ impl<
712709 xof. squeeze ( k_bar. as_mut ( ) ) ?;
713710
714711 // Step 8:
715- debug_assert_eq ! ( self . get_encapsulation_key_bytes( ) , self . ek. as_ref( ) ) ;
716712 debug_assert_eq ! ( self . get_encapsulation_key_bytes( ) , ek_pke) ;
717- self . ek . encrypt ( & m, r. as_ref ( ) , c_prime) ?;
713+ ek. encrypt ( & m, r. as_ref ( ) , c_prime) ?;
718714
719715 // Step 9:
720716 let ct_choice = c. ct_ne ( c_prime) ;
@@ -727,6 +723,19 @@ impl<
727723 Ok ( k)
728724 }
729725
726+ #[ cfg( feature = "safe_api" ) ] // used in from_keys which requires safe_api
727+ pub ( crate ) fn mlkem_decap_internal (
728+ & self ,
729+ c : & [ u8 ] ,
730+ c_prime : & mut [ u8 ] ,
731+ ) -> Result < [ u8 ; 32 ] , UnknownCryptoError > {
732+ // In this case we aren't provided a cached encapsulation key.
733+ let ek =
734+ EncapKey :: < K , ENCODED_SIZE_EK , Pke > :: from_slice ( self . get_encapsulation_key_bytes ( ) ) ?;
735+
736+ self . mlkem_decap_internal_with_ek ( c, c_prime, & ek)
737+ }
738+
730739 pub ( crate ) fn unprotected_as_bytes ( & self ) -> & [ u8 ] {
731740 & self . bytes
732741 }
@@ -745,6 +754,7 @@ impl<Pke: PkeParameters> KeyPairInternal<Pke> {
745754 /// k \in [2, 3, 4]
746755 fn keygen < const K : usize , const ENCODED_SIZE_EK : usize , const ENCODED_SIZE_DK : usize > (
747756 d : & [ u8 ] ,
757+ ek : & mut EncapKey < K , ENCODED_SIZE_EK , Pke > ,
748758 dk : & mut DecapKey < K , ENCODED_SIZE_EK , ENCODED_SIZE_DK , Pke > ,
749759 ) -> Result < ( ) , UnknownCryptoError > {
750760 let ( rho, sigma) = g ( & [ d, & [ Pke :: K as u8 ] ] ) ;
@@ -753,7 +763,7 @@ impl<Pke: PkeParameters> KeyPairInternal<Pke> {
753763 // Steps 3..7
754764 for i in 0 ..Pke :: K {
755765 for j in 0 ..Pke :: K {
756- dk . ek . mat_a [ i] [ j] = sample_ntt ( & rho, & [ j as u8 , i as u8 ] ) ?;
766+ ek. mat_a [ i] [ j] = sample_ntt ( & rho, & [ j as u8 , i as u8 ] ) ?;
757767 }
758768 }
759769
@@ -773,34 +783,33 @@ impl<Pke: PkeParameters> KeyPairInternal<Pke> {
773783
774784 for i in 0 ..Pke :: K {
775785 dk. s_hat [ i] = to_ntt ( & s[ i] ) ;
776- dk . ek . t_hat [ i] = to_ntt ( & e[ i] ) ;
786+ ek. t_hat [ i] = to_ntt ( & e[ i] ) ;
777787 }
778788
779789 s. zeroize ( ) ;
780790
781791 // t ← A ∘ ŝ + ê
782792 for i in 0 ..Pke :: K {
783793 for j in 0 ..Pke :: K {
784- dk . ek . t_hat [ i] += dk . ek . mat_a [ i] [ j] * dk. s_hat [ j] ;
794+ ek. t_hat [ i] += ek. mat_a [ i] [ j] * dk. s_hat [ j] ;
785795 }
786796 }
787797
788798 // Step 19
789- for ( re, ek_part) in dk
790- . ek
799+ for ( re, ek_part) in ek
791800 . t_hat
792801 . iter ( )
793- . zip ( dk . ek . bytes . chunks_exact_mut ( ENCODE_SIZE_POLY ) )
802+ . zip ( ek. bytes . chunks_exact_mut ( ENCODE_SIZE_POLY ) )
794803 {
795804 ByteSerialization :: encode_12 ( & re. coefficients , ek_part) ;
796805 }
797806
798807 let idx = ENCODED_SIZE_EK - rho. len ( ) ;
799- dk . ek . bytes [ idx..] . copy_from_slice ( & rho) ;
808+ ek. bytes [ idx..] . copy_from_slice ( & rho) ;
800809
801810 // Cache hash of ek so we don't need to re-compute for every encap().
802- let h_ek = Sha3_256 :: digest ( & dk . ek . bytes ) ?;
803- dk . ek . h_ek = h_ek. value ;
811+ let h_ek = Sha3_256 :: digest ( & ek. bytes ) ?;
812+ ek. h_ek = h_ek. value ;
804813
805814 // Step 20
806815 for ( re, dk_part) in dk
@@ -827,7 +836,7 @@ impl<Pke: PkeParameters> KeyPairInternal<Pke> {
827836 ) ,
828837 UnknownCryptoError ,
829838 > {
830- let ek = EncapKey :: < K , ENCODED_SIZE_EK , Pke > {
839+ let mut encap_key = EncapKey :: < K , ENCODED_SIZE_EK , Pke > {
831840 bytes : [ 0u8 ; ENCODED_SIZE_EK ] ,
832841 h_ek : [ 0u8 ; SHA3_256_OUTSIZE ] ,
833842 t_hat : [ RingElementNTT :: zero ( ) ; K ] ,
@@ -839,28 +848,27 @@ impl<Pke: PkeParameters> KeyPairInternal<Pke> {
839848 let mut decap_key = DecapKey :: < K , ENCODED_SIZE_EK , ENCODED_SIZE_DK , Pke > {
840849 bytes : [ 0u8 ; ENCODED_SIZE_DK ] ,
841850 s_hat : [ RingElementNTT :: zero ( ) ; K ] ,
842- ek ,
851+ _phantom : PhantomData ,
843852 } ;
844853
845854 // Step 1 + 2. (ekPKE, dkPKE) ← K-PKE.KeyGen(d)
846- Self :: keygen ( & seed. unprotected_as_bytes ( ) [ ..32 ] , & mut decap_key) ?;
855+ Self :: keygen (
856+ & seed. unprotected_as_bytes ( ) [ ..32 ] ,
857+ & mut encap_key,
858+ & mut decap_key,
859+ ) ?;
847860
848861 // Step 3. dk ← (dkPKE‖ek‖H(ek)‖z)
849- let ek_bytes = decap_key. ek . as_ref ( ) ;
850862 decap_key. bytes [ ( ENCODE_SIZE_POLY * K ) ..( ENCODE_SIZE_POLY * K ) + Pke :: EK_SIZE ]
851- . copy_from_slice ( ek_bytes ) ;
863+ . copy_from_slice ( & encap_key . bytes ) ;
852864 decap_key. bytes
853865 [ ( ENCODE_SIZE_POLY * K ) + Pke :: EK_SIZE ..( ENCODE_SIZE_POLY * K ) + Pke :: EK_SIZE + 32 ]
854- . copy_from_slice ( Sha3_256 :: digest ( ek_bytes ) . unwrap ( ) . as_ref ( ) ) ;
866+ . copy_from_slice ( Sha3_256 :: digest ( & encap_key . bytes ) . unwrap ( ) . as_ref ( ) ) ;
855867 decap_key. bytes [ ( ENCODE_SIZE_POLY * K ) + Pke :: EK_SIZE + 32
856868 ..( ENCODE_SIZE_POLY * K ) + Pke :: EK_SIZE + 32 + 32 ]
857869 . copy_from_slice ( & seed. unprotected_as_bytes ( ) [ 32 ..64 ] ) ;
858870
859- let encap_key = decap_key. ek . clone ( ) ; // TODO: Can we maybe get rid of this clone? Probably some internal API changes propagating upwards
860- debug_assert_eq ! (
861- decap_key. get_encapsulation_key_bytes( ) ,
862- decap_key. ek. as_ref( )
863- ) ;
871+ debug_assert_eq ! ( decap_key. get_encapsulation_key_bytes( ) , encap_key. as_ref( ) ) ;
864872
865873 Ok ( ( encap_key, decap_key) )
866874 }
@@ -927,20 +935,6 @@ mod tests {
927935 use crate :: hazardous:: kem:: ml_kem:: mlkem512:: KeyPair as MlKem512KeyPair ;
928936 use crate :: hazardous:: kem:: ml_kem:: mlkem768:: KeyPair as MlKem768KeyPair ;
929937
930- #[ test]
931- fn test_keypair_dk_ek_match_internal ( ) {
932- let seed = Seed :: from_slice ( & [ 128u8 ; 64 ] ) . unwrap ( ) ;
933-
934- let kp = MlKem512KeyPair :: try_from ( & seed) . unwrap ( ) ;
935- assert_eq ! ( kp. public( ) . value, kp. private( ) . value. ek) ;
936-
937- let kp = MlKem768KeyPair :: try_from ( & seed) . unwrap ( ) ;
938- assert_eq ! ( kp. public( ) . value, kp. private( ) . value. ek) ;
939-
940- let kp = MlKem1024KeyPair :: try_from ( & seed) . unwrap ( ) ;
941- assert_eq ! ( kp. public( ) . value, kp. private( ) . value. ek) ;
942- }
943-
944938 #[ test]
945939 #[ cfg( feature = "safe_api" ) ]
946940 fn test_seed_and_dk_mismatch ( ) {
0 commit comments