Skip to content

Commit 9bfdbe4

Browse files
committed
added KeySpecs for ML-KEM
added public key recovery to ML-DSA
1 parent aea65ff commit 9bfdbe4

File tree

12 files changed

+386
-56
lines changed

12 files changed

+386
-56
lines changed

core/src/main/java/org/bouncycastle/pqc/crypto/mldsa/MLDSAEngine.java

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,53 @@ byte[][] generateKeyPairInternal(byte[] seed)
300300

301301
byte[][] sk = Packing.packSecretKey(rho, tr, key, t0, s1, s2, this);
302302

303-
return new byte[][]{ sk[0], sk[1], sk[2], sk[3], sk[4], sk[5], encT1, seed};
303+
return new byte[][]{sk[0], sk[1], sk[2], sk[3], sk[4], sk[5], encT1, seed};
304+
}
305+
306+
byte[] deriveT1(byte[] rho, byte[] key, byte[] tr, byte[] s1Enc, byte[] s2Enc, byte[] t0Enc)
307+
{
308+
PolyVecMatrix aMatrix = new PolyVecMatrix(this);
309+
310+
PolyVecL s1 = new PolyVecL(this), s1hat;
311+
PolyVecK s2 = new PolyVecK(this), t1 = new PolyVecK(this), t0 = new PolyVecK(this);
312+
313+
Packing.unpackSecretKey(t0, s1, s2, t0Enc, s1Enc, s2Enc, this);
314+
315+
// System.out.print("rho = ");
316+
// Helper.printByteArray(rho);
317+
318+
// System.out.println("key = ");
319+
// Helper.printByteArray(key);
320+
321+
aMatrix.expandMatrix(rho);
322+
// System.out.print(aMatrix.toString("aMatrix"));
323+
324+
s1hat = new PolyVecL(this);
325+
326+
s1.copyPolyVecL(s1hat);
327+
s1hat.polyVecNtt();
328+
329+
// System.out.println(s1hat.toString("s1hat"));
330+
331+
aMatrix.pointwiseMontgomery(t1, s1hat);
332+
// System.out.println(t1.toString("t1"));
333+
334+
t1.reduce();
335+
t1.invNttToMont();
336+
337+
t1.addPolyVecK(s2);
338+
// System.out.println(s2.toString("s2"));
339+
// System.out.println(t1.toString("t1"));
340+
t1.conditionalAddQ();
341+
t1.power2Round(t0);
342+
343+
// System.out.println(t1.toString("t1"));
344+
// System.out.println(t0.toString("t0"));
345+
346+
byte[] encT1 = Packing.packPublicKey(t1, this);
347+
// System.out.println("enc t1 = ");
348+
// Helper.printByteArray(encT1);
349+
return encT1;
304350
}
305351

306352
SHAKEDigest getShake256Digest()

core/src/main/java/org/bouncycastle/pqc/crypto/mldsa/MLDSAPrivateKeyParameters.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public MLDSAPrivateKeyParameters(MLDSAParameters params, byte[] encoding, MLDSAP
8181
}
8282
else
8383
{
84-
this.t1 = null;
84+
this.t1 = eng.deriveT1(rho, k, tr, s1, s2, t0);;
8585
}
8686
this.seed = null;
8787
}
@@ -117,6 +117,11 @@ public byte[] getSeed()
117117

118118
public MLDSAPublicKeyParameters getPublicKeyParameters()
119119
{
120+
if (this.t1 == null)
121+
{
122+
return null;
123+
}
124+
120125
return new MLDSAPublicKeyParameters(getParameters(), rho, t1);
121126
}
122127

core/src/main/java/org/bouncycastle/pqc/crypto/mldsa/MLDSAPublicKeyParameters.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,23 @@ public MLDSAPublicKeyParameters(MLDSAParameters params, byte[] encoding)
1818
super(false, params);
1919
this.rho = Arrays.copyOfRange(encoding, 0, MLDSAEngine.SeedBytes);
2020
this.t1 = Arrays.copyOfRange(encoding, MLDSAEngine.SeedBytes, encoding.length);
21+
if (t1.length == 0)
22+
{
23+
throw new IllegalArgumentException("encoding too short");
24+
}
2125
}
2226

2327
public MLDSAPublicKeyParameters(MLDSAParameters params, byte[] rho, byte[] t1)
2428
{
2529
super(false, params);
30+
if (rho == null)
31+
{
32+
throw new NullPointerException("rho cannot be null");
33+
}
34+
if (t1 == null)
35+
{
36+
throw new NullPointerException("t1 cannot be null");
37+
}
2638
this.rho = Arrays.clone(rho);
2739
this.t1 = Arrays.clone(t1);
2840
}

prov/src/main/java/org/bouncycastle/jcajce/provider/asymmetric/mldsa/BCMLDSAPrivateKey.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.bouncycastle.jcajce.interfaces.MLDSAPublicKey;
1111
import org.bouncycastle.jcajce.spec.MLDSAParameterSpec;
1212
import org.bouncycastle.pqc.crypto.mldsa.MLDSAPrivateKeyParameters;
13+
import org.bouncycastle.pqc.crypto.mldsa.MLDSAPublicKeyParameters;
1314
import org.bouncycastle.pqc.crypto.util.PrivateKeyFactory;
1415
import org.bouncycastle.pqc.jcajce.provider.util.KeyUtil;
1516
import org.bouncycastle.util.Arrays;
@@ -101,7 +102,12 @@ public byte[] getEncoded()
101102

102103
public MLDSAPublicKey getPublicKey()
103104
{
104-
return new BCMLDSAPublicKey(params.getPublicKeyParameters());
105+
MLDSAPublicKeyParameters publicKeyParameters = params.getPublicKeyParameters();
106+
if (publicKeyParameters == null)
107+
{
108+
return null;
109+
}
110+
return new BCMLDSAPublicKey(publicKeyParameters);
105111
}
106112

107113
@Override

prov/src/main/java/org/bouncycastle/jcajce/provider/asymmetric/mldsa/MLDSAKeyFactorySpi.java

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
import org.bouncycastle.pqc.crypto.mldsa.MLDSAPrivateKeyParameters;
2323
import org.bouncycastle.pqc.crypto.mldsa.MLDSAPublicKeyParameters;
2424
import org.bouncycastle.pqc.jcajce.provider.util.BaseKeyFactorySpi;
25+
import org.bouncycastle.util.Arrays;
2526

2627
public class MLDSAKeyFactorySpi
27-
extends BaseKeyFactorySpi
28+
extends BaseKeyFactorySpi
2829
{
2930
private static final Set<ASN1ObjectIdentifier> pureKeyOids = new HashSet<ASN1ObjectIdentifier>();
3031
private static final Set<ASN1ObjectIdentifier> hashKeyOids = new HashSet<ASN1ObjectIdentifier>();
@@ -54,7 +55,7 @@ public MLDSAKeyFactorySpi(ASN1ObjectIdentifier keyOid)
5455
}
5556

5657
public final KeySpec engineGetKeySpec(Key key, Class keySpec)
57-
throws InvalidKeySpecException
58+
throws InvalidKeySpecException
5859
{
5960
if (key instanceof BCMLDSAPrivateKey)
6061
{
@@ -93,15 +94,15 @@ else if (key instanceof BCMLDSAPublicKey)
9394
else
9495
{
9596
throw new InvalidKeySpecException("Unsupported key type: "
96-
+ key.getClass() + ".");
97+
+ key.getClass() + ".");
9798
}
9899

99100
throw new InvalidKeySpecException("Unknown key specification: "
100-
+ keySpec + ".");
101+
+ keySpec + ".");
101102
}
102103

103104
public final Key engineTranslateKey(Key key)
104-
throws InvalidKeyException
105+
throws InvalidKeyException
105106
{
106107
if (key instanceof BCMLDSAPrivateKey || key instanceof BCMLDSAPublicKey)
107108
{
@@ -127,16 +128,15 @@ public PrivateKey engineGeneratePrivate(
127128
}
128129
else
129130
{
131+
params = new MLDSAPrivateKeyParameters(
132+
mldsaParameters, spec.getPrivateData(), null);
130133
byte[] publicData = spec.getPublicData();
131134
if (publicData != null)
132135
{
133-
params = new MLDSAPrivateKeyParameters(
134-
mldsaParameters, spec.getPrivateData(), new MLDSAPublicKeyParameters(mldsaParameters, publicData));
135-
}
136-
else
137-
{
138-
params = new MLDSAPrivateKeyParameters(
139-
mldsaParameters, spec.getPrivateData(), null);
136+
if (!Arrays.areEqual(publicData, params.getPublicKey()))
137+
{
138+
throw new InvalidKeySpecException("public key data does not match private key data");
139+
}
140140
}
141141
}
142142

@@ -163,19 +163,19 @@ public PublicKey engineGeneratePublic(
163163
}
164164

165165
public PrivateKey generatePrivate(PrivateKeyInfo keyInfo)
166-
throws IOException
166+
throws IOException
167167
{
168168
return new BCMLDSAPrivateKey(keyInfo);
169169
}
170170

171171
public PublicKey generatePublic(SubjectPublicKeyInfo keyInfo)
172-
throws IOException
172+
throws IOException
173173
{
174174
return new BCMLDSAPublicKey(keyInfo);
175175
}
176176

177177
public static class Pure
178-
extends MLDSAKeyFactorySpi
178+
extends MLDSAKeyFactorySpi
179179
{
180180
public Pure()
181181
{
@@ -184,7 +184,7 @@ public Pure()
184184
}
185185

186186
public static class MLDSA44
187-
extends MLDSAKeyFactorySpi
187+
extends MLDSAKeyFactorySpi
188188
{
189189
public MLDSA44()
190190
{
@@ -193,7 +193,7 @@ public MLDSA44()
193193
}
194194

195195
public static class MLDSA65
196-
extends MLDSAKeyFactorySpi
196+
extends MLDSAKeyFactorySpi
197197
{
198198
public MLDSA65()
199199
{
@@ -202,7 +202,7 @@ public MLDSA65()
202202
}
203203

204204
public static class MLDSA87
205-
extends MLDSAKeyFactorySpi
205+
extends MLDSAKeyFactorySpi
206206
{
207207
public MLDSA87()
208208
{
@@ -211,7 +211,7 @@ public MLDSA87()
211211
}
212212

213213
public static class Hash
214-
extends MLDSAKeyFactorySpi
214+
extends MLDSAKeyFactorySpi
215215
{
216216
public Hash()
217217
{
@@ -220,7 +220,7 @@ public Hash()
220220
}
221221

222222
public static class HashMLDSA44
223-
extends MLDSAKeyFactorySpi
223+
extends MLDSAKeyFactorySpi
224224
{
225225
public HashMLDSA44()
226226
{
@@ -229,7 +229,7 @@ public HashMLDSA44()
229229
}
230230

231231
public static class HashMLDSA65
232-
extends MLDSAKeyFactorySpi
232+
extends MLDSAKeyFactorySpi
233233
{
234234
public HashMLDSA65()
235235
{
@@ -238,7 +238,7 @@ public HashMLDSA65()
238238
}
239239

240240
public static class HashMLDSA87
241-
extends MLDSAKeyFactorySpi
241+
extends MLDSAKeyFactorySpi
242242
{
243243
public HashMLDSA87()
244244
{

0 commit comments

Comments
 (0)