Skip to content

Commit 4257ff9

Browse files
committed
added seed based storage for ML-KEM and ML-DSA private keys.
1 parent 2418e7f commit 4257ff9

File tree

10 files changed

+122
-37
lines changed

10 files changed

+122
-37
lines changed

core/src/main/java/org/bouncycastle/pqc/crypto/mldsa/MLDSAEngine.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ else if (this.DilithiumGamma1 == (1 << 19))
227227
}
228228

229229
//Internal functions are deterministic. No randomness is sampled inside them
230-
private byte[][] generateKeyPairInternal(byte[] seed)
230+
byte[][] generateKeyPairInternal(byte[] seed)
231231
{
232232
byte[] buf = new byte[2 * SeedBytes + CrhBytes];
233233
byte[] tr = new byte[TrBytes];
@@ -301,7 +301,7 @@ private byte[][] generateKeyPairInternal(byte[] seed)
301301

302302
byte[][] sk = Packing.packSecretKey(rho, tr, key, t0, s1, s2, this);
303303

304-
return new byte[][]{sk[0], sk[1], sk[2], sk[3], sk[4], sk[5], encT1};
304+
return new byte[][]{ sk[0], sk[1], sk[2], sk[3], sk[4], sk[5], encT1, seed};
305305
}
306306

307307
SHAKEDigest getShake256Digest()

core/src/main/java/org/bouncycastle/pqc/crypto/mldsa/MLDSAKeyPairGenerator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public AsymmetricCipherKeyPair generateKeyPair()
2424

2525
byte[][] keyPair = engine.generateKeyPair();
2626
MLDSAPublicKeyParameters pubKey = new MLDSAPublicKeyParameters(dilithiumParams, keyPair[0], keyPair[6]);
27-
MLDSAPrivateKeyParameters privKey = new MLDSAPrivateKeyParameters(dilithiumParams, keyPair[0], keyPair[1], keyPair[2], keyPair[3], keyPair[4], keyPair[5], keyPair[6]);
27+
MLDSAPrivateKeyParameters privKey = new MLDSAPrivateKeyParameters(dilithiumParams, keyPair[0], keyPair[1], keyPair[2], keyPair[3], keyPair[4], keyPair[5], keyPair[6], keyPair[7]);
2828

2929
return new AsymmetricCipherKeyPair(pubKey, privKey);
3030
}

core/src/main/java/org/bouncycastle/pqc/crypto/mldsa/MLDSAPrivateKeyParameters.java

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,29 @@ public class MLDSAPrivateKeyParameters
1313
final byte[] t0;
1414

1515
private final byte[] t1;
16+
private final byte[] seed;
17+
18+
public MLDSAPrivateKeyParameters(MLDSAParameters params, byte[] seed)
19+
{
20+
super(true, params);
21+
byte[][] keyDetails = params.getEngine(null).generateKeyPairInternal(seed);
22+
23+
this.rho = keyDetails[0];
24+
this.k = keyDetails[1];
25+
this.tr = keyDetails[2];
26+
this.s1 = keyDetails[3];
27+
this.s2 = keyDetails[4];
28+
this.t0 = keyDetails[5];
29+
this.t1 = keyDetails[6];
30+
this.seed = keyDetails[7];
31+
}
1632

1733
public MLDSAPrivateKeyParameters(MLDSAParameters params, byte[] rho, byte[] K, byte[] tr, byte[] s1, byte[] s2, byte[] t0, byte[] t1)
34+
{
35+
this(params, rho, K, tr, s1, s2, t0, t1, null);
36+
}
37+
38+
public MLDSAPrivateKeyParameters(MLDSAParameters params, byte[] rho, byte[] K, byte[] tr, byte[] s1, byte[] s2, byte[] t0, byte[] t1, byte[] seed)
1839
{
1940
super(true, params);
2041
this.rho = Arrays.clone(rho);
@@ -24,6 +45,7 @@ public MLDSAPrivateKeyParameters(MLDSAParameters params, byte[] rho, byte[] K, b
2445
this.s2 = Arrays.clone(s2);
2546
this.t0 = Arrays.clone(t0);
2647
this.t1 = Arrays.clone(t1);
48+
this.seed = Arrays.clone(seed);
2749
}
2850

2951
public MLDSAPrivateKeyParameters(MLDSAParameters params, byte[] encoding, MLDSAPublicKeyParameters pubKey)
@@ -32,15 +54,21 @@ public MLDSAPrivateKeyParameters(MLDSAParameters params, byte[] encoding, MLDSAP
3254

3355
MLDSAEngine eng = params.getEngine(null);
3456
int index = 0;
35-
this.rho = Arrays.copyOfRange(encoding, 0, MLDSAEngine.SeedBytes); index += MLDSAEngine.SeedBytes;
36-
this.k = Arrays.copyOfRange(encoding, index, index + MLDSAEngine.SeedBytes); index += MLDSAEngine.SeedBytes;
37-
this.tr = Arrays.copyOfRange(encoding, index, index + MLDSAEngine.TrBytes); index += MLDSAEngine.TrBytes;
57+
this.rho = Arrays.copyOfRange(encoding, 0, MLDSAEngine.SeedBytes);
58+
index += MLDSAEngine.SeedBytes;
59+
this.k = Arrays.copyOfRange(encoding, index, index + MLDSAEngine.SeedBytes);
60+
index += MLDSAEngine.SeedBytes;
61+
this.tr = Arrays.copyOfRange(encoding, index, index + MLDSAEngine.TrBytes);
62+
index += MLDSAEngine.TrBytes;
3863
int delta = eng.getDilithiumL() * eng.getDilithiumPolyEtaPackedBytes();
39-
this.s1 = Arrays.copyOfRange(encoding, index, index + delta); index += delta;
64+
this.s1 = Arrays.copyOfRange(encoding, index, index + delta);
65+
index += delta;
4066
delta = eng.getDilithiumK() * eng.getDilithiumPolyEtaPackedBytes();
41-
this.s2 = Arrays.copyOfRange(encoding, index, index + delta); index += delta;
67+
this.s2 = Arrays.copyOfRange(encoding, index, index + delta);
68+
index += delta;
4269
delta = eng.getDilithiumK() * MLDSAEngine.DilithiumPolyT0PackedBytes;
43-
this.t0 = Arrays.copyOfRange(encoding, index, index + delta); index += delta;
70+
this.t0 = Arrays.copyOfRange(encoding, index, index + delta);
71+
index += delta;
4472

4573
if (pubKey != null)
4674
{
@@ -50,19 +78,22 @@ public MLDSAPrivateKeyParameters(MLDSAParameters params, byte[] encoding, MLDSAP
5078
{
5179
this.t1 = null;
5280
}
81+
this.seed = null;
5382
}
5483

5584
public byte[] getEncoded()
5685
{
57-
return Arrays.concatenate(new byte[][]{ rho, k, tr, s1, s2, t0 });
86+
return Arrays.concatenate(new byte[][]{rho, k, tr, s1, s2, t0});
5887
}
5988

6089
public byte[] getK()
6190
{
6291
return Arrays.clone(k);
6392
}
6493

65-
/** @deprecated Use {@link #getEncoded()} instead. */
94+
/**
95+
* @deprecated Use {@link #getEncoded()} instead.
96+
*/
6697
public byte[] getPrivateKey()
6798
{
6899
return getEncoded();
@@ -73,6 +104,11 @@ public byte[] getPublicKey()
73104
return MLDSAPublicKeyParameters.getEncoded(rho, t1);
74105
}
75106

107+
public byte[] getSeed()
108+
{
109+
return Arrays.clone(seed);
110+
}
111+
76112
public MLDSAPublicKeyParameters getPublicKeyParameters()
77113
{
78114
return new MLDSAPublicKeyParameters(getParameters(), rho, t1);

core/src/main/java/org/bouncycastle/pqc/crypto/mlkem/MLKEMEngine.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ public byte[][] generateKemKeyPairInternal(byte[] d, byte[] z)
202202

203203
byte[] outputPublicKey = new byte[KyberIndCpaPublicKeyBytes];
204204
System.arraycopy(indCpaKeyPair[0], 0, outputPublicKey, 0, KyberIndCpaPublicKeyBytes);
205-
return new byte[][]{ Arrays.copyOfRange(outputPublicKey, 0, outputPublicKey.length - 32), Arrays.copyOfRange(outputPublicKey, outputPublicKey.length - 32, outputPublicKey.length), s, hashedPublicKey, z };
205+
return new byte[][]{ Arrays.copyOfRange(outputPublicKey, 0, outputPublicKey.length - 32), Arrays.copyOfRange(outputPublicKey, outputPublicKey.length - 32, outputPublicKey.length), s, hashedPublicKey, z, Arrays.concatenate(d, z)};
206206
}
207207

208208
public byte[][] kemEncryptInternal(byte[] publicKeyInput, byte[] randBytes)

core/src/main/java/org/bouncycastle/pqc/crypto/mlkem/MLKEMKeyPairGenerator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ private AsymmetricCipherKeyPair genKeyPair()
3030
byte[][] keyPair = engine.generateKemKeyPair();
3131

3232
MLKEMPublicKeyParameters pubKey = new MLKEMPublicKeyParameters(mlkemParams, keyPair[0], keyPair[1]);
33-
MLKEMPrivateKeyParameters privKey = new MLKEMPrivateKeyParameters(mlkemParams, keyPair[2], keyPair[3], keyPair[4], keyPair[0], keyPair[1]);
33+
MLKEMPrivateKeyParameters privKey = new MLKEMPrivateKeyParameters(mlkemParams, keyPair[2], keyPair[3], keyPair[4], keyPair[0], keyPair[1], keyPair[5]);
3434

3535
return new AsymmetricCipherKeyPair(pubKey, privKey);
3636
}

core/src/main/java/org/bouncycastle/pqc/crypto/mlkem/MLKEMPrivateKeyParameters.java

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,14 @@ public class MLKEMPrivateKeyParameters
1010
final byte[] nonce;
1111
final byte[] t;
1212
final byte[] rho;
13-
13+
final byte[] seed;
14+
1415
public MLKEMPrivateKeyParameters(MLKEMParameters params, byte[] s, byte[] hpk, byte[] nonce, byte[] t, byte[] rho)
16+
{
17+
this(params, s, hpk, nonce, t, rho, null);
18+
}
19+
20+
public MLKEMPrivateKeyParameters(MLKEMParameters params, byte[] s, byte[] hpk, byte[] nonce, byte[] t, byte[] rho, byte[] seed)
1521
{
1622
super(true, params);
1723

@@ -20,19 +26,40 @@ public MLKEMPrivateKeyParameters(MLKEMParameters params, byte[] s, byte[] hpk, b
2026
this.nonce = Arrays.clone(nonce);
2127
this.t = Arrays.clone(t);
2228
this.rho = Arrays.clone(rho);
29+
this.seed = Arrays.clone(seed);
2330
}
2431

2532
public MLKEMPrivateKeyParameters(MLKEMParameters params, byte[] encoding)
2633
{
2734
super(true, params);
2835

2936
MLKEMEngine eng = params.getEngine();
30-
int index = 0;
31-
this.s = Arrays.copyOfRange(encoding, 0, eng.getKyberIndCpaSecretKeyBytes()); index += eng.getKyberIndCpaSecretKeyBytes();
32-
this.t = Arrays.copyOfRange(encoding, index, index + eng.getKyberIndCpaPublicKeyBytes() - MLKEMEngine.KyberSymBytes); index += eng.getKyberIndCpaPublicKeyBytes() - MLKEMEngine.KyberSymBytes;
33-
this.rho = Arrays.copyOfRange(encoding, index, index + 32); index += 32;
34-
this.hpk = Arrays.copyOfRange(encoding, index, index + 32); index += 32;
35-
this.nonce = Arrays.copyOfRange(encoding, index, index + MLKEMEngine.KyberSymBytes);
37+
if (encoding.length == MLKEMEngine.KyberSymBytes * 2)
38+
{
39+
byte[][] keyData = eng.generateKemKeyPairInternal(
40+
Arrays.copyOfRange(encoding, 0, MLKEMEngine.KyberSymBytes),
41+
Arrays.copyOfRange(encoding, MLKEMEngine.KyberSymBytes, encoding.length));
42+
this.s = keyData[2];
43+
this.hpk = keyData[3];
44+
this.nonce = keyData[4];
45+
this.t = keyData[0];
46+
this.rho = keyData[1];
47+
this.seed = keyData[5];
48+
}
49+
else
50+
{
51+
int index = 0;
52+
this.s = Arrays.copyOfRange(encoding, 0, eng.getKyberIndCpaSecretKeyBytes());
53+
index += eng.getKyberIndCpaSecretKeyBytes();
54+
this.t = Arrays.copyOfRange(encoding, index, index + eng.getKyberIndCpaPublicKeyBytes() - MLKEMEngine.KyberSymBytes);
55+
index += eng.getKyberIndCpaPublicKeyBytes() - MLKEMEngine.KyberSymBytes;
56+
this.rho = Arrays.copyOfRange(encoding, index, index + 32);
57+
index += 32;
58+
this.hpk = Arrays.copyOfRange(encoding, index, index + 32);
59+
index += 32;
60+
this.nonce = Arrays.copyOfRange(encoding, index, index + MLKEMEngine.KyberSymBytes);
61+
this.seed = null;
62+
}
3663
}
3764

3865
public byte[] getEncoded()
@@ -74,4 +101,9 @@ public byte[] getT()
74101
{
75102
return Arrays.clone(t);
76103
}
104+
105+
public byte[] getSeed()
106+
{
107+
return Arrays.clone(seed);
108+
}
77109
}

core/src/main/java/org/bouncycastle/pqc/crypto/util/PrivateKeyFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ else if (keyObj instanceof DEROctetString)
329329
MLDSAPublicKeyParameters pubParams = PublicKeyFactory.MLDSAConverter.getPublicKeyParams(spParams, keyInfo.getPublicKeyData());
330330
return new MLDSAPrivateKeyParameters(spParams, data, pubParams);
331331
}
332-
return new MLDSAPrivateKeyParameters(spParams, data, null);
332+
return new MLDSAPrivateKeyParameters(spParams, data);
333333
}
334334
else
335335
{

core/src/main/java/org/bouncycastle/pqc/crypto/util/PrivateKeyInfoFactory.java

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,15 @@ else if (privateKey instanceof MLKEMPrivateKeyParameters)
247247

248248
AlgorithmIdentifier algorithmIdentifier = new AlgorithmIdentifier(Utils.mlkemOidLookup(params.getParameters()));
249249

250-
return new PrivateKeyInfo(algorithmIdentifier, new DEROctetString(params.getEncoded()), attributes);
250+
byte[] seed = params.getSeed();
251+
if (seed == null)
252+
{
253+
return new PrivateKeyInfo(algorithmIdentifier, new DEROctetString(params.getEncoded()), attributes);
254+
}
255+
else
256+
{
257+
return new PrivateKeyInfo(algorithmIdentifier, new DEROctetString(seed), attributes);
258+
}
251259
}
252260
else if (privateKey instanceof NTRULPRimePrivateKeyParameters)
253261
{
@@ -286,9 +294,19 @@ else if (privateKey instanceof MLDSAPrivateKeyParameters)
286294

287295
AlgorithmIdentifier algorithmIdentifier = new AlgorithmIdentifier(Utils.mldsaOidLookup(params.getParameters()));
288296

289-
MLDSAPublicKeyParameters pubParams = params.getPublicKeyParameters();
297+
byte[] seed = params.getSeed();
298+
if (seed == null)
299+
{
300+
MLDSAPublicKeyParameters pubParams = params.getPublicKeyParameters();
301+
302+
return new PrivateKeyInfo(algorithmIdentifier, new DEROctetString(params.getEncoded()), attributes, pubParams.getEncoded());
303+
}
304+
else
305+
{
306+
MLDSAPublicKeyParameters pubParams = params.getPublicKeyParameters();
290307

291-
return new PrivateKeyInfo(algorithmIdentifier, new DEROctetString(params.getEncoded()), attributes, pubParams.getEncoded());
308+
return new PrivateKeyInfo(algorithmIdentifier, new DEROctetString(params.getSeed()), attributes);
309+
}
292310
}
293311
else if (privateKey instanceof DilithiumPrivateKeyParameters)
294312
{

0 commit comments

Comments
 (0)