From fdd11425c6dea9eb0321794537063b7f4dadbd57 Mon Sep 17 00:00:00 2001 From: John Peck <140550562+johnpeck-us-ibm@users.noreply.github.com> Date: Mon, 15 Dec 2025 15:41:28 -0600 Subject: [PATCH] Add Check for PQCPrivateKey in Decapsulator and add tests for invalid keys. Signed-off-by: John Peck <140550562+johnpeck-us-ibm@users.noreply.github.com> --- .../ibm/crypto/plus/provider/MLKEMImpl.java | 50 ++++++++++++--- .../ibm/jceplus/junit/base/BaseTestKEM.java | 51 +++++++++++++--- .../junit/base/BaseTestPQCKeyInterop.java | 61 ++++++++++++++----- 3 files changed, 133 insertions(+), 29 deletions(-) diff --git a/src/main/java/com/ibm/crypto/plus/provider/MLKEMImpl.java b/src/main/java/com/ibm/crypto/plus/provider/MLKEMImpl.java index d4f1fd0ce..0cf39c1b3 100644 --- a/src/main/java/com/ibm/crypto/plus/provider/MLKEMImpl.java +++ b/src/main/java/com/ibm/crypto/plus/provider/MLKEMImpl.java @@ -12,11 +12,16 @@ import com.ibm.crypto.plus.provider.ock.OJPKEM; import java.security.InvalidAlgorithmParameterException; import java.security.InvalidKeyException; +import java.security.KeyFactory; import java.security.PrivateKey; import java.security.ProviderException; import java.security.PublicKey; import java.security.SecureRandom; import java.security.spec.AlgorithmParameterSpec; +import java.security.spec.EncodedKeySpec; +import java.security.spec.PKCS8EncodedKeySpec; +import java.security.spec.X509EncodedKeySpec; +import java.util.Arrays; import javax.crypto.DecapsulateException; import javax.crypto.KEM; import javax.crypto.KEMSpi; @@ -49,7 +54,6 @@ private int getEncapsulationLength() { return size; } - /* * spec - The AlgorithmParameterSpec is not used and should be null. If not null * it will be ignored. @@ -60,14 +64,28 @@ private int getEncapsulationLength() { public KEMSpi.EncapsulatorSpi engineNewEncapsulator(PublicKey publicKey, AlgorithmParameterSpec spec, SecureRandom secureRandom) throws InvalidAlgorithmParameterException, InvalidKeyException { - if (publicKey == null || !(publicKey instanceof PQCPublicKey) ) { - throw new InvalidKeyException("unsupported key"); + + PublicKey pubKey = publicKey; + if (pubKey == null) { + throw new InvalidKeyException("Key is null."); + } + + if (!(pubKey instanceof PQCPublicKey)) { + // Try and convert this key to a usage PQCPublicKey + try { + KeyFactory kf = KeyFactory.getInstance(this.alg, this.provider.getName()); + EncodedKeySpec publicKeySpec = new X509EncodedKeySpec(publicKey.getEncoded()); + pubKey = kf.generatePublic(publicKeySpec); + + } catch (Exception e) { + throw new InvalidKeyException("unsupported key", e); + } } if (spec != null) { throw new InvalidAlgorithmParameterException("no spec needed"); } - return new MLKEMEncapsulator(publicKey, spec, null); + return new MLKEMEncapsulator(pubKey, spec, null); } class MLKEMEncapsulator implements KEMSpi.EncapsulatorSpi { @@ -129,14 +147,32 @@ public KEMSpi.DecapsulatorSpi engineNewDecapsulator(PrivateKey privateKey, AlgorithmParameterSpec spec) throws InvalidAlgorithmParameterException, InvalidKeyException { - if (privateKey == null) { - throw new InvalidKeyException("unsupported key"); + PrivateKey privKey = privateKey; + + if (privKey == null) { + throw new InvalidKeyException("Key is null."); + } + + if (!(privKey instanceof PQCPrivateKey)) { + // Try and convert this key to a usage PQCPrivateKey + byte[] encoding = null; + try { + KeyFactory kf = KeyFactory.getInstance(this.alg, this.provider.getName()); + encoding = privateKey.getEncoded(); + PKCS8EncodedKeySpec privateKeySpec = new PKCS8EncodedKeySpec(encoding); + privKey = kf.generatePrivate(privateKeySpec); + } catch (Exception e) { + throw new InvalidKeyException("unsupported key", e); + } finally { + Arrays.fill(encoding, (byte) 0); + } + } if (spec != null) { throw new InvalidAlgorithmParameterException("no spec needed"); } - return new MLKEMDecapsulator(privateKey, null); + return new MLKEMDecapsulator(privKey, null); } /* diff --git a/src/test/java/ibm/jceplus/junit/base/BaseTestKEM.java b/src/test/java/ibm/jceplus/junit/base/BaseTestKEM.java index d9c98cc1c..6fce1c62d 100644 --- a/src/test/java/ibm/jceplus/junit/base/BaseTestKEM.java +++ b/src/test/java/ibm/jceplus/junit/base/BaseTestKEM.java @@ -8,6 +8,7 @@ package ibm.jceplus.junit.base; +import java.security.InvalidKeyException; import java.security.KeyFactory; import java.security.KeyPair; import java.security.KeyPairGenerator; @@ -20,6 +21,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; public class BaseTestKEM extends BaseTestJunit5 { @@ -35,8 +37,6 @@ public void testKEM(String Algorithm) throws Exception { KEM kem = KEM.getInstance(Algorithm, getProviderName()); KeyPair pqcKeyPair = generateKeyPair(Algorithm); - pqcKeyPair.getPublic(); - pqcKeyPair.getPrivate(); KEM.Encapsulator encr = kem.newEncapsulator(pqcKeyPair.getPublic()); KEM.Encapsulated enc = encr.encapsulate(0, 32, "AES"); @@ -56,8 +56,6 @@ public void testKEMEmptyNoToFrom(String Algorithm) throws Exception { KEM kem = KEM.getInstance(Algorithm, getProviderName()); KeyPair pqcKeyPair = generateKeyPair(Algorithm); - pqcKeyPair.getPublic(); - pqcKeyPair.getPrivate(); KEM.Encapsulator encr = kem.newEncapsulator(pqcKeyPair.getPublic()); KEM.Encapsulated enc = encr.encapsulate(); @@ -78,8 +76,6 @@ public void testKEMError(String Algorithm) throws Exception { KEM kem = KEM.getInstance(Algorithm, getProviderName()); KeyPair pqcKeyPair = generateKeyPair(Algorithm); - pqcKeyPair.getPublic(); - pqcKeyPair.getPrivate(); KEM.Encapsulator encr = kem.newEncapsulator(pqcKeyPair.getPublic()); for (int i =0; i < 4; i++) { @@ -161,8 +157,6 @@ public void testKEMSmallerSecret(String Algorithm) throws Exception { KEM kem = KEM.getInstance(Algorithm, getProviderName()); KeyPair pqcKeyPair = generateKeyPair(Algorithm); - pqcKeyPair.getPublic(); - pqcKeyPair.getPrivate(); KEM.Encapsulator encr = kem.newEncapsulator(pqcKeyPair.getPublic()); KEM.Encapsulated enc = encr.encapsulate(0, 16, "AES"); @@ -175,6 +169,47 @@ public void testKEMSmallerSecret(String Algorithm) throws Exception { assertArrayEquals(keyE.getEncoded(), keyD.getEncoded(), "Secrets do NOT match"); } + @ParameterizedTest + @CsvSource({"ML-KEM", "ML-KEM-512", "ML_KEM_768", "ML_KEM_1024"}) + public void testKEMKeys(String Algorithm) throws Exception { + + KEM kem = KEM.getInstance(Algorithm, getProviderName()); + + KeyPair pqcKeyPair = generateKeyPair("RSA"); + + try { + kem.newEncapsulator(pqcKeyPair.getPublic()); + fail("testKEMKeys failed - RSA Public key did not cause an Invalid Key Excepton."); + } catch (InvalidKeyException ike) { + assertTrue(ike.getMessage().equals("unsupported key")); + } + + try { + kem.newDecapsulator(pqcKeyPair.getPrivate()); + fail("testKEMKeys failed - RSA Private key did not cause an Invalid Key Excepton."); + } catch (InvalidKeyException ike) { + assertTrue(ike.getMessage().equals("unsupported key")); + } + + // Test null keys + PublicKey pub = null; + PrivateKey priv = null; + + try { + kem.newEncapsulator(pub); + fail("testKEMKeys failed - NULL Public key did not cause an Invalid Key Excepton."); + } catch (InvalidKeyException ike) { + assertTrue(ike.getMessage().equals("Key is null.")); + } + + try { + kem.newDecapsulator(priv); + fail("testKEMKeys failed - NULL Private key did not cause an Invalid Key Excepton."); + } catch (InvalidKeyException ike) { + assertTrue(ike.getMessage().equals("Key is null.")); + } + } + protected KeyPair generateKeyPair(String Algorithm) throws Exception { pqcKeyPairGen = KeyPairGenerator.getInstance(Algorithm, getProviderName()); diff --git a/src/test/java/ibm/jceplus/junit/base/BaseTestPQCKeyInterop.java b/src/test/java/ibm/jceplus/junit/base/BaseTestPQCKeyInterop.java index 456a924dd..8725d540d 100644 --- a/src/test/java/ibm/jceplus/junit/base/BaseTestPQCKeyInterop.java +++ b/src/test/java/ibm/jceplus/junit/base/BaseTestPQCKeyInterop.java @@ -24,6 +24,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -74,6 +75,38 @@ public void testPQCKeyGenKEM_PlusToInterop() throws Exception { assertTrue(same); } + @Test + public void testPQCKeyGenKEMAutoKeyConvertion() throws Exception { + String pqcAlgorithm = "ML-KEM-512"; + + if (getProviderName().equals("OpenJCEPlusFIPS") || + getInteropProviderName().equals(Utils.PROVIDER_BC)) { + //This is not in the FIPS provider yet and Boucy Castle does not support this test. + return; + } + + KEM kemInterop = KEM.getInstance(pqcAlgorithm, getProviderName()); + + KeyPairGenerator keyPairGen = KeyPairGenerator.getInstance(pqcAlgorithm, getInteropProviderName()); + KeyPair keyPair = generateKeyPair(keyPairGen); + + PublicKey publicKey = keyPair.getPublic(); + PrivateKey privateKey = keyPair.getPrivate(); + + KEM.Encapsulator encr = kemInterop.newEncapsulator(publicKey); + KEM.Encapsulated enc = encr.encapsulate(0, 32, "AES"); + if (enc == null){ + System.out.println("enc = null"); + fail("KEMPlusCreatesInteropGet failed no enc."); + } + SecretKey keyE = enc.key(); + + KEM.Decapsulator decr = kemInterop.newDecapsulator(privateKey); + SecretKey keyD = decr.decapsulate(enc.encapsulation(), 0, 32, "AES"); + + assertArrayEquals(keyE.getEncoded(), keyD.getEncoded(), "Secrets do NOT match"); + } + @Test public void testPQCKeyGenKEM_Interop() throws Exception { String pqcAlgorithm = "ML-KEM-512"; @@ -337,7 +370,7 @@ public void testSignInteropKeysPlusSignVerify(String algorithm) { assertTrue(verifyingPlus.verify(signedBytesInterop), "Signature verification failed"); } catch (Exception ex) { ex.printStackTrace(); - assertTrue(false, "SignInteropAndVerifyPlus failed"); + fail("SignInteropAndVerifyPlus failed"); } } @@ -372,7 +405,7 @@ public void testSignPlusKeysInteropSignVerify(String algorithm) { assertTrue(verifyingPlus.verify(signedBytesInterop), "Signature verification failed"); } catch (Exception ex) { ex.printStackTrace(); - assertTrue(false, "SignInteropAndVerifyPlus failed"); + fail("SignInteropAndVerifyPlus failed"); } } @@ -408,7 +441,7 @@ public void testSignPlusAndVerifyInterop(String algorithm) { assertTrue(verifyingPlus.verify(signedBytesPlus), "Signature verification failed"); } catch (Exception ex) { ex.printStackTrace(); - assertTrue(false, "SignPlusAndVerifyInterop failed"); + fail("SignPlusAndVerifyInterop failed"); } } @@ -440,17 +473,17 @@ public void testKEMPlusKeyInteropAll(String Algorithm) { KEM.Encapsulated enc = encr.encapsulate(0, 32, "AES"); if (enc == null){ System.out.println("enc = null"); - assertTrue(false, "KEMPlusCreatesInteropGet failed no enc."); + fail("KEMPlusCreatesInteropGet failed no enc."); } SecretKey keyE = enc.key(); KEM.Decapsulator decr = kemInterop.newDecapsulator(privateKeyInterop); SecretKey keyD = decr.decapsulate(enc.encapsulation(), 0, 32, "AES"); - assertTrue(Arrays.equals(keyE.getEncoded(), keyD.getEncoded()), "Secrets do NOT match"); + assertArrayEquals(keyE.getEncoded(), keyD.getEncoded(), "Secrets do NOT match"); } catch (Exception ex) { ex.printStackTrace(); - assertTrue(false, "KEMPlusCreatesInteropGet failed"); + fail("KEMPlusCreatesInteropGet failed"); } } @@ -482,17 +515,17 @@ public void testKEMInteropKeyPlusAll(String Algorithm) { KEM.Encapsulated enc = encr.encapsulate(0, 32, "AES"); if (enc == null){ System.out.println("enc = null"); - assertTrue(false, "KEMPlusCreatesInteropGet failed no enc."); + fail("KEMPlusCreatesInteropGet failed no enc."); } SecretKey keyE = enc.key(); KEM.Decapsulator decr = kemPlus.newDecapsulator(privateKeyPlus); SecretKey keyD = decr.decapsulate(enc.encapsulation(), 0, 32, "AES"); - assertTrue(Arrays.equals(keyE.getEncoded(), keyD.getEncoded()), "Secrets do NOT match"); + assertArrayEquals(keyE.getEncoded(), keyD.getEncoded(), "Secrets do NOT match"); } catch (Exception ex) { ex.printStackTrace(); - assertTrue(false, "KEMPlusCreatesInteropGet failed"); + fail("KEMPlusCreatesInteropGet failed"); } } @@ -522,17 +555,17 @@ public void testKEMPlusCreatesInteropGet(String Algorithm) { KEM.Encapsulated enc = encr.encapsulate(0, 32, "AES"); if (enc == null){ System.out.println("enc = null"); - assertTrue(false, "KEMPlusCreatesInteropGet failed no enc."); + fail("KEMPlusCreatesInteropGet failed no enc."); } SecretKey keyE = enc.key(); KEM.Decapsulator decr = kemPlus.newDecapsulator(privateKeyPlus); SecretKey keyD = decr.decapsulate(enc.encapsulation(), 0, 32, "AES"); - assertTrue(Arrays.equals(keyE.getEncoded(), keyD.getEncoded()), "Secrets do NOT match"); + assertArrayEquals(keyE.getEncoded(), keyD.getEncoded(), "Secrets do NOT match"); } catch (Exception ex) { ex.printStackTrace(); - assertTrue(false, "KEMPlusCreatesInteropGet failed"); + fail("KEMPlusCreatesInteropGet failed"); } } @@ -566,10 +599,10 @@ public void testKEMInteropCreatesPlusGet(String Algorithm) { SecretKey keyD = decr.decapsulate(enc.encapsulation(), 0, 32, "AES"); - assertTrue(Arrays.equals(keyE.getEncoded(), keyD.getEncoded()), "Secrets do NOT match"); + assertArrayEquals(keyE.getEncoded(), keyD.getEncoded(), "Secrets do NOT match"); } catch (Exception ex) { ex.printStackTrace(); - assertTrue(false, "KEMInteropCreatesPlusGet failed"); + fail("KEMInteropCreatesPlusGet failed"); } }