Skip to content

Commit fdd1142

Browse files
Add Check for PQCPrivateKey in Decapsulator and add tests for invalid keys.
Signed-off-by: John Peck <140550562+johnpeck-us-ibm@users.noreply.github.com>
1 parent 89f3c70 commit fdd1142

File tree

3 files changed

+133
-29
lines changed

3 files changed

+133
-29
lines changed

src/main/java/com/ibm/crypto/plus/provider/MLKEMImpl.java

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,16 @@
1212
import com.ibm.crypto.plus.provider.ock.OJPKEM;
1313
import java.security.InvalidAlgorithmParameterException;
1414
import java.security.InvalidKeyException;
15+
import java.security.KeyFactory;
1516
import java.security.PrivateKey;
1617
import java.security.ProviderException;
1718
import java.security.PublicKey;
1819
import java.security.SecureRandom;
1920
import java.security.spec.AlgorithmParameterSpec;
21+
import java.security.spec.EncodedKeySpec;
22+
import java.security.spec.PKCS8EncodedKeySpec;
23+
import java.security.spec.X509EncodedKeySpec;
24+
import java.util.Arrays;
2025
import javax.crypto.DecapsulateException;
2126
import javax.crypto.KEM;
2227
import javax.crypto.KEMSpi;
@@ -49,7 +54,6 @@ private int getEncapsulationLength() {
4954
return size;
5055
}
5156

52-
5357
/*
5458
* spec - The AlgorithmParameterSpec is not used and should be null. If not null
5559
* it will be ignored.
@@ -60,14 +64,28 @@ private int getEncapsulationLength() {
6064
public KEMSpi.EncapsulatorSpi engineNewEncapsulator(PublicKey publicKey,
6165
AlgorithmParameterSpec spec, SecureRandom secureRandom)
6266
throws InvalidAlgorithmParameterException, InvalidKeyException {
63-
if (publicKey == null || !(publicKey instanceof PQCPublicKey) ) {
64-
throw new InvalidKeyException("unsupported key");
67+
68+
PublicKey pubKey = publicKey;
69+
if (pubKey == null) {
70+
throw new InvalidKeyException("Key is null.");
71+
}
72+
73+
if (!(pubKey instanceof PQCPublicKey)) {
74+
// Try and convert this key to a usage PQCPublicKey
75+
try {
76+
KeyFactory kf = KeyFactory.getInstance(this.alg, this.provider.getName());
77+
EncodedKeySpec publicKeySpec = new X509EncodedKeySpec(publicKey.getEncoded());
78+
pubKey = kf.generatePublic(publicKeySpec);
79+
80+
} catch (Exception e) {
81+
throw new InvalidKeyException("unsupported key", e);
82+
}
6583
}
6684

6785
if (spec != null) {
6886
throw new InvalidAlgorithmParameterException("no spec needed");
6987
}
70-
return new MLKEMEncapsulator(publicKey, spec, null);
88+
return new MLKEMEncapsulator(pubKey, spec, null);
7189
}
7290

7391
class MLKEMEncapsulator implements KEMSpi.EncapsulatorSpi {
@@ -129,14 +147,32 @@ public KEMSpi.DecapsulatorSpi engineNewDecapsulator(PrivateKey privateKey,
129147
AlgorithmParameterSpec spec)
130148
throws InvalidAlgorithmParameterException, InvalidKeyException {
131149

132-
if (privateKey == null) {
133-
throw new InvalidKeyException("unsupported key");
150+
PrivateKey privKey = privateKey;
151+
152+
if (privKey == null) {
153+
throw new InvalidKeyException("Key is null.");
154+
}
155+
156+
if (!(privKey instanceof PQCPrivateKey)) {
157+
// Try and convert this key to a usage PQCPrivateKey
158+
byte[] encoding = null;
159+
try {
160+
KeyFactory kf = KeyFactory.getInstance(this.alg, this.provider.getName());
161+
encoding = privateKey.getEncoded();
162+
PKCS8EncodedKeySpec privateKeySpec = new PKCS8EncodedKeySpec(encoding);
163+
privKey = kf.generatePrivate(privateKeySpec);
164+
} catch (Exception e) {
165+
throw new InvalidKeyException("unsupported key", e);
166+
} finally {
167+
Arrays.fill(encoding, (byte) 0);
168+
}
169+
134170
}
135171

136172
if (spec != null) {
137173
throw new InvalidAlgorithmParameterException("no spec needed");
138174
}
139-
return new MLKEMDecapsulator(privateKey, null);
175+
return new MLKEMDecapsulator(privKey, null);
140176
}
141177

142178
/*

src/test/java/ibm/jceplus/junit/base/BaseTestKEM.java

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
package ibm.jceplus.junit.base;
1010

11+
import java.security.InvalidKeyException;
1112
import java.security.KeyFactory;
1213
import java.security.KeyPair;
1314
import java.security.KeyPairGenerator;
@@ -20,6 +21,7 @@
2021
import org.junit.jupiter.params.ParameterizedTest;
2122
import org.junit.jupiter.params.provider.CsvSource;
2223
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
24+
import static org.junit.jupiter.api.Assertions.assertTrue;
2325
import static org.junit.jupiter.api.Assertions.fail;
2426

2527
public class BaseTestKEM extends BaseTestJunit5 {
@@ -35,8 +37,6 @@ public void testKEM(String Algorithm) throws Exception {
3537
KEM kem = KEM.getInstance(Algorithm, getProviderName());
3638

3739
KeyPair pqcKeyPair = generateKeyPair(Algorithm);
38-
pqcKeyPair.getPublic();
39-
pqcKeyPair.getPrivate();
4040

4141
KEM.Encapsulator encr = kem.newEncapsulator(pqcKeyPair.getPublic());
4242
KEM.Encapsulated enc = encr.encapsulate(0, 32, "AES");
@@ -56,8 +56,6 @@ public void testKEMEmptyNoToFrom(String Algorithm) throws Exception {
5656
KEM kem = KEM.getInstance(Algorithm, getProviderName());
5757

5858
KeyPair pqcKeyPair = generateKeyPair(Algorithm);
59-
pqcKeyPair.getPublic();
60-
pqcKeyPair.getPrivate();
6159

6260
KEM.Encapsulator encr = kem.newEncapsulator(pqcKeyPair.getPublic());
6361
KEM.Encapsulated enc = encr.encapsulate();
@@ -78,8 +76,6 @@ public void testKEMError(String Algorithm) throws Exception {
7876
KEM kem = KEM.getInstance(Algorithm, getProviderName());
7977

8078
KeyPair pqcKeyPair = generateKeyPair(Algorithm);
81-
pqcKeyPair.getPublic();
82-
pqcKeyPair.getPrivate();
8379

8480
KEM.Encapsulator encr = kem.newEncapsulator(pqcKeyPair.getPublic());
8581
for (int i =0; i < 4; i++) {
@@ -161,8 +157,6 @@ public void testKEMSmallerSecret(String Algorithm) throws Exception {
161157
KEM kem = KEM.getInstance(Algorithm, getProviderName());
162158

163159
KeyPair pqcKeyPair = generateKeyPair(Algorithm);
164-
pqcKeyPair.getPublic();
165-
pqcKeyPair.getPrivate();
166160

167161
KEM.Encapsulator encr = kem.newEncapsulator(pqcKeyPair.getPublic());
168162
KEM.Encapsulated enc = encr.encapsulate(0, 16, "AES");
@@ -175,6 +169,47 @@ public void testKEMSmallerSecret(String Algorithm) throws Exception {
175169
assertArrayEquals(keyE.getEncoded(), keyD.getEncoded(), "Secrets do NOT match");
176170
}
177171

172+
@ParameterizedTest
173+
@CsvSource({"ML-KEM", "ML-KEM-512", "ML_KEM_768", "ML_KEM_1024"})
174+
public void testKEMKeys(String Algorithm) throws Exception {
175+
176+
KEM kem = KEM.getInstance(Algorithm, getProviderName());
177+
178+
KeyPair pqcKeyPair = generateKeyPair("RSA");
179+
180+
try {
181+
kem.newEncapsulator(pqcKeyPair.getPublic());
182+
fail("testKEMKeys failed - RSA Public key did not cause an Invalid Key Excepton.");
183+
} catch (InvalidKeyException ike) {
184+
assertTrue(ike.getMessage().equals("unsupported key"));
185+
}
186+
187+
try {
188+
kem.newDecapsulator(pqcKeyPair.getPrivate());
189+
fail("testKEMKeys failed - RSA Private key did not cause an Invalid Key Excepton.");
190+
} catch (InvalidKeyException ike) {
191+
assertTrue(ike.getMessage().equals("unsupported key"));
192+
}
193+
194+
// Test null keys
195+
PublicKey pub = null;
196+
PrivateKey priv = null;
197+
198+
try {
199+
kem.newEncapsulator(pub);
200+
fail("testKEMKeys failed - NULL Public key did not cause an Invalid Key Excepton.");
201+
} catch (InvalidKeyException ike) {
202+
assertTrue(ike.getMessage().equals("Key is null."));
203+
}
204+
205+
try {
206+
kem.newDecapsulator(priv);
207+
fail("testKEMKeys failed - NULL Private key did not cause an Invalid Key Excepton.");
208+
} catch (InvalidKeyException ike) {
209+
assertTrue(ike.getMessage().equals("Key is null."));
210+
}
211+
}
212+
178213
protected KeyPair generateKeyPair(String Algorithm) throws Exception {
179214
pqcKeyPairGen = KeyPairGenerator.getInstance(Algorithm, getProviderName());
180215

src/test/java/ibm/jceplus/junit/base/BaseTestPQCKeyInterop.java

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.junit.jupiter.api.Test;
2525
import org.junit.jupiter.params.ParameterizedTest;
2626
import org.junit.jupiter.params.provider.CsvSource;
27+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
2728
import static org.junit.jupiter.api.Assertions.assertTrue;
2829
import static org.junit.jupiter.api.Assertions.fail;
2930

@@ -74,6 +75,38 @@ public void testPQCKeyGenKEM_PlusToInterop() throws Exception {
7475
assertTrue(same);
7576
}
7677

78+
@Test
79+
public void testPQCKeyGenKEMAutoKeyConvertion() throws Exception {
80+
String pqcAlgorithm = "ML-KEM-512";
81+
82+
if (getProviderName().equals("OpenJCEPlusFIPS") ||
83+
getInteropProviderName().equals(Utils.PROVIDER_BC)) {
84+
//This is not in the FIPS provider yet and Boucy Castle does not support this test.
85+
return;
86+
}
87+
88+
KEM kemInterop = KEM.getInstance(pqcAlgorithm, getProviderName());
89+
90+
KeyPairGenerator keyPairGen = KeyPairGenerator.getInstance(pqcAlgorithm, getInteropProviderName());
91+
KeyPair keyPair = generateKeyPair(keyPairGen);
92+
93+
PublicKey publicKey = keyPair.getPublic();
94+
PrivateKey privateKey = keyPair.getPrivate();
95+
96+
KEM.Encapsulator encr = kemInterop.newEncapsulator(publicKey);
97+
KEM.Encapsulated enc = encr.encapsulate(0, 32, "AES");
98+
if (enc == null){
99+
System.out.println("enc = null");
100+
fail("KEMPlusCreatesInteropGet failed no enc.");
101+
}
102+
SecretKey keyE = enc.key();
103+
104+
KEM.Decapsulator decr = kemInterop.newDecapsulator(privateKey);
105+
SecretKey keyD = decr.decapsulate(enc.encapsulation(), 0, 32, "AES");
106+
107+
assertArrayEquals(keyE.getEncoded(), keyD.getEncoded(), "Secrets do NOT match");
108+
}
109+
77110
@Test
78111
public void testPQCKeyGenKEM_Interop() throws Exception {
79112
String pqcAlgorithm = "ML-KEM-512";
@@ -337,7 +370,7 @@ public void testSignInteropKeysPlusSignVerify(String algorithm) {
337370
assertTrue(verifyingPlus.verify(signedBytesInterop), "Signature verification failed");
338371
} catch (Exception ex) {
339372
ex.printStackTrace();
340-
assertTrue(false, "SignInteropAndVerifyPlus failed");
373+
fail("SignInteropAndVerifyPlus failed");
341374
}
342375
}
343376

@@ -372,7 +405,7 @@ public void testSignPlusKeysInteropSignVerify(String algorithm) {
372405
assertTrue(verifyingPlus.verify(signedBytesInterop), "Signature verification failed");
373406
} catch (Exception ex) {
374407
ex.printStackTrace();
375-
assertTrue(false, "SignInteropAndVerifyPlus failed");
408+
fail("SignInteropAndVerifyPlus failed");
376409
}
377410
}
378411

@@ -408,7 +441,7 @@ public void testSignPlusAndVerifyInterop(String algorithm) {
408441
assertTrue(verifyingPlus.verify(signedBytesPlus), "Signature verification failed");
409442
} catch (Exception ex) {
410443
ex.printStackTrace();
411-
assertTrue(false, "SignPlusAndVerifyInterop failed");
444+
fail("SignPlusAndVerifyInterop failed");
412445
}
413446
}
414447

@@ -440,17 +473,17 @@ public void testKEMPlusKeyInteropAll(String Algorithm) {
440473
KEM.Encapsulated enc = encr.encapsulate(0, 32, "AES");
441474
if (enc == null){
442475
System.out.println("enc = null");
443-
assertTrue(false, "KEMPlusCreatesInteropGet failed no enc.");
476+
fail("KEMPlusCreatesInteropGet failed no enc.");
444477
}
445478
SecretKey keyE = enc.key();
446479

447480
KEM.Decapsulator decr = kemInterop.newDecapsulator(privateKeyInterop);
448481
SecretKey keyD = decr.decapsulate(enc.encapsulation(), 0, 32, "AES");
449482

450-
assertTrue(Arrays.equals(keyE.getEncoded(), keyD.getEncoded()), "Secrets do NOT match");
483+
assertArrayEquals(keyE.getEncoded(), keyD.getEncoded(), "Secrets do NOT match");
451484
} catch (Exception ex) {
452485
ex.printStackTrace();
453-
assertTrue(false, "KEMPlusCreatesInteropGet failed");
486+
fail("KEMPlusCreatesInteropGet failed");
454487
}
455488
}
456489

@@ -482,17 +515,17 @@ public void testKEMInteropKeyPlusAll(String Algorithm) {
482515
KEM.Encapsulated enc = encr.encapsulate(0, 32, "AES");
483516
if (enc == null){
484517
System.out.println("enc = null");
485-
assertTrue(false, "KEMPlusCreatesInteropGet failed no enc.");
518+
fail("KEMPlusCreatesInteropGet failed no enc.");
486519
}
487520
SecretKey keyE = enc.key();
488521

489522
KEM.Decapsulator decr = kemPlus.newDecapsulator(privateKeyPlus);
490523
SecretKey keyD = decr.decapsulate(enc.encapsulation(), 0, 32, "AES");
491524

492-
assertTrue(Arrays.equals(keyE.getEncoded(), keyD.getEncoded()), "Secrets do NOT match");
525+
assertArrayEquals(keyE.getEncoded(), keyD.getEncoded(), "Secrets do NOT match");
493526
} catch (Exception ex) {
494527
ex.printStackTrace();
495-
assertTrue(false, "KEMPlusCreatesInteropGet failed");
528+
fail("KEMPlusCreatesInteropGet failed");
496529
}
497530
}
498531

@@ -522,17 +555,17 @@ public void testKEMPlusCreatesInteropGet(String Algorithm) {
522555
KEM.Encapsulated enc = encr.encapsulate(0, 32, "AES");
523556
if (enc == null){
524557
System.out.println("enc = null");
525-
assertTrue(false, "KEMPlusCreatesInteropGet failed no enc.");
558+
fail("KEMPlusCreatesInteropGet failed no enc.");
526559
}
527560
SecretKey keyE = enc.key();
528561

529562
KEM.Decapsulator decr = kemPlus.newDecapsulator(privateKeyPlus);
530563
SecretKey keyD = decr.decapsulate(enc.encapsulation(), 0, 32, "AES");
531564

532-
assertTrue(Arrays.equals(keyE.getEncoded(), keyD.getEncoded()), "Secrets do NOT match");
565+
assertArrayEquals(keyE.getEncoded(), keyD.getEncoded(), "Secrets do NOT match");
533566
} catch (Exception ex) {
534567
ex.printStackTrace();
535-
assertTrue(false, "KEMPlusCreatesInteropGet failed");
568+
fail("KEMPlusCreatesInteropGet failed");
536569
}
537570
}
538571

@@ -566,10 +599,10 @@ public void testKEMInteropCreatesPlusGet(String Algorithm) {
566599

567600
SecretKey keyD = decr.decapsulate(enc.encapsulation(), 0, 32, "AES");
568601

569-
assertTrue(Arrays.equals(keyE.getEncoded(), keyD.getEncoded()), "Secrets do NOT match");
602+
assertArrayEquals(keyE.getEncoded(), keyD.getEncoded(), "Secrets do NOT match");
570603
} catch (Exception ex) {
571604
ex.printStackTrace();
572-
assertTrue(false, "KEMInteropCreatesPlusGet failed");
605+
fail("KEMInteropCreatesPlusGet failed");
573606
}
574607
}
575608

0 commit comments

Comments
 (0)