Skip to content

Commit ca4632b

Browse files
committed
added both check for key/seed verification for ML-KEM and ML-DSA.
1 parent d6af532 commit ca4632b

File tree

3 files changed

+101
-66
lines changed

3 files changed

+101
-66
lines changed

core/src/main/java/org/bouncycastle/pqc/crypto/util/PrivateKeyFactory.java

Lines changed: 63 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.bouncycastle.asn1.ASN1OctetString;
1212
import org.bouncycastle.asn1.ASN1Primitive;
1313
import org.bouncycastle.asn1.ASN1Sequence;
14-
import org.bouncycastle.asn1.ASN1TaggedObject;
1514
import org.bouncycastle.asn1.DEROctetString;
1615
import org.bouncycastle.asn1.bc.BCObjectIdentifiers;
1716
import org.bouncycastle.asn1.nist.NISTObjectIdentifiers;
@@ -49,6 +48,7 @@
4948
import org.bouncycastle.pqc.crypto.mldsa.MLDSAPublicKeyParameters;
5049
import org.bouncycastle.pqc.crypto.mlkem.MLKEMParameters;
5150
import org.bouncycastle.pqc.crypto.mlkem.MLKEMPrivateKeyParameters;
51+
import org.bouncycastle.pqc.crypto.mlkem.MLKEMPublicKeyParameters;
5252
import org.bouncycastle.pqc.crypto.newhope.NHPrivateKeyParameters;
5353
import org.bouncycastle.pqc.crypto.ntru.NTRUParameters;
5454
import org.bouncycastle.pqc.crypto.ntru.NTRUPrivateKeyParameters;
@@ -92,7 +92,7 @@ public class PrivateKeyFactory
9292
* @throws IOException on an error decoding the key
9393
*/
9494
public static AsymmetricKeyParameter createKey(byte[] privateKeyInfoData)
95-
throws IOException
95+
throws IOException
9696
{
9797
if (privateKeyInfoData == null)
9898
{
@@ -114,7 +114,7 @@ public static AsymmetricKeyParameter createKey(byte[] privateKeyInfoData)
114114
* @throws IOException on an error decoding the key
115115
*/
116116
public static AsymmetricKeyParameter createKey(InputStream inStr)
117-
throws IOException
117+
throws IOException
118118
{
119119
return createKey(PrivateKeyInfo.getInstance(new ASN1InputStream(inStr).readObject()));
120120
}
@@ -127,7 +127,7 @@ public static AsymmetricKeyParameter createKey(InputStream inStr)
127127
* @throws IOException on an error decoding the key
128128
*/
129129
public static AsymmetricKeyParameter createKey(PrivateKeyInfo keyInfo)
130-
throws IOException
130+
throws IOException
131131
{
132132
if (keyInfo == null)
133133
{
@@ -146,7 +146,7 @@ public static AsymmetricKeyParameter createKey(PrivateKeyInfo keyInfo)
146146
else if (algOID.equals(PQCObjectIdentifiers.sphincs256))
147147
{
148148
return new SPHINCSPrivateKeyParameters(ASN1OctetString.getInstance(keyInfo.parsePrivateKey()).getOctets(),
149-
Utils.sphincs256LookupTreeAlgName(SPHINCS256KeyParams.getInstance(algId.getParameters())));
149+
Utils.sphincs256LookupTreeAlgName(SPHINCS256KeyParams.getInstance(algId.getParameters())));
150150
}
151151
else if (algOID.equals(PQCObjectIdentifiers.newHope))
152152
{
@@ -176,7 +176,7 @@ else if (algOID.on(BCObjectIdentifiers.sphincsPlus) || algOID.on(BCObjectIdentif
176176
SPHINCSPLUSPrivateKey spKey = SPHINCSPLUSPrivateKey.getInstance(obj);
177177
SPHINCSPLUSPublicKey publicKey = spKey.getPublicKey();
178178
return new SPHINCSPlusPrivateKeyParameters(spParams, spKey.getSkseed(), spKey.getSkprf(),
179-
publicKey.getPkseed(), publicKey.getPkroot());
179+
publicKey.getPkseed(), publicKey.getPkroot());
180180
}
181181
else
182182
{
@@ -226,24 +226,29 @@ else if (algOID.on(BCObjectIdentifiers.pqc_kem_ntru))
226226
return new NTRUPrivateKeyParameters(spParams, keyEnc);
227227
}
228228
else if (algOID.equals(NISTObjectIdentifiers.id_alg_ml_kem_512) ||
229-
algOID.equals(NISTObjectIdentifiers.id_alg_ml_kem_768) ||
230-
algOID.equals(NISTObjectIdentifiers.id_alg_ml_kem_1024))
229+
algOID.equals(NISTObjectIdentifiers.id_alg_ml_kem_768) ||
230+
algOID.equals(NISTObjectIdentifiers.id_alg_ml_kem_1024))
231231
{
232232
ASN1Primitive mlkemKey = parsePrimitiveString(keyInfo.getPrivateKey(), 64);
233233
MLKEMParameters mlkemParams = Utils.mlkemParamsLookup(algOID);
234234

235+
MLKEMPublicKeyParameters pubParams = null;
236+
if (keyInfo.getPublicKeyData() != null)
237+
{
238+
pubParams = PublicKeyFactory.MLKEMConverter.getPublicKeyParams(mlkemParams, keyInfo.getPublicKeyData());
239+
}
240+
235241
if (mlkemKey instanceof ASN1Sequence)
236242
{
237243
ASN1Sequence keySeq = ASN1Sequence.getInstance(mlkemKey);
238244

239-
if (keySeq.getObjectAt(0) instanceof ASN1OctetString)
240-
{
241-
return new MLKEMPrivateKeyParameters(mlkemParams, ASN1OctetString.getInstance(keySeq.getObjectAt(0)).getOctets());
242-
}
243-
else
245+
MLKEMPrivateKeyParameters mlkemPriv = new MLKEMPrivateKeyParameters(mlkemParams, ASN1OctetString.getInstance(keySeq.getObjectAt(0)).getOctets(), pubParams);
246+
if (!Arrays.constantTimeAreEqual(mlkemPriv.getEncoded(), ASN1OctetString.getInstance(keySeq.getObjectAt(1)).getOctets()))
244247
{
245-
return new MLKEMPrivateKeyParameters(mlkemParams, ASN1OctetString.getInstance((ASN1TaggedObject)keySeq.getObjectAt(0), false).getOctets());
248+
throw new IllegalStateException("seed/expanded-key mismatch");
246249
}
250+
251+
return mlkemPriv;
247252
}
248253
else if (mlkemKey instanceof ASN1OctetString)
249254
{
@@ -259,10 +264,10 @@ else if (algOID.on(BCObjectIdentifiers.pqc_kem_ntrulprime))
259264
NTRULPRimeParameters spParams = Utils.ntrulprimeParamsLookup(algOID);
260265

261266
return new NTRULPRimePrivateKeyParameters(spParams,
262-
ASN1OctetString.getInstance(keyEnc.getObjectAt(0)).getOctets(),
263-
ASN1OctetString.getInstance(keyEnc.getObjectAt(1)).getOctets(),
264-
ASN1OctetString.getInstance(keyEnc.getObjectAt(2)).getOctets(),
265-
ASN1OctetString.getInstance(keyEnc.getObjectAt(3)).getOctets());
267+
ASN1OctetString.getInstance(keyEnc.getObjectAt(0)).getOctets(),
268+
ASN1OctetString.getInstance(keyEnc.getObjectAt(1)).getOctets(),
269+
ASN1OctetString.getInstance(keyEnc.getObjectAt(2)).getOctets(),
270+
ASN1OctetString.getInstance(keyEnc.getObjectAt(3)).getOctets());
266271
}
267272
else if (algOID.on(BCObjectIdentifiers.pqc_kem_sntruprime))
268273
{
@@ -271,11 +276,11 @@ else if (algOID.on(BCObjectIdentifiers.pqc_kem_sntruprime))
271276
SNTRUPrimeParameters spParams = Utils.sntruprimeParamsLookup(algOID);
272277

273278
return new SNTRUPrimePrivateKeyParameters(spParams,
274-
ASN1OctetString.getInstance(keyEnc.getObjectAt(0)).getOctets(),
275-
ASN1OctetString.getInstance(keyEnc.getObjectAt(1)).getOctets(),
276-
ASN1OctetString.getInstance(keyEnc.getObjectAt(2)).getOctets(),
277-
ASN1OctetString.getInstance(keyEnc.getObjectAt(3)).getOctets(),
278-
ASN1OctetString.getInstance(keyEnc.getObjectAt(4)).getOctets());
279+
ASN1OctetString.getInstance(keyEnc.getObjectAt(0)).getOctets(),
280+
ASN1OctetString.getInstance(keyEnc.getObjectAt(1)).getOctets(),
281+
ASN1OctetString.getInstance(keyEnc.getObjectAt(2)).getOctets(),
282+
ASN1OctetString.getInstance(keyEnc.getObjectAt(3)).getOctets(),
283+
ASN1OctetString.getInstance(keyEnc.getObjectAt(4)).getOctets());
279284
}
280285
else if (Utils.mldsaParams.containsKey(algOID))
281286
{
@@ -298,22 +303,21 @@ else if (keyObj instanceof ASN1Sequence)
298303
{
299304
ASN1Sequence keySeq = ASN1Sequence.getInstance(keyObj);
300305

301-
if (keySeq.getObjectAt(0) instanceof ASN1OctetString)
302-
{
303-
return new MLDSAPrivateKeyParameters(spParams, ASN1OctetString.getInstance(keySeq.getObjectAt(0)).getOctets(), pubParams);
304-
}
305-
else
306+
MLDSAPrivateKeyParameters mldsaPriv = new MLDSAPrivateKeyParameters(spParams, ASN1OctetString.getInstance(keySeq.getObjectAt(0)).getOctets(), pubParams);
307+
if (!Arrays.constantTimeAreEqual(mldsaPriv.getEncoded(), ASN1OctetString.getInstance(keySeq.getObjectAt(1)).getOctets()))
306308
{
307-
return new MLDSAPrivateKeyParameters(spParams, ASN1OctetString.getInstance((ASN1TaggedObject)keySeq.getObjectAt(0), false).getOctets(), pubParams);
309+
throw new IllegalStateException("seed/expanded-key mismatch");
308310
}
311+
312+
return mldsaPriv;
309313
}
310314
else
311315
{
312316
throw new IOException("not supported");
313317
}
314318
}
315319
else if (algOID.equals(BCObjectIdentifiers.dilithium2)
316-
|| algOID.equals(BCObjectIdentifiers.dilithium3) || algOID.equals(BCObjectIdentifiers.dilithium5))
320+
|| algOID.equals(BCObjectIdentifiers.dilithium3) || algOID.equals(BCObjectIdentifiers.dilithium5))
317321
{
318322
ASN1Encodable keyObj = keyInfo.parsePrivateKey();
319323
DilithiumParameters dilParams = Utils.dilithiumParamsLookup(algOID);
@@ -333,24 +337,24 @@ else if (algOID.equals(BCObjectIdentifiers.dilithium2)
333337
DilithiumPublicKeyParameters pubParams = PublicKeyFactory.DilithiumConverter.getPublicKeyParams(dilParams, keyInfo.getPublicKeyData());
334338

335339
return new DilithiumPrivateKeyParameters(dilParams,
336-
ASN1BitString.getInstance(keyEnc.getObjectAt(1)).getOctets(),
337-
ASN1BitString.getInstance(keyEnc.getObjectAt(2)).getOctets(),
338-
ASN1BitString.getInstance(keyEnc.getObjectAt(3)).getOctets(),
339-
ASN1BitString.getInstance(keyEnc.getObjectAt(4)).getOctets(),
340-
ASN1BitString.getInstance(keyEnc.getObjectAt(5)).getOctets(),
341-
ASN1BitString.getInstance(keyEnc.getObjectAt(6)).getOctets(),
342-
pubParams.getT1()); // encT1
340+
ASN1BitString.getInstance(keyEnc.getObjectAt(1)).getOctets(),
341+
ASN1BitString.getInstance(keyEnc.getObjectAt(2)).getOctets(),
342+
ASN1BitString.getInstance(keyEnc.getObjectAt(3)).getOctets(),
343+
ASN1BitString.getInstance(keyEnc.getObjectAt(4)).getOctets(),
344+
ASN1BitString.getInstance(keyEnc.getObjectAt(5)).getOctets(),
345+
ASN1BitString.getInstance(keyEnc.getObjectAt(6)).getOctets(),
346+
pubParams.getT1()); // encT1
343347
}
344348
else
345349
{
346350
return new DilithiumPrivateKeyParameters(dilParams,
347-
ASN1BitString.getInstance(keyEnc.getObjectAt(1)).getOctets(),
348-
ASN1BitString.getInstance(keyEnc.getObjectAt(2)).getOctets(),
349-
ASN1BitString.getInstance(keyEnc.getObjectAt(3)).getOctets(),
350-
ASN1BitString.getInstance(keyEnc.getObjectAt(4)).getOctets(),
351-
ASN1BitString.getInstance(keyEnc.getObjectAt(5)).getOctets(),
352-
ASN1BitString.getInstance(keyEnc.getObjectAt(6)).getOctets(),
353-
null);
351+
ASN1BitString.getInstance(keyEnc.getObjectAt(1)).getOctets(),
352+
ASN1BitString.getInstance(keyEnc.getObjectAt(2)).getOctets(),
353+
ASN1BitString.getInstance(keyEnc.getObjectAt(3)).getOctets(),
354+
ASN1BitString.getInstance(keyEnc.getObjectAt(4)).getOctets(),
355+
ASN1BitString.getInstance(keyEnc.getObjectAt(5)).getOctets(),
356+
ASN1BitString.getInstance(keyEnc.getObjectAt(6)).getOctets(),
357+
null);
354358
}
355359
}
356360
else if (keyObj instanceof DEROctetString)
@@ -409,12 +413,12 @@ else if (algOID.equals(PQCObjectIdentifiers.xmss))
409413
try
410414
{
411415
XMSSPrivateKeyParameters.Builder keyBuilder = new XMSSPrivateKeyParameters
412-
.Builder(new XMSSParameters(keyParams.getHeight(), Utils.getDigest(treeDigest)))
413-
.withIndex(xmssPrivateKey.getIndex())
414-
.withSecretKeySeed(xmssPrivateKey.getSecretKeySeed())
415-
.withSecretKeyPRF(xmssPrivateKey.getSecretKeyPRF())
416-
.withPublicSeed(xmssPrivateKey.getPublicSeed())
417-
.withRoot(xmssPrivateKey.getRoot());
416+
.Builder(new XMSSParameters(keyParams.getHeight(), Utils.getDigest(treeDigest)))
417+
.withIndex(xmssPrivateKey.getIndex())
418+
.withSecretKeySeed(xmssPrivateKey.getSecretKeySeed())
419+
.withSecretKeyPRF(xmssPrivateKey.getSecretKeyPRF())
420+
.withPublicSeed(xmssPrivateKey.getPublicSeed())
421+
.withRoot(xmssPrivateKey.getRoot());
418422

419423
if (xmssPrivateKey.getVersion() != 0)
420424
{
@@ -423,7 +427,7 @@ else if (algOID.equals(PQCObjectIdentifiers.xmss))
423427

424428
if (xmssPrivateKey.getBdsState() != null)
425429
{
426-
BDS bds = (BDS) XMSSUtil.deserialize(xmssPrivateKey.getBdsState(), BDS.class);
430+
BDS bds = (BDS)XMSSUtil.deserialize(xmssPrivateKey.getBdsState(), BDS.class);
427431
keyBuilder.withBDSState(bds.withWOTSDigest(treeDigest));
428432
}
429433

@@ -444,12 +448,12 @@ else if (algOID.equals(PQCObjectIdentifiers.xmss_mt))
444448
XMSSMTPrivateKey xmssMtPrivateKey = XMSSMTPrivateKey.getInstance(keyInfo.parsePrivateKey());
445449

446450
XMSSMTPrivateKeyParameters.Builder keyBuilder = new XMSSMTPrivateKeyParameters
447-
.Builder(new XMSSMTParameters(keyParams.getHeight(), keyParams.getLayers(), Utils.getDigest(treeDigest)))
448-
.withIndex(xmssMtPrivateKey.getIndex())
449-
.withSecretKeySeed(xmssMtPrivateKey.getSecretKeySeed())
450-
.withSecretKeyPRF(xmssMtPrivateKey.getSecretKeyPRF())
451-
.withPublicSeed(xmssMtPrivateKey.getPublicSeed())
452-
.withRoot(xmssMtPrivateKey.getRoot());
451+
.Builder(new XMSSMTParameters(keyParams.getHeight(), keyParams.getLayers(), Utils.getDigest(treeDigest)))
452+
.withIndex(xmssMtPrivateKey.getIndex())
453+
.withSecretKeySeed(xmssMtPrivateKey.getSecretKeySeed())
454+
.withSecretKeyPRF(xmssMtPrivateKey.getSecretKeyPRF())
455+
.withPublicSeed(xmssMtPrivateKey.getPublicSeed())
456+
.withRoot(xmssMtPrivateKey.getRoot());
453457

454458
if (xmssMtPrivateKey.getVersion() != 0)
455459
{
@@ -458,7 +462,7 @@ else if (algOID.equals(PQCObjectIdentifiers.xmss_mt))
458462

459463
if (xmssMtPrivateKey.getBdsState() != null)
460464
{
461-
BDSStateMap bdsState = (BDSStateMap) XMSSUtil.deserialize(xmssMtPrivateKey.getBdsState(), BDSStateMap.class);
465+
BDSStateMap bdsState = (BDSStateMap)XMSSUtil.deserialize(xmssMtPrivateKey.getBdsState(), BDSStateMap.class);
462466
keyBuilder.withBDSState(bdsState.withWOTSDigest(treeDigest));
463467
}
464468

@@ -528,7 +532,7 @@ private static ASN1Primitive parsePrimitiveString(ASN1OctetString octStr, int ex
528532
// possible internal OCTET STRING, possibly long form with or without the internal OCTET STRING
529533
// or possible SEQUENCE
530534
ASN1Encodable obj = Utils.parseData(data);
531-
535+
532536
if (obj instanceof ASN1OctetString)
533537
{
534538
return ASN1OctetString.getInstance(obj);

core/src/main/java/org/bouncycastle/pqc/crypto/util/PublicKeyFactory.java

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,12 @@ public class PublicKeyFactory
186186
converters.put(BCObjectIdentifiers.ntruhrss1373, new NtruConverter());
187187
converters.put(BCObjectIdentifiers.falcon_512, new FalconConverter());
188188
converters.put(BCObjectIdentifiers.falcon_1024, new FalconConverter());
189-
converters.put(NISTObjectIdentifiers.id_alg_ml_kem_512, new KyberConverter());
190-
converters.put(NISTObjectIdentifiers.id_alg_ml_kem_768, new KyberConverter());
191-
converters.put(NISTObjectIdentifiers.id_alg_ml_kem_1024, new KyberConverter());
192-
converters.put(BCObjectIdentifiers.kyber512_aes, new KyberConverter());
193-
converters.put(BCObjectIdentifiers.kyber768_aes, new KyberConverter());
194-
converters.put(BCObjectIdentifiers.kyber1024_aes, new KyberConverter());
189+
converters.put(NISTObjectIdentifiers.id_alg_ml_kem_512, new MLKEMConverter());
190+
converters.put(NISTObjectIdentifiers.id_alg_ml_kem_768, new MLKEMConverter());
191+
converters.put(NISTObjectIdentifiers.id_alg_ml_kem_1024, new MLKEMConverter());
192+
converters.put(BCObjectIdentifiers.kyber512_aes, new MLKEMConverter());
193+
converters.put(BCObjectIdentifiers.kyber768_aes, new MLKEMConverter());
194+
converters.put(BCObjectIdentifiers.kyber1024_aes, new MLKEMConverter());
195195
converters.put(BCObjectIdentifiers.ntrulpr653, new NTRULPrimeConverter());
196196
converters.put(BCObjectIdentifiers.ntrulpr761, new NTRULPrimeConverter());
197197
converters.put(BCObjectIdentifiers.ntrulpr857, new NTRULPrimeConverter());
@@ -602,7 +602,7 @@ AsymmetricKeyParameter getPublicKeyParameters(SubjectPublicKeyInfo keyInfo, Obje
602602
}
603603
}
604604

605-
private static class KyberConverter
605+
static class MLKEMConverter
606606
extends SubjectPublicKeyInfoConverter
607607
{
608608
AsymmetricKeyParameter getPublicKeyParameters(SubjectPublicKeyInfo keyInfo, Object defaultParams)
@@ -623,6 +623,33 @@ AsymmetricKeyParameter getPublicKeyParameters(SubjectPublicKeyInfo keyInfo, Obje
623623
return new MLKEMPublicKeyParameters(kyberParameters, keyInfo.getPublicKeyData().getOctets());
624624
}
625625
}
626+
627+
static MLKEMPublicKeyParameters getPublicKeyParams(MLKEMParameters dilithiumParams, ASN1BitString publicKeyData)
628+
{
629+
try
630+
{
631+
ASN1Primitive obj = ASN1Primitive.fromByteArray(publicKeyData.getOctets());
632+
if (obj instanceof ASN1Sequence)
633+
{
634+
ASN1Sequence keySeq = ASN1Sequence.getInstance(obj);
635+
636+
return new MLKEMPublicKeyParameters(dilithiumParams,
637+
ASN1OctetString.getInstance(keySeq.getObjectAt(0)).getOctets(),
638+
ASN1OctetString.getInstance(keySeq.getObjectAt(1)).getOctets());
639+
}
640+
else
641+
{
642+
byte[] encKey = ASN1OctetString.getInstance(obj).getOctets();
643+
644+
return new MLKEMPublicKeyParameters(dilithiumParams, encKey);
645+
}
646+
}
647+
catch (Exception e)
648+
{
649+
// we're a raw encoding
650+
return new MLKEMPublicKeyParameters(dilithiumParams, publicKeyData.getOctets());
651+
}
652+
}
626653
}
627654

628655
private static class NTRULPrimeConverter

prov/src/main/java/org/bouncycastle/pqc/jcajce/provider/util/BaseKeyFactorySpi.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ public PrivateKey engineGeneratePrivate(KeySpec keySpec)
6464
{
6565
throw e;
6666
}
67+
catch (IllegalStateException e)
68+
{
69+
throw new InvalidKeySpecException(e.getMessage());
70+
}
6771
catch (Exception e)
6872
{
6973
throw new InvalidKeySpecException(e.toString());

0 commit comments

Comments
 (0)