1111import org .bouncycastle .asn1 .ASN1OctetString ;
1212import org .bouncycastle .asn1 .ASN1Primitive ;
1313import org .bouncycastle .asn1 .ASN1Sequence ;
14- import org .bouncycastle .asn1 .ASN1TaggedObject ;
1514import org .bouncycastle .asn1 .DEROctetString ;
1615import org .bouncycastle .asn1 .bc .BCObjectIdentifiers ;
1716import org .bouncycastle .asn1 .nist .NISTObjectIdentifiers ;
4948import org .bouncycastle .pqc .crypto .mldsa .MLDSAPublicKeyParameters ;
5049import org .bouncycastle .pqc .crypto .mlkem .MLKEMParameters ;
5150import org .bouncycastle .pqc .crypto .mlkem .MLKEMPrivateKeyParameters ;
51+ import org .bouncycastle .pqc .crypto .mlkem .MLKEMPublicKeyParameters ;
5252import org .bouncycastle .pqc .crypto .newhope .NHPrivateKeyParameters ;
5353import org .bouncycastle .pqc .crypto .ntru .NTRUParameters ;
5454import org .bouncycastle .pqc .crypto .ntru .NTRUPrivateKeyParameters ;
@@ -92,7 +92,7 @@ public class PrivateKeyFactory
9292 * @throws IOException on an error decoding the key
9393 */
9494 public static AsymmetricKeyParameter createKey (byte [] privateKeyInfoData )
95- throws IOException
95+ throws IOException
9696 {
9797 if (privateKeyInfoData == null )
9898 {
@@ -114,7 +114,7 @@ public static AsymmetricKeyParameter createKey(byte[] privateKeyInfoData)
114114 * @throws IOException on an error decoding the key
115115 */
116116 public static AsymmetricKeyParameter createKey (InputStream inStr )
117- throws IOException
117+ throws IOException
118118 {
119119 return createKey (PrivateKeyInfo .getInstance (new ASN1InputStream (inStr ).readObject ()));
120120 }
@@ -127,7 +127,7 @@ public static AsymmetricKeyParameter createKey(InputStream inStr)
127127 * @throws IOException on an error decoding the key
128128 */
129129 public static AsymmetricKeyParameter createKey (PrivateKeyInfo keyInfo )
130- throws IOException
130+ throws IOException
131131 {
132132 if (keyInfo == null )
133133 {
@@ -146,7 +146,7 @@ public static AsymmetricKeyParameter createKey(PrivateKeyInfo keyInfo)
146146 else if (algOID .equals (PQCObjectIdentifiers .sphincs256 ))
147147 {
148148 return new SPHINCSPrivateKeyParameters (ASN1OctetString .getInstance (keyInfo .parsePrivateKey ()).getOctets (),
149- Utils .sphincs256LookupTreeAlgName (SPHINCS256KeyParams .getInstance (algId .getParameters ())));
149+ Utils .sphincs256LookupTreeAlgName (SPHINCS256KeyParams .getInstance (algId .getParameters ())));
150150 }
151151 else if (algOID .equals (PQCObjectIdentifiers .newHope ))
152152 {
@@ -176,7 +176,7 @@ else if (algOID.on(BCObjectIdentifiers.sphincsPlus) || algOID.on(BCObjectIdentif
176176 SPHINCSPLUSPrivateKey spKey = SPHINCSPLUSPrivateKey .getInstance (obj );
177177 SPHINCSPLUSPublicKey publicKey = spKey .getPublicKey ();
178178 return new SPHINCSPlusPrivateKeyParameters (spParams , spKey .getSkseed (), spKey .getSkprf (),
179- publicKey .getPkseed (), publicKey .getPkroot ());
179+ publicKey .getPkseed (), publicKey .getPkroot ());
180180 }
181181 else
182182 {
@@ -226,24 +226,29 @@ else if (algOID.on(BCObjectIdentifiers.pqc_kem_ntru))
226226 return new NTRUPrivateKeyParameters (spParams , keyEnc );
227227 }
228228 else if (algOID .equals (NISTObjectIdentifiers .id_alg_ml_kem_512 ) ||
229- algOID .equals (NISTObjectIdentifiers .id_alg_ml_kem_768 ) ||
230- algOID .equals (NISTObjectIdentifiers .id_alg_ml_kem_1024 ))
229+ algOID .equals (NISTObjectIdentifiers .id_alg_ml_kem_768 ) ||
230+ algOID .equals (NISTObjectIdentifiers .id_alg_ml_kem_1024 ))
231231 {
232232 ASN1Primitive mlkemKey = parsePrimitiveString (keyInfo .getPrivateKey (), 64 );
233233 MLKEMParameters mlkemParams = Utils .mlkemParamsLookup (algOID );
234234
235+ MLKEMPublicKeyParameters pubParams = null ;
236+ if (keyInfo .getPublicKeyData () != null )
237+ {
238+ pubParams = PublicKeyFactory .MLKEMConverter .getPublicKeyParams (mlkemParams , keyInfo .getPublicKeyData ());
239+ }
240+
235241 if (mlkemKey instanceof ASN1Sequence )
236242 {
237243 ASN1Sequence keySeq = ASN1Sequence .getInstance (mlkemKey );
238244
239- if (keySeq .getObjectAt (0 ) instanceof ASN1OctetString )
240- {
241- return new MLKEMPrivateKeyParameters (mlkemParams , ASN1OctetString .getInstance (keySeq .getObjectAt (0 )).getOctets ());
242- }
243- else
245+ MLKEMPrivateKeyParameters mlkemPriv = new MLKEMPrivateKeyParameters (mlkemParams , ASN1OctetString .getInstance (keySeq .getObjectAt (0 )).getOctets (), pubParams );
246+ if (!Arrays .constantTimeAreEqual (mlkemPriv .getEncoded (), ASN1OctetString .getInstance (keySeq .getObjectAt (1 )).getOctets ()))
244247 {
245- return new MLKEMPrivateKeyParameters ( mlkemParams , ASN1OctetString . getInstance (( ASN1TaggedObject ) keySeq . getObjectAt ( 0 ), false ). getOctets () );
248+ throw new IllegalStateException ( "seed/expanded-key mismatch" );
246249 }
250+
251+ return mlkemPriv ;
247252 }
248253 else if (mlkemKey instanceof ASN1OctetString )
249254 {
@@ -259,10 +264,10 @@ else if (algOID.on(BCObjectIdentifiers.pqc_kem_ntrulprime))
259264 NTRULPRimeParameters spParams = Utils .ntrulprimeParamsLookup (algOID );
260265
261266 return new NTRULPRimePrivateKeyParameters (spParams ,
262- ASN1OctetString .getInstance (keyEnc .getObjectAt (0 )).getOctets (),
263- ASN1OctetString .getInstance (keyEnc .getObjectAt (1 )).getOctets (),
264- ASN1OctetString .getInstance (keyEnc .getObjectAt (2 )).getOctets (),
265- ASN1OctetString .getInstance (keyEnc .getObjectAt (3 )).getOctets ());
267+ ASN1OctetString .getInstance (keyEnc .getObjectAt (0 )).getOctets (),
268+ ASN1OctetString .getInstance (keyEnc .getObjectAt (1 )).getOctets (),
269+ ASN1OctetString .getInstance (keyEnc .getObjectAt (2 )).getOctets (),
270+ ASN1OctetString .getInstance (keyEnc .getObjectAt (3 )).getOctets ());
266271 }
267272 else if (algOID .on (BCObjectIdentifiers .pqc_kem_sntruprime ))
268273 {
@@ -271,11 +276,11 @@ else if (algOID.on(BCObjectIdentifiers.pqc_kem_sntruprime))
271276 SNTRUPrimeParameters spParams = Utils .sntruprimeParamsLookup (algOID );
272277
273278 return new SNTRUPrimePrivateKeyParameters (spParams ,
274- ASN1OctetString .getInstance (keyEnc .getObjectAt (0 )).getOctets (),
275- ASN1OctetString .getInstance (keyEnc .getObjectAt (1 )).getOctets (),
276- ASN1OctetString .getInstance (keyEnc .getObjectAt (2 )).getOctets (),
277- ASN1OctetString .getInstance (keyEnc .getObjectAt (3 )).getOctets (),
278- ASN1OctetString .getInstance (keyEnc .getObjectAt (4 )).getOctets ());
279+ ASN1OctetString .getInstance (keyEnc .getObjectAt (0 )).getOctets (),
280+ ASN1OctetString .getInstance (keyEnc .getObjectAt (1 )).getOctets (),
281+ ASN1OctetString .getInstance (keyEnc .getObjectAt (2 )).getOctets (),
282+ ASN1OctetString .getInstance (keyEnc .getObjectAt (3 )).getOctets (),
283+ ASN1OctetString .getInstance (keyEnc .getObjectAt (4 )).getOctets ());
279284 }
280285 else if (Utils .mldsaParams .containsKey (algOID ))
281286 {
@@ -298,22 +303,21 @@ else if (keyObj instanceof ASN1Sequence)
298303 {
299304 ASN1Sequence keySeq = ASN1Sequence .getInstance (keyObj );
300305
301- if (keySeq .getObjectAt (0 ) instanceof ASN1OctetString )
302- {
303- return new MLDSAPrivateKeyParameters (spParams , ASN1OctetString .getInstance (keySeq .getObjectAt (0 )).getOctets (), pubParams );
304- }
305- else
306+ MLDSAPrivateKeyParameters mldsaPriv = new MLDSAPrivateKeyParameters (spParams , ASN1OctetString .getInstance (keySeq .getObjectAt (0 )).getOctets (), pubParams );
307+ if (!Arrays .constantTimeAreEqual (mldsaPriv .getEncoded (), ASN1OctetString .getInstance (keySeq .getObjectAt (1 )).getOctets ()))
306308 {
307- return new MLDSAPrivateKeyParameters ( spParams , ASN1OctetString . getInstance (( ASN1TaggedObject ) keySeq . getObjectAt ( 0 ), false ). getOctets (), pubParams );
309+ throw new IllegalStateException ( "seed/expanded-key mismatch" );
308310 }
311+
312+ return mldsaPriv ;
309313 }
310314 else
311315 {
312316 throw new IOException ("not supported" );
313317 }
314318 }
315319 else if (algOID .equals (BCObjectIdentifiers .dilithium2 )
316- || algOID .equals (BCObjectIdentifiers .dilithium3 ) || algOID .equals (BCObjectIdentifiers .dilithium5 ))
320+ || algOID .equals (BCObjectIdentifiers .dilithium3 ) || algOID .equals (BCObjectIdentifiers .dilithium5 ))
317321 {
318322 ASN1Encodable keyObj = keyInfo .parsePrivateKey ();
319323 DilithiumParameters dilParams = Utils .dilithiumParamsLookup (algOID );
@@ -333,24 +337,24 @@ else if (algOID.equals(BCObjectIdentifiers.dilithium2)
333337 DilithiumPublicKeyParameters pubParams = PublicKeyFactory .DilithiumConverter .getPublicKeyParams (dilParams , keyInfo .getPublicKeyData ());
334338
335339 return new DilithiumPrivateKeyParameters (dilParams ,
336- ASN1BitString .getInstance (keyEnc .getObjectAt (1 )).getOctets (),
337- ASN1BitString .getInstance (keyEnc .getObjectAt (2 )).getOctets (),
338- ASN1BitString .getInstance (keyEnc .getObjectAt (3 )).getOctets (),
339- ASN1BitString .getInstance (keyEnc .getObjectAt (4 )).getOctets (),
340- ASN1BitString .getInstance (keyEnc .getObjectAt (5 )).getOctets (),
341- ASN1BitString .getInstance (keyEnc .getObjectAt (6 )).getOctets (),
342- pubParams .getT1 ()); // encT1
340+ ASN1BitString .getInstance (keyEnc .getObjectAt (1 )).getOctets (),
341+ ASN1BitString .getInstance (keyEnc .getObjectAt (2 )).getOctets (),
342+ ASN1BitString .getInstance (keyEnc .getObjectAt (3 )).getOctets (),
343+ ASN1BitString .getInstance (keyEnc .getObjectAt (4 )).getOctets (),
344+ ASN1BitString .getInstance (keyEnc .getObjectAt (5 )).getOctets (),
345+ ASN1BitString .getInstance (keyEnc .getObjectAt (6 )).getOctets (),
346+ pubParams .getT1 ()); // encT1
343347 }
344348 else
345349 {
346350 return new DilithiumPrivateKeyParameters (dilParams ,
347- ASN1BitString .getInstance (keyEnc .getObjectAt (1 )).getOctets (),
348- ASN1BitString .getInstance (keyEnc .getObjectAt (2 )).getOctets (),
349- ASN1BitString .getInstance (keyEnc .getObjectAt (3 )).getOctets (),
350- ASN1BitString .getInstance (keyEnc .getObjectAt (4 )).getOctets (),
351- ASN1BitString .getInstance (keyEnc .getObjectAt (5 )).getOctets (),
352- ASN1BitString .getInstance (keyEnc .getObjectAt (6 )).getOctets (),
353- null );
351+ ASN1BitString .getInstance (keyEnc .getObjectAt (1 )).getOctets (),
352+ ASN1BitString .getInstance (keyEnc .getObjectAt (2 )).getOctets (),
353+ ASN1BitString .getInstance (keyEnc .getObjectAt (3 )).getOctets (),
354+ ASN1BitString .getInstance (keyEnc .getObjectAt (4 )).getOctets (),
355+ ASN1BitString .getInstance (keyEnc .getObjectAt (5 )).getOctets (),
356+ ASN1BitString .getInstance (keyEnc .getObjectAt (6 )).getOctets (),
357+ null );
354358 }
355359 }
356360 else if (keyObj instanceof DEROctetString )
@@ -409,12 +413,12 @@ else if (algOID.equals(PQCObjectIdentifiers.xmss))
409413 try
410414 {
411415 XMSSPrivateKeyParameters .Builder keyBuilder = new XMSSPrivateKeyParameters
412- .Builder (new XMSSParameters (keyParams .getHeight (), Utils .getDigest (treeDigest )))
413- .withIndex (xmssPrivateKey .getIndex ())
414- .withSecretKeySeed (xmssPrivateKey .getSecretKeySeed ())
415- .withSecretKeyPRF (xmssPrivateKey .getSecretKeyPRF ())
416- .withPublicSeed (xmssPrivateKey .getPublicSeed ())
417- .withRoot (xmssPrivateKey .getRoot ());
416+ .Builder (new XMSSParameters (keyParams .getHeight (), Utils .getDigest (treeDigest )))
417+ .withIndex (xmssPrivateKey .getIndex ())
418+ .withSecretKeySeed (xmssPrivateKey .getSecretKeySeed ())
419+ .withSecretKeyPRF (xmssPrivateKey .getSecretKeyPRF ())
420+ .withPublicSeed (xmssPrivateKey .getPublicSeed ())
421+ .withRoot (xmssPrivateKey .getRoot ());
418422
419423 if (xmssPrivateKey .getVersion () != 0 )
420424 {
@@ -423,7 +427,7 @@ else if (algOID.equals(PQCObjectIdentifiers.xmss))
423427
424428 if (xmssPrivateKey .getBdsState () != null )
425429 {
426- BDS bds = (BDS ) XMSSUtil .deserialize (xmssPrivateKey .getBdsState (), BDS .class );
430+ BDS bds = (BDS )XMSSUtil .deserialize (xmssPrivateKey .getBdsState (), BDS .class );
427431 keyBuilder .withBDSState (bds .withWOTSDigest (treeDigest ));
428432 }
429433
@@ -444,12 +448,12 @@ else if (algOID.equals(PQCObjectIdentifiers.xmss_mt))
444448 XMSSMTPrivateKey xmssMtPrivateKey = XMSSMTPrivateKey .getInstance (keyInfo .parsePrivateKey ());
445449
446450 XMSSMTPrivateKeyParameters .Builder keyBuilder = new XMSSMTPrivateKeyParameters
447- .Builder (new XMSSMTParameters (keyParams .getHeight (), keyParams .getLayers (), Utils .getDigest (treeDigest )))
448- .withIndex (xmssMtPrivateKey .getIndex ())
449- .withSecretKeySeed (xmssMtPrivateKey .getSecretKeySeed ())
450- .withSecretKeyPRF (xmssMtPrivateKey .getSecretKeyPRF ())
451- .withPublicSeed (xmssMtPrivateKey .getPublicSeed ())
452- .withRoot (xmssMtPrivateKey .getRoot ());
451+ .Builder (new XMSSMTParameters (keyParams .getHeight (), keyParams .getLayers (), Utils .getDigest (treeDigest )))
452+ .withIndex (xmssMtPrivateKey .getIndex ())
453+ .withSecretKeySeed (xmssMtPrivateKey .getSecretKeySeed ())
454+ .withSecretKeyPRF (xmssMtPrivateKey .getSecretKeyPRF ())
455+ .withPublicSeed (xmssMtPrivateKey .getPublicSeed ())
456+ .withRoot (xmssMtPrivateKey .getRoot ());
453457
454458 if (xmssMtPrivateKey .getVersion () != 0 )
455459 {
@@ -458,7 +462,7 @@ else if (algOID.equals(PQCObjectIdentifiers.xmss_mt))
458462
459463 if (xmssMtPrivateKey .getBdsState () != null )
460464 {
461- BDSStateMap bdsState = (BDSStateMap ) XMSSUtil .deserialize (xmssMtPrivateKey .getBdsState (), BDSStateMap .class );
465+ BDSStateMap bdsState = (BDSStateMap )XMSSUtil .deserialize (xmssMtPrivateKey .getBdsState (), BDSStateMap .class );
462466 keyBuilder .withBDSState (bdsState .withWOTSDigest (treeDigest ));
463467 }
464468
@@ -528,7 +532,7 @@ private static ASN1Primitive parsePrimitiveString(ASN1OctetString octStr, int ex
528532 // possible internal OCTET STRING, possibly long form with or without the internal OCTET STRING
529533 // or possible SEQUENCE
530534 ASN1Encodable obj = Utils .parseData (data );
531-
535+
532536 if (obj instanceof ASN1OctetString )
533537 {
534538 return ASN1OctetString .getInstance (obj );
0 commit comments