Skip to content

Commit d53e7a4

Browse files
committed
Support derandomized key generation for ML-KEM
Signed-off-by: Spencer Wilson <[email protected]>
1 parent 1d7530e commit d53e7a4

File tree

3 files changed

+111
-1
lines changed

3 files changed

+111
-1
lines changed

src/main/c/KeyEncapsulation.c

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ JNIEXPORT jobject JNICALL Java_org_openquantumsafe_KeyEncapsulation_get_1KEM_1de
8585
jfieldID _length_shared_secret = (*env)->GetFieldID(env, cls, "length_shared_secret", "J");
8686
(*env)->SetLongField(env, _nativeKED, _length_shared_secret, (jlong) kem->length_shared_secret);
8787

88+
// long length_keypair_seed;
89+
jfieldID _length_keypair_seed = (*env)->GetFieldID(env, cls, "length_keypair_seed", "J");
90+
(*env)->SetLongField(env, _nativeKED, _length_keypair_seed, (jlong) kem->length_keypair_seed);
91+
8892
return _nativeKED;
8993
}
9094

@@ -110,6 +114,30 @@ JNIEXPORT jint JNICALL Java_org_openquantumsafe_KeyEncapsulation_generate_1keypa
110114
return (rv_ == OQS_SUCCESS) ? 0 : -1;
111115
}
112116

117+
/*
118+
* Class: org_openquantumsafe_KeyEncapsulation
119+
* Method: generate_keypair
120+
* Signature: ([B[B)I
121+
*/
122+
JNIEXPORT jint JNICALL Java_org_openquantumsafe_KeyEncapsulation_generate_1keypair_1derand
123+
(JNIEnv *env, jobject obj, jbyteArray jpublic_key, jbyteArray jsecret_key, jbyteArray jseed)
124+
{
125+
jbyte *public_key_native = (*env)->GetByteArrayElements(env, jpublic_key, 0);
126+
jbyte *secret_key_native = (*env)->GetByteArrayElements(env, jsecret_key, 0);
127+
jbyte *seed_native = (*env)->GetByteArrayElements(env, jseed, 0);
128+
129+
// Get pointer to KEM
130+
OQS_KEM *kem = (OQS_KEM *) getHandle(env, obj, "native_kem_handle_");
131+
132+
// Invoke liboqs KEM keypair generation function
133+
OQS_STATUS rv_ = OQS_KEM_keypair_derand(kem, (uint8_t*) public_key_native, (uint8_t*) secret_key_native, (uint8_t*) seed_native);
134+
135+
(*env)->ReleaseByteArrayElements(env, jpublic_key, public_key_native, 0);
136+
(*env)->ReleaseByteArrayElements(env, jsecret_key, secret_key_native, 0);
137+
(*env)->ReleaseByteArrayElements(env, jseed, seed_native, JNI_ABORT);
138+
return (rv_ == OQS_SUCCESS) ? 0 : -1;
139+
}
140+
113141
/*
114142
* Class: org_openquantumsafe_KeyEncapsulation
115143
* Method: encap_secret

src/main/java/org/openquantumsafe/KeyEncapsulation.java

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class KeyEncapsulationDetails {
2020
long length_secret_key;
2121
long length_ciphertext;
2222
long length_shared_secret;
23+
long length_keypair_seed;
2324

2425
/**
2526
* \brief Print KEM algorithm details
@@ -33,7 +34,9 @@ void printKeyEncapsulation() {
3334
"\n Length public key (bytes): " + this.length_public_key +
3435
"\n Length secret key (bytes): " + this.length_secret_key +
3536
"\n Length ciphertext (bytes): " + this.length_ciphertext +
36-
"\n Length shared secret (bytes): " + this.length_shared_secret
37+
"\n Length shared secret (bytes): " + this.length_shared_secret +
38+
"\n Length keypair seed (bytes): "
39+
+ ((this.length_keypair_seed > 0) ? this.length_keypair_seed : "N/A")
3740
);
3841
}
3942

@@ -114,6 +117,18 @@ public KeyEncapsulation(String alg_name, byte[] secret_key)
114117
*/
115118
private native int generate_keypair(byte[] public_key, byte[] secret_key);
116119

120+
/**
121+
* \brief Wrapper for OQS_API OQS_STATUS OQS_KEM_keypair_derand(const OQS_KEM *kem,
122+
* uint8_t *public_key, uint8_t *secret_key,
123+
* const uint8_t *seed);
124+
* \param Public key
125+
* \param Secret key
126+
* \param Seed
127+
* \return Status
128+
*/
129+
private native int generate_keypair_derand(byte[] public_key,
130+
byte[] secret_key, byte[] seed);
131+
117132
/**
118133
* \brief Wrapper for OQS_API OQS_STATUS OQS_KEM_encaps(const OQS_KEM *kem,
119134
* uint8_t *ciphertext,
@@ -159,6 +174,27 @@ public byte[] generate_keypair() throws RuntimeException {
159174
return this.public_key_;
160175
}
161176

177+
/**
178+
* \brief Invoke native generate_keypair_derand method using the PK and SK lengths
179+
* from alg_details_. Check return value and if != 0 throw Exception.
180+
*/
181+
public byte[] generate_keypair(byte[] seed) throws RuntimeException {
182+
if (seed.length != alg_details_.length_keypair_seed) {
183+
throw new RuntimeException("Incorrect seed length");
184+
}
185+
186+
int rv_ = generate_keypair_derand(this.public_key_, this.secret_key_, seed);
187+
if (rv_ != 0) throw new RuntimeException("Cannot generate keypair from seed");
188+
return this.public_key_;
189+
}
190+
191+
/**
192+
* \brief Return seed length
193+
*/
194+
public long get_keypair_seed_length() {
195+
return alg_details_.length_keypair_seed;
196+
}
197+
162198
/**
163199
* \brief Return public key
164200
*/

src/test/java/org/openquantumsafe/KEMTest.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
99

1010
import java.util.ArrayList;
11+
import java.util.Arrays;
1112
import java.util.stream.Stream;
1213

1314
public class KEMTest {
@@ -56,6 +57,43 @@ public void testAllKEMs(String kem_name) {
5657
System.out.println(sb.toString());
5758
}
5859

60+
/**
61+
* Test KEMs with derandomized keypair generation.
62+
*/
63+
@ParameterizedTest(name = "Testing {arguments}")
64+
@MethodSource("getDerandSupportedKEMsAsStream")
65+
public void testKEMsWithDerand(String kem_name) {
66+
StringBuilder sb = new StringBuilder();
67+
sb.append(kem_name);
68+
sb.append(" (derand)");
69+
sb.append(String.format("%1$" + (40 - kem_name.length() - 9) + "s", ""));
70+
71+
// Create client and server
72+
KeyEncapsulation client = new KeyEncapsulation(kem_name);
73+
KeyEncapsulation server = new KeyEncapsulation(kem_name);
74+
75+
// Generate seed
76+
byte[] seed = Rand.randombytes(client.get_keypair_seed_length());
77+
78+
// Generate client key pair
79+
byte[] client_public_key = client.generate_keypair(seed);
80+
81+
// Server: encapsulate secret with client's public key
82+
Pair<byte[], byte[]> server_pair = server.encap_secret(client_public_key);
83+
byte[] ciphertext = server_pair.getLeft();
84+
byte[] shared_secret_server = server_pair.getRight();
85+
86+
// Client: decapsulate
87+
byte[] shared_secret_client = client.decap_secret(ciphertext);
88+
89+
// Check if equal
90+
assertArrayEquals(shared_secret_client, shared_secret_server, kem_name);
91+
92+
// If successful print KEM name, otherwise an exception will be thrown
93+
sb.append("\033[0;32m").append("PASSED").append("\033[0m");
94+
System.out.println(sb.toString());
95+
}
96+
5997
/**
6098
* Test the MechanismNotSupported Exception
6199
*/
@@ -71,4 +109,12 @@ private static Stream<String> getEnabledKEMsAsStream() {
71109
return enabled_kems.parallelStream();
72110
}
73111

112+
/**
113+
* Method to convert the list of derand-supported KEMs to a stream for input to testAllSigs
114+
*/
115+
private static Stream<String> getDerandSupportedKEMsAsStream() {
116+
return Arrays.asList(
117+
"ML-KEM-512", "ML-KEM-768", "ML-KEM-1024"
118+
).parallelStream();
119+
}
74120
}

0 commit comments

Comments
 (0)