Skip to content

Commit d858403

Browse files
committed
refactoring of ML-KEM public key validation into PublicKeyParameters - relates to github #1974
1 parent 33ce50e commit d858403

File tree

4 files changed

+47
-35
lines changed

4 files changed

+47
-35
lines changed

core/src/main/java/org/bouncycastle/pqc/crypto/mlkem/MLKEMEngine.java

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
class MLKEMEngine
88
{
99
private SecureRandom random;
10-
private MLKEMIndCpa indCpa;
10+
11+
private final MLKEMIndCpa indCpa;
1112

1213
// constant parameters
1314
public final static int KyberN = 256;
@@ -281,29 +282,18 @@ public byte[] kemDecryptInternal(byte[] secretKey, byte[] cipherText)
281282
return Arrays.copyOfRange(kr, 0, sessionKeyLength);
282283
}
283284

284-
public byte[][] kemEncrypt(byte[] publicKeyInput, byte[] randBytes)
285+
MLKEMIndCpa getIndCpa()
285286
{
286-
//TODO: do input validation elsewhere?
287-
// Input validation (6.2 ML-KEM Encaps)
288-
// Type Check
289-
if (publicKeyInput.length != KyberIndCpaPublicKeyBytes)
290-
{
291-
throw new IllegalArgumentException("Input validation Error: Type check failed for ml-kem encapsulation");
292-
}
293-
// Modulus Check
294-
PolyVec polyVec = new PolyVec(this);
295-
byte[] seed = indCpa.unpackPublicKey(polyVec, publicKeyInput);
296-
byte[] ek = indCpa.packPublicKey(polyVec, seed);
297-
if (!Arrays.areEqual(ek, publicKeyInput))
298-
{
299-
throw new IllegalArgumentException("Input validation: Modulus check failed for ml-kem encapsulation");
300-
}
287+
return indCpa;
288+
}
301289

290+
byte[][] kemEncrypt(byte[] publicKeyInput, byte[] randBytes)
291+
{
302292
return kemEncryptInternal(publicKeyInput, randBytes);
303293
}
304-
public byte[] kemDecrypt(byte[] secretKey, byte[] cipherText)
294+
295+
byte[] kemDecrypt(byte[] secretKey, byte[] cipherText)
305296
{
306-
//TODO: do input validation
307297
return kemDecryptInternal(secretKey, cipherText);
308298
}
309299

core/src/main/java/org/bouncycastle/pqc/crypto/mlkem/MLKEMIndCpa.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
class MLKEMIndCpa
66
{
7-
private MLKEMEngine engine;
8-
private int kyberK;
9-
private int indCpaPublicKeyBytes;
10-
private int polyVecBytes;
11-
private int indCpaBytes;
12-
private int polyVecCompressedBytes;
13-
private int polyCompressedBytes;
7+
private final MLKEMEngine engine;
8+
private final int kyberK;
9+
private final int indCpaPublicKeyBytes;
10+
private final int polyVecBytes;
11+
private final int indCpaBytes;
12+
private final int polyVecCompressedBytes;
13+
private final int polyCompressedBytes;
1414

1515
private Symmetric symmetric;
1616

core/src/main/java/org/bouncycastle/pqc/crypto/mlkem/MLKEMPublicKeyParameters.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,44 @@ static byte[] getEncoded(byte[] t, byte[] rho)
1616
public MLKEMPublicKeyParameters(MLKEMParameters params, byte[] t, byte[] rho)
1717
{
1818
super(false, params);
19+
20+
validatePublicKey(params.getEngine(), getEncoded(t, rho));
21+
1922
this.t = Arrays.clone(t);
2023
this.rho = Arrays.clone(rho);
2124
}
2225

2326
public MLKEMPublicKeyParameters(MLKEMParameters params, byte[] encoding)
2427
{
2528
super(false, params);
29+
30+
validatePublicKey(params.getEngine(), encoding);
31+
2632
this.t = Arrays.copyOfRange(encoding, 0, encoding.length - MLKEMEngine.KyberSymBytes);
2733
this.rho = Arrays.copyOfRange(encoding, encoding.length - MLKEMEngine.KyberSymBytes, encoding.length);
2834
}
2935

36+
private static void validatePublicKey(MLKEMEngine engine, byte[] publicKeyInput)
37+
{
38+
// Input validation (6.2 ML-KEM Encaps)
39+
// length Check
40+
if (publicKeyInput.length != engine.getKyberIndCpaPublicKeyBytes())
41+
{
42+
throw new IllegalArgumentException("length check failed for ml-kem public key construction");
43+
}
44+
45+
// Modulus Check
46+
PolyVec polyVec = new PolyVec(engine);
47+
MLKEMIndCpa indCpa = engine.getIndCpa();
48+
49+
byte[] seed = indCpa.unpackPublicKey(polyVec, publicKeyInput);
50+
byte[] ek = indCpa.packPublicKey(polyVec, seed);
51+
if (!Arrays.areEqual(ek, publicKeyInput))
52+
{
53+
throw new IllegalArgumentException("modulus check failed for ml-kem public key construction");
54+
}
55+
}
56+
3057
public byte[] getEncoded()
3158
{
3259
return getEncoded(t, rho);

core/src/test/java/org/bouncycastle/pqc/crypto/test/MLKEMTest.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -414,20 +414,15 @@ public void testModulus() throws IOException
414414
byte[] key = Hex.decode(line);
415415
MLKEMParameters parameters = params[fileIndex];
416416

417-
MLKEMPublicKeyParameters pubParams = (MLKEMPublicKeyParameters) PublicKeyFactory.createKey(
418-
SubjectPublicKeyInfoFactory.createSubjectPublicKeyInfo(new MLKEMPublicKeyParameters(parameters, key)));
419-
420-
// KEM Enc
421-
SecureRandom random = new SecureRandom();
422-
MLKEMGenerator generator = new MLKEMGenerator(random);
423417
try
424418
{
425-
SecretWithEncapsulation secWenc = generator.generateEncapsulated(pubParams);
426-
byte[] generated_cipher_text = secWenc.getEncapsulation();
419+
MLKEMPublicKeyParameters pubParams = (MLKEMPublicKeyParameters)PublicKeyFactory.createKey(
420+
SubjectPublicKeyInfoFactory.createSubjectPublicKeyInfo(new MLKEMPublicKeyParameters(parameters, key)));
427421
fail();
428422
}
429-
catch (Exception ignored)
423+
catch (IllegalArgumentException e)
430424
{
425+
assertEquals("modulus check failed for ml-kem public key construction", e.getMessage());
431426
}
432427
}
433428
}

0 commit comments

Comments
 (0)