Skip to content

Commit 24797f6

Browse files
committed
Update of PQC key encodings for ML-KEM and ML-DSA (no encoding preservation)
1 parent 6e801d1 commit 24797f6

File tree

8 files changed

+181
-59
lines changed

8 files changed

+181
-59
lines changed

core/src/main/java/org/bouncycastle/pqc/crypto/util/PrivateKeyFactory.java

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.bouncycastle.asn1.ASN1OctetString;
1212
import org.bouncycastle.asn1.ASN1Primitive;
1313
import org.bouncycastle.asn1.ASN1Sequence;
14+
import org.bouncycastle.asn1.ASN1TaggedObject;
1415
import org.bouncycastle.asn1.DEROctetString;
1516
import org.bouncycastle.asn1.bc.BCObjectIdentifiers;
1617
import org.bouncycastle.asn1.nist.NISTObjectIdentifiers;
@@ -228,10 +229,28 @@ else if (algOID.equals(NISTObjectIdentifiers.id_alg_ml_kem_512) ||
228229
algOID.equals(NISTObjectIdentifiers.id_alg_ml_kem_768) ||
229230
algOID.equals(NISTObjectIdentifiers.id_alg_ml_kem_1024))
230231
{
231-
ASN1OctetString mlkemKey = parseOctetString(keyInfo.getPrivateKey(), 64);
232+
ASN1Primitive mlkemKey = parsePrimitiveString(keyInfo.getPrivateKey(), 64);
232233
MLKEMParameters mlkemParams = Utils.mlkemParamsLookup(algOID);
233234

234-
return new MLKEMPrivateKeyParameters(mlkemParams, mlkemKey.getOctets());
235+
if (mlkemKey instanceof ASN1Sequence)
236+
{
237+
ASN1Sequence keySeq = ASN1Sequence.getInstance(mlkemKey);
238+
239+
if (keySeq.getObjectAt(0) instanceof ASN1OctetString)
240+
{
241+
return new MLKEMPrivateKeyParameters(mlkemParams, ASN1OctetString.getInstance(keySeq.getObjectAt(0)).getOctets());
242+
}
243+
else
244+
{
245+
return new MLKEMPrivateKeyParameters(mlkemParams, ASN1OctetString.getInstance((ASN1TaggedObject)keySeq.getObjectAt(0), false).getOctets());
246+
}
247+
}
248+
else if (mlkemKey instanceof ASN1OctetString)
249+
{
250+
return new MLKEMPrivateKeyParameters(mlkemParams, ASN1OctetString.getInstance(mlkemKey).getOctets());
251+
}
252+
253+
throw new IllegalArgumentException("unknown key format");
235254
}
236255
else if (algOID.on(BCObjectIdentifiers.pqc_kem_ntrulprime))
237256
{
@@ -260,18 +279,33 @@ else if (algOID.on(BCObjectIdentifiers.pqc_kem_sntruprime))
260279
}
261280
else if (Utils.mldsaParams.containsKey(algOID))
262281
{
263-
ASN1Encodable keyObj = parseOctetString(keyInfo.getPrivateKey(), 32);
282+
ASN1Encodable keyObj = parsePrimitiveString(keyInfo.getPrivateKey(), 32);
264283
MLDSAParameters spParams = Utils.mldsaParamsLookup(algOID);
265284

266-
if (keyObj instanceof DEROctetString)
285+
MLDSAPublicKeyParameters pubParams = null;
286+
if (keyInfo.getPublicKeyData() != null)
287+
{
288+
pubParams = PublicKeyFactory.MLDSAConverter.getPublicKeyParams(spParams, keyInfo.getPublicKeyData());
289+
}
290+
291+
if (keyObj instanceof ASN1OctetString)
267292
{
268293
byte[] data = ASN1OctetString.getInstance(keyObj).getOctets();
269-
if (keyInfo.getPublicKeyData() != null)
294+
295+
return new MLDSAPrivateKeyParameters(spParams, data, pubParams);
296+
}
297+
else if (keyObj instanceof ASN1Sequence)
298+
{
299+
ASN1Sequence keySeq = ASN1Sequence.getInstance(keyObj);
300+
301+
if (keySeq.getObjectAt(0) instanceof ASN1OctetString)
302+
{
303+
return new MLDSAPrivateKeyParameters(spParams, ASN1OctetString.getInstance(keySeq.getObjectAt(0)).getOctets(), pubParams);
304+
}
305+
else
270306
{
271-
MLDSAPublicKeyParameters pubParams = PublicKeyFactory.MLDSAConverter.getPublicKeyParams(spParams, keyInfo.getPublicKeyData());
272-
return new MLDSAPrivateKeyParameters(spParams, data, pubParams);
307+
return new MLDSAPrivateKeyParameters(spParams, ASN1OctetString.getInstance((ASN1TaggedObject)keySeq.getObjectAt(0), false).getOctets(), pubParams);
273308
}
274-
return new MLDSAPrivateKeyParameters(spParams, data);
275309
}
276310
else
277311
{
@@ -464,15 +498,49 @@ private static ASN1OctetString parseOctetString(ASN1OctetString octStr, int expe
464498

465499
//
466500
// possible internal OCTET STRING, possibly long form with or without the internal OCTET STRING
467-
data = Utils.readOctetString(data);
468-
if (data != null)
501+
ASN1OctetString obj = Utils.parseOctetData(data);
502+
503+
if (obj != null)
504+
{
505+
return ASN1OctetString.getInstance(obj);
506+
}
507+
508+
return octStr;
509+
}
510+
511+
/**
512+
* So it seems for the new PQC algorithms, there's a couple of approaches to what goes in the OCTET STRING
513+
* and in this case there may also be SEQUENCE.
514+
*/
515+
private static ASN1Primitive parsePrimitiveString(ASN1OctetString octStr, int expectedLength)
516+
throws IOException
517+
{
518+
byte[] data = octStr.getOctets();
519+
//
520+
// it's the right length for a RAW encoding, just return it.
521+
//
522+
if (data.length == expectedLength)
523+
{
524+
return octStr;
525+
}
526+
527+
//
528+
// possible internal OCTET STRING, possibly long form with or without the internal OCTET STRING
529+
// or possible SEQUENCE
530+
ASN1Encodable obj = Utils.parseData(data);
531+
532+
if (obj instanceof ASN1OctetString)
469533
{
470-
return new DEROctetString(data);
534+
return ASN1OctetString.getInstance(obj);
535+
}
536+
if (obj instanceof ASN1Sequence)
537+
{
538+
return ASN1Sequence.getInstance(obj);
471539
}
472540

473541
return octStr;
474542
}
475-
543+
476544
private static short[] convert(byte[] octets)
477545
{
478546
short[] rv = new short[octets.length / 2];

core/src/main/java/org/bouncycastle/pqc/crypto/util/PrivateKeyInfoFactory.java

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import java.io.IOException;
44

55
import org.bouncycastle.asn1.ASN1EncodableVector;
6+
import org.bouncycastle.asn1.ASN1Sequence;
67
import org.bouncycastle.asn1.ASN1Set;
78
import org.bouncycastle.asn1.DEROctetString;
89
import org.bouncycastle.asn1.DERSequence;
10+
import org.bouncycastle.asn1.DERTaggedObject;
911
import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers;
1012
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
1113
import org.bouncycastle.asn1.x509.AlgorithmIdentifier;
@@ -32,7 +34,6 @@
3234
import org.bouncycastle.pqc.crypto.lms.HSSPrivateKeyParameters;
3335
import org.bouncycastle.pqc.crypto.lms.LMSPrivateKeyParameters;
3436
import org.bouncycastle.pqc.crypto.mldsa.MLDSAPrivateKeyParameters;
35-
import org.bouncycastle.pqc.crypto.mldsa.MLDSAPublicKeyParameters;
3637
import org.bouncycastle.pqc.crypto.mlkem.MLKEMPrivateKeyParameters;
3738
import org.bouncycastle.pqc.crypto.newhope.NHPrivateKeyParameters;
3839
import org.bouncycastle.pqc.crypto.ntru.NTRUPrivateKeyParameters;
@@ -246,16 +247,18 @@ else if (privateKey instanceof MLKEMPrivateKeyParameters)
246247
MLKEMPrivateKeyParameters params = (MLKEMPrivateKeyParameters)privateKey;
247248

248249
AlgorithmIdentifier algorithmIdentifier = new AlgorithmIdentifier(Utils.mlkemOidLookup(params.getParameters()));
249-
250-
byte[] seed = params.getSeed();
251-
if (seed == null)
252-
{
253-
return new PrivateKeyInfo(algorithmIdentifier, params.getEncoded(), attributes);
254-
}
255-
else
256-
{
257-
return new PrivateKeyInfo(algorithmIdentifier, seed, attributes);
258-
}
250+
251+
return new PrivateKeyInfo(algorithmIdentifier, getBasicPQCEncoding(params.getSeed(), params.getEncoded()), attributes);
252+
// byte[] seed = params.getSeed();
253+
//
254+
// if (seed == null)
255+
// {
256+
// return new PrivateKeyInfo(algorithmIdentifier, params.getEncoded(), attributes);
257+
// }
258+
// else
259+
// {
260+
// return new PrivateKeyInfo(algorithmIdentifier, seed, attributes);
261+
// }
259262
}
260263
else if (privateKey instanceof NTRULPRimePrivateKeyParameters)
261264
{
@@ -294,19 +297,20 @@ else if (privateKey instanceof MLDSAPrivateKeyParameters)
294297

295298
AlgorithmIdentifier algorithmIdentifier = new AlgorithmIdentifier(Utils.mldsaOidLookup(params.getParameters()));
296299

297-
byte[] seed = params.getSeed();
298-
if (seed == null)
299-
{
300-
MLDSAPublicKeyParameters pubParams = params.getPublicKeyParameters();
301-
302-
return new PrivateKeyInfo(algorithmIdentifier, params.getEncoded(), attributes, pubParams.getEncoded());
303-
}
304-
else
305-
{
306-
MLDSAPublicKeyParameters pubParams = params.getPublicKeyParameters();
307-
308-
return new PrivateKeyInfo(algorithmIdentifier, params.getSeed(), attributes);
309-
}
300+
return new PrivateKeyInfo(algorithmIdentifier, getBasicPQCEncoding(params.getSeed(), params.getEncoded()), attributes);
301+
// byte[] seed = params.getSeed();
302+
// if (seed == null)
303+
// {
304+
// MLDSAPublicKeyParameters pubParams = params.getPublicKeyParameters();
305+
//
306+
// return new PrivateKeyInfo(algorithmIdentifier, params.getEncoded(), attributes, pubParams.getEncoded());
307+
// }
308+
// else
309+
// {
310+
// MLDSAPublicKeyParameters pubParams = params.getPublicKeyParameters();
311+
//
312+
// return new PrivateKeyInfo(algorithmIdentifier, seed, attributes, pubParams.getEncoded());
313+
// }
310314
}
311315
else if (privateKey instanceof DilithiumPrivateKeyParameters)
312316
{
@@ -395,6 +399,23 @@ private static XMSSPrivateKey xmssCreateKeyStructure(XMSSPrivateKeyParameters ke
395399
}
396400
}
397401

402+
private static ASN1Sequence getBasicPQCEncoding(byte[] seed, byte[] expanded)
403+
{
404+
ASN1EncodableVector v = new ASN1EncodableVector(2);
405+
406+
if (seed != null)
407+
{
408+
v.add(new DEROctetString(seed));
409+
}
410+
411+
if (expanded != null)
412+
{
413+
v.add(new DERTaggedObject(false, 1, new DEROctetString(expanded)));
414+
}
415+
416+
return new DERSequence(v);
417+
}
418+
398419
private static XMSSMTPrivateKey xmssmtCreateKeyStructure(XMSSMTPrivateKeyParameters keyParams)
399420
throws IOException
400421
{

core/src/main/java/org/bouncycastle/pqc/crypto/util/PublicKeyFactory.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -438,11 +438,11 @@ AsymmetricKeyParameter getPublicKeyParameters(SubjectPublicKeyInfo keyInfo, Obje
438438
throws IOException
439439
{
440440
byte[] keyEnc = keyInfo.getPublicKeyData().getOctets();
441-
byte[] data = Utils.readOctetString(keyEnc);
441+
ASN1OctetString data = (ASN1OctetString)Utils.parseData(keyEnc);
442442

443443
if (data != null)
444444
{
445-
return getLmsKeyParameters(data);
445+
return getLmsKeyParameters(data.getOctets());
446446
}
447447

448448
return getLmsKeyParameters(keyEnc);
@@ -567,11 +567,11 @@ AsymmetricKeyParameter getPublicKeyParameters(SubjectPublicKeyInfo keyInfo, Obje
567567
throws IOException
568568
{
569569
byte[] keyEnc = keyInfo.getPublicKeyData().getOctets();
570-
byte[] data = Utils.readOctetString(keyEnc);
570+
ASN1OctetString data = Utils.parseOctetData(keyEnc);
571571

572572
if (data != null)
573573
{
574-
return getNtruPublicKeyParameters(keyInfo, data);
574+
return getNtruPublicKeyParameters(keyInfo, data.getOctets());
575575
}
576576

577577
return getNtruPublicKeyParameters(keyInfo, keyEnc);

core/src/main/java/org/bouncycastle/pqc/crypto/util/Utils.java

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package org.bouncycastle.pqc.crypto.util;
22

33
import java.io.ByteArrayInputStream;
4-
import java.io.IOException;
54
import java.util.HashMap;
65
import java.util.Map;
76

87
import org.bouncycastle.asn1.ASN1ObjectIdentifier;
8+
import org.bouncycastle.asn1.ASN1OctetString;
9+
import org.bouncycastle.asn1.ASN1Primitive;
10+
import org.bouncycastle.asn1.ASN1Sequence;
911
import org.bouncycastle.asn1.BERTags;
1012
import org.bouncycastle.asn1.DERNull;
1113
import org.bouncycastle.asn1.bc.BCObjectIdentifiers;
@@ -38,7 +40,6 @@
3840
import org.bouncycastle.pqc.crypto.xmss.XMSSKeyParameters;
3941
import org.bouncycastle.pqc.legacy.crypto.qtesla.QTESLASecurityCategory;
4042
import org.bouncycastle.util.Integers;
41-
import org.bouncycastle.util.io.Streams;
4243

4344
class Utils
4445
{
@@ -783,18 +784,48 @@ static RainbowParameters rainbowParamsLookup(ASN1ObjectIdentifier oid)
783784
return (RainbowParameters)rainbowParams.get(oid);
784785
}
785786

786-
static byte[] readOctetString(byte[] data)
787-
throws IOException
787+
private static boolean isRaw(byte[] data)
788788
{
789-
if (data[0] == BERTags.OCTET_STRING)
789+
// check well-formed first
790+
ByteArrayInputStream bIn = new ByteArrayInputStream(data);
791+
792+
int tag = bIn.read();
793+
int len = readLen(bIn);
794+
if (len != bIn.available())
795+
{
796+
return true;
797+
}
798+
799+
return false;
800+
}
801+
802+
static ASN1OctetString parseOctetData(byte[] data)
803+
{
804+
// check well-formed first
805+
if (!isRaw(data))
806+
{
807+
if (data[0] == BERTags.OCTET_STRING)
808+
{
809+
return ASN1OctetString.getInstance(data);
810+
}
811+
}
812+
813+
return null;
814+
}
815+
816+
static ASN1Primitive parseData(byte[] data)
817+
{
818+
// check well-formed first
819+
if (!isRaw(data))
790820
{
791-
ByteArrayInputStream bIn = new ByteArrayInputStream(data);
821+
if (data[0] == (BERTags.SEQUENCE | BERTags.CONSTRUCTED))
822+
{
823+
return ASN1Sequence.getInstance(data);
824+
}
792825

793-
int tag = bIn.read();
794-
int len = readLen(bIn);
795-
if (len == bIn.available())
826+
if (data[0] == BERTags.OCTET_STRING)
796827
{
797-
return Streams.readAll(bIn);
828+
return ASN1OctetString.getInstance(data);
798829
}
799830
}
800831

prov/src/main/java/org/bouncycastle/jcajce/provider/asymmetric/mlkem/BCMLKEMPrivateKey.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public BCMLKEMPrivateKey(PrivateKeyInfo keyInfo)
4242
private void init(PrivateKeyInfo keyInfo)
4343
throws IOException
4444
{
45-
this.attributes = keyInfo.getAttributes();;
45+
this.attributes = keyInfo.getAttributes();
4646
this.params = (MLKEMPrivateKeyParameters)PrivateKeyFactory.createKey(keyInfo);
4747
this.algorithm = Strings.toUpperCase(MLKEMParameterSpec.fromName(params.getParameters().getName()).getName());
4848
}

0 commit comments

Comments
 (0)