@@ -196,9 +196,10 @@ public void init(SecureRandom random)
196196 this .random = random ;
197197 }
198198
199- public byte [][] generateKemKeyPair ()
199+ //Internal functions are deterministic. No randomness is sampled inside them
200+ public byte [][] generateKemKeyPairInternal (byte [] d , byte [] z )
200201 {
201- byte [][] indCpaKeyPair = indCpa .generateKeyPair ();
202+ byte [][] indCpaKeyPair = indCpa .generateKeyPair (d );
202203
203204 byte [] s = new byte [KyberIndCpaSecretKeyBytes ];
204205
@@ -208,41 +209,19 @@ public byte[][] generateKemKeyPair()
208209
209210 symmetric .hash_h (hashedPublicKey , indCpaKeyPair [0 ], 0 );
210211
211- byte [] z = new byte [KyberSymBytes ];
212- random .nextBytes (z );
213212
214213 byte [] outputPublicKey = new byte [KyberIndCpaPublicKeyBytes ];
215214 System .arraycopy (indCpaKeyPair [0 ], 0 , outputPublicKey , 0 , KyberIndCpaPublicKeyBytes );
216215 return new byte [][]{ Arrays .copyOfRange (outputPublicKey , 0 , outputPublicKey .length - 32 ), Arrays .copyOfRange (outputPublicKey , outputPublicKey .length - 32 , outputPublicKey .length ), s , hashedPublicKey , z };
217216 }
218217
219- public byte [][] kemEncrypt (byte [] publicKeyInput )
218+ public byte [][] kemEncryptInternal (byte [] publicKeyInput , byte [] randBytes )
220219 {
221- // Input validation (6.2 ML-KEM Encaps)
222- // Type Check
223- if (publicKeyInput .length != KyberIndCpaPublicKeyBytes )
224- {
225- throw new IllegalArgumentException ("Input validation Error: Type check failed for ml-kem encapsulation" );
226- }
227- // Modulus Check
228- PolyVec polyVec = new PolyVec (this );
229- byte [] seed = indCpa .unpackPublicKey (polyVec , publicKeyInput );
230- byte [] ek = indCpa .packPublicKey (polyVec , seed );
231- if (!Arrays .areEqual (ek , publicKeyInput ))
232- {
233- throw new IllegalArgumentException ("Input validation: Modulus check failed for ml-kem encapsulation" );
234- }
235-
236-
237220 byte [] outputCipherText ;
238221
239222 byte [] buf = new byte [2 * KyberSymBytes ];
240223 byte [] kr = new byte [2 * KyberSymBytes ];
241224
242- byte [] randBytes = new byte [KyberSymBytes ];
243-
244- random .nextBytes (randBytes );
245-
246225 System .arraycopy (randBytes , 0 , buf , 0 , KyberSymBytes );
247226
248227 // SHA3-256 Public Key
@@ -252,33 +231,32 @@ public byte[][] kemEncrypt(byte[] publicKeyInput)
252231 symmetric .hash_g (kr , buf );
253232
254233 // IndCpa Encryption
255- outputCipherText = indCpa .encrypt (Arrays .copyOfRange (buf , 0 , KyberSymBytes ), publicKeyInput , Arrays .copyOfRange (kr , 32 , kr .length ));
234+ outputCipherText = indCpa .encrypt (publicKeyInput , Arrays .copyOfRange (buf , 0 , KyberSymBytes ), Arrays .copyOfRange (kr , 32 , kr .length ));
256235
257236 byte [] outputSharedSecret = new byte [sessionKeyLength ];
258237
259238 System .arraycopy (kr , 0 , outputSharedSecret , 0 , outputSharedSecret .length );
260-
239+
261240 byte [][] outBuf = new byte [2 ][];
262241 outBuf [0 ] = outputSharedSecret ;
263242 outBuf [1 ] = outputCipherText ;
264-
265243 return outBuf ;
266244 }
267245
268- public byte [] kemDecrypt (byte [] cipherText , byte [] secretKey )
246+ public byte [] kemDecryptInternal (byte [] secretKey , byte [] cipherText )
269247 {
270248 byte [] buf = new byte [2 * KyberSymBytes ],
271- kr = new byte [2 * KyberSymBytes ];
249+ kr = new byte [2 * KyberSymBytes ];
272250
273251 byte [] publicKey = Arrays .copyOfRange (secretKey , KyberIndCpaSecretKeyBytes , secretKey .length );
274252
275- System .arraycopy (indCpa .decrypt (cipherText , secretKey ), 0 , buf , 0 , KyberSymBytes );
253+ System .arraycopy (indCpa .decrypt (secretKey , cipherText ), 0 , buf , 0 , KyberSymBytes );
276254
277255 System .arraycopy (secretKey , KyberSecretKeyBytes - 2 * KyberSymBytes , buf , KyberSymBytes , KyberSymBytes );
278256
279257 symmetric .hash_g (kr , buf );
280258
281- byte [] cmp = indCpa .encrypt (Arrays .copyOfRange (buf , 0 , KyberSymBytes ), publicKey , Arrays .copyOfRange (kr , KyberSymBytes , kr .length ));
259+ byte [] cmp = indCpa .encrypt (publicKey , Arrays .copyOfRange (buf , 0 , KyberSymBytes ), Arrays .copyOfRange (kr , KyberSymBytes , kr .length ));
282260
283261 boolean fail = !(Arrays .constantTimeAreEqual (cipherText , cmp ));
284262
@@ -289,6 +267,42 @@ public byte[] kemDecrypt(byte[] cipherText, byte[] secretKey)
289267 return Arrays .copyOfRange (kr , 0 , sessionKeyLength );
290268 }
291269
270+ public byte [][] generateKemKeyPair ()
271+ {
272+ byte [] d = new byte [KyberSymBytes ];
273+ byte [] z = new byte [KyberSymBytes ];
274+ random .nextBytes (d );
275+ random .nextBytes (z );
276+
277+ return generateKemKeyPairInternal (d , z );
278+ }
279+
280+ public byte [][] kemEncrypt (byte [] publicKeyInput , byte [] randBytes )
281+ {
282+ //TODO: do input validation elsewhere?
283+ // Input validation (6.2 ML-KEM Encaps)
284+ // Type Check
285+ if (publicKeyInput .length != KyberIndCpaPublicKeyBytes )
286+ {
287+ throw new IllegalArgumentException ("Input validation Error: Type check failed for ml-kem encapsulation" );
288+ }
289+ // Modulus Check
290+ PolyVec polyVec = new PolyVec (this );
291+ byte [] seed = indCpa .unpackPublicKey (polyVec , publicKeyInput );
292+ byte [] ek = indCpa .packPublicKey (polyVec , seed );
293+ if (!Arrays .areEqual (ek , publicKeyInput ))
294+ {
295+ throw new IllegalArgumentException ("Input validation: Modulus check failed for ml-kem encapsulation" );
296+ }
297+
298+ return kemEncryptInternal (publicKeyInput , randBytes );
299+ }
300+ public byte [] kemDecrypt (byte [] secretKey , byte [] cipherText )
301+ {
302+ //TODO: do input validation
303+ return kemDecryptInternal (secretKey , cipherText );
304+ }
305+
292306 private void cmov (byte [] r , byte [] x , int xlen , boolean b )
293307 {
294308 if (b )
0 commit comments