@@ -155,16 +155,50 @@ where
155155 P :: concat_ek ( t_hat, self . rho . clone ( ) )
156156 }
157157
158- /// Parse an encryption key from a byte array `(t_hat || rho)`
159- // TODO(tarcieri): validate decoded keys
160- #[ allow( clippy:: unnecessary_wraps) ]
158+ /// Parse an encryption key from a byte array `(t_hat || rho)`.
159+ ///
160+ /// # Errors
161+ /// Returns [`Error`] in the event that the key fails the encapsulation key checks specified in
162+ /// FIPS 203 §7.2.
161163 pub fn from_bytes ( enc : & EncodedEncryptionKey < P > ) -> Result < Self , Error > {
162164 let ( t_hat, rho) = P :: split_ek ( enc) ;
163165 let t_hat = P :: decode_u12 ( t_hat) ;
164- Ok ( Self {
166+ let ret = Self {
165167 t_hat,
166168 rho : rho. clone ( ) ,
167- } )
169+ } ;
170+
171+ // Check the candidate encapsulation key is valid using the method specified in FIPS 203
172+ // §7.2 ML-KEM Encapsulation:
173+ //
174+ // > Encapsulation key check. To check a candidate encapsulation key `ek`, perform the
175+ // > following:
176+ // >
177+ // > 1. (Type check) If `ek` is not an array of bytes of length 384𝑘+32 for the value of 𝑘
178+ // > specified by the relevant parameter set, then input checking failed.
179+ // > 2. (Modulus check) Perform the computation:
180+ // >
181+ // > test ← ByteEncode₁₂(ByteDecode₁₂(ek[0:384𝑘]))
182+ // >
183+ // > (see Section 4.2.1). If `test ≠ ek[0∶384𝑘]`, then input checking failed. This
184+ // > check ensures that the integers encoded in the public key are in the valid range
185+ // > `[0,q-1]`.
186+ // >
187+ // > If both checks pass, then `ML-KEM.Encaps` can be run with input `ek`. It is important
188+ // > to note that this checking process does not guarantee that ek is a properly produced
189+ // > output of `ML-KEM.KeyGen`.
190+ // >
191+ // > `ML-KEM.Encaps` shall not be run with an encapsulation key that has not been checked as
192+ // > above.
193+ //
194+ // #1 is performed by the `EncodedEncryptionKey` type, and the following check vicariously
195+ // performs #2 by encoding the integer-mod-q array using our implementation of ByteEncode₁₂
196+ // and comparing the resulting serialization to see if it round-trips.
197+ if & ret. as_bytes ( ) == enc {
198+ Ok ( ret)
199+ } else {
200+ Err ( Error )
201+ }
168202 }
169203}
170204
@@ -221,4 +255,12 @@ mod test {
221255 codec_test :: < MlKem768Params > ( ) ;
222256 codec_test :: < MlKem1024Params > ( ) ;
223257 }
258+
259+ #[ test]
260+ fn reject_invalid_encryption_keys ( ) {
261+ // Create an invalid key: all bytes set to 0xFF
262+ // When decoded as 12-bit coefficients, this produces values of 0xFFF = 4095 > 3329
263+ let invalid_key = [ 0xFF ; 1184 ] ;
264+ assert ! ( EncryptionKey :: <MlKem768Params >:: from_bytes( & invalid_key. into( ) ) . is_err( ) ) ;
265+ }
224266}
0 commit comments