Skip to content

Commit 7822234

Browse files
committed
updated parsing to handle CHOICE in ML-DSA.
1 parent 86381da commit 7822234

File tree

3 files changed

+50
-7
lines changed

3 files changed

+50
-7
lines changed

core/src/main/java/org/bouncycastle/pqc/crypto/mldsa/MLDSAPrivateKeyParameters.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
public class MLDSAPrivateKeyParameters
66
extends MLDSAKeyParameters
77
{
8+
public static final int BOTH = 0;
9+
public static final int SEED_ONLY = 1;
10+
public static final int EXPANDED_KEY = 2;
11+
812
final byte[] rho;
913
final byte[] k;
1014
final byte[] tr;
@@ -15,6 +19,8 @@ public class MLDSAPrivateKeyParameters
1519
private final byte[] t1;
1620
private final byte[] seed;
1721

22+
private final int prefFormat;
23+
1824
public MLDSAPrivateKeyParameters(MLDSAParameters params, byte[] encoding)
1925
{
2026
this(params, encoding, null);
@@ -36,6 +42,7 @@ public MLDSAPrivateKeyParameters(MLDSAParameters params, byte[] rho, byte[] K, b
3642
this.t0 = Arrays.clone(t0);
3743
this.t1 = Arrays.clone(t1);
3844
this.seed = Arrays.clone(seed);
45+
this.prefFormat = (seed != null) ? BOTH : EXPANDED_KEY;
3946
}
4047

4148
public MLDSAPrivateKeyParameters(MLDSAParameters params, byte[] encoding, MLDSAPublicKeyParameters pubKey)
@@ -86,6 +93,36 @@ public MLDSAPrivateKeyParameters(MLDSAParameters params, byte[] encoding, MLDSAP
8693

8794
this.seed = null;
8895
}
96+
this.prefFormat = (seed != null) ? BOTH : EXPANDED_KEY;
97+
}
98+
99+
private MLDSAPrivateKeyParameters(MLDSAPrivateKeyParameters params, int preferredFormat)
100+
{
101+
super(true, params.getParameters());
102+
103+
this.rho = params.rho;
104+
this.k = params.k;
105+
this.tr = params.tr;
106+
this.s1 = params.s1;
107+
this.s2 = params.s2;
108+
this.t0 = params.t0;
109+
this.t1 = params.t1;
110+
this.seed = params.seed;
111+
this.prefFormat = preferredFormat;
112+
}
113+
114+
public MLDSAPrivateKeyParameters getParametersWithFormat(int format)
115+
{
116+
if (this.seed == null && format == SEED_ONLY)
117+
{
118+
throw new IllegalArgumentException("no seed available");
119+
}
120+
return new MLDSAPrivateKeyParameters(this, format);
121+
}
122+
123+
public int getPreferredFormat()
124+
{
125+
return prefFormat;
89126
}
90127

91128
public byte[] getEncoded()

core/src/main/java/org/bouncycastle/pqc/crypto/util/PrivateKeyInfoFactory.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -298,13 +298,13 @@ else if (privateKey instanceof MLDSAPrivateKeyParameters)
298298
AlgorithmIdentifier algorithmIdentifier = new AlgorithmIdentifier(Utils.mldsaOidLookup(params.getParameters()));
299299

300300
byte[] seed = params.getSeed();
301-
if (Properties.isOverrideSet("org.bouncycastle.mldsa.seedOnly"))
301+
if (params.getPreferredFormat() == MLDSAPrivateKeyParameters.SEED_ONLY)
302302
{
303-
if (seed == null) // very difficult to imagine, but...
304-
{
305-
throw new IOException("no seed available");
306-
}
307-
return new PrivateKeyInfo(algorithmIdentifier, seed, attributes);
303+
return new PrivateKeyInfo(algorithmIdentifier, new DERTaggedObject(false, 0, new DEROctetString(params.getSeed())), attributes);
304+
}
305+
else if (params.getPreferredFormat() == MLDSAPrivateKeyParameters.EXPANDED_KEY)
306+
{
307+
return new PrivateKeyInfo(algorithmIdentifier, new DEROctetString(params.getEncoded()), attributes);
308308
}
309309
return new PrivateKeyInfo(algorithmIdentifier, getBasicPQCEncoding(params.getSeed(), params.getEncoded()), attributes);
310310
}
@@ -406,7 +406,7 @@ private static ASN1Sequence getBasicPQCEncoding(byte[] seed, byte[] expanded)
406406

407407
if (expanded != null)
408408
{
409-
v.add(new DERTaggedObject(false, 1, new DEROctetString(expanded)));
409+
v.add(new DEROctetString(expanded));
410410
}
411411

412412
return new DERSequence(v);

core/src/main/java/org/bouncycastle/pqc/crypto/util/Utils.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import org.bouncycastle.asn1.ASN1OctetString;
99
import org.bouncycastle.asn1.ASN1Primitive;
1010
import org.bouncycastle.asn1.ASN1Sequence;
11+
import org.bouncycastle.asn1.ASN1TaggedObject;
1112
import org.bouncycastle.asn1.BERTags;
1213
import org.bouncycastle.asn1.DERNull;
1314
import org.bouncycastle.asn1.bc.BCObjectIdentifiers;
@@ -830,6 +831,11 @@ static ASN1Primitive parseData(byte[] data)
830831
{
831832
return ASN1OctetString.getInstance(data);
832833
}
834+
835+
if ((data[0] & 0xff) == BERTags.TAGGED)
836+
{
837+
return ASN1OctetString.getInstance(ASN1TaggedObject.getInstance(data), false);
838+
}
833839
}
834840

835841
return null;

0 commit comments

Comments
 (0)