3636import net .i2p .crypto .EncType ;
3737import net .i2p .crypto .HKDF ;
3838import net .i2p .crypto .HMAC256Generator ;
39+ import net .i2p .crypto .KeyFactory ;
3940import net .i2p .crypto .KeyPair ;
4041import net .i2p .crypto .SHA256Generator ;
4142import net .i2p .crypto .SigType ;
6263import net .i2p .data .i2np .I2NPMessageHandler ;
6364import net .i2p .data .i2np .UnknownI2NPMessage ;
6465import net .i2p .router .RouterContext ;
66+ import net .i2p .router .crypto .pqc .MLKEM ;
6567import static net .i2p .router .crypto .ratchet .RatchetPayload .*;
6668import net .i2p .router .transport .crypto .X25519KeyFactory ;
6769import net .i2p .util .HexDump ;
@@ -130,6 +132,7 @@ public class RatchetTest implements RatchetPayload.PayloadCallback {
130132
131133 private final X25519KeyFactory kf ;
132134 private final Elg2KeyFactory ekf ;
135+ private final PQKeyFactory hkf ;
133136 private final Elligator2 elg2 ;
134137 private final HKDF hkdf ;
135138
@@ -156,12 +159,15 @@ public RatchetTest(RouterContext ctx, String host, int port, boolean isAlice, bo
156159 _pkFile = pkfile ;
157160 _type = type ;
158161
159- if (type == EncType .ECIES_X25519 )
162+ if (type == EncType .ECIES_X25519 ) {
160163 kf = new X25519KeyFactory (_context );
161- else if (useNoise )
164+ hkf = null ;
165+ } else if (useNoise ) {
162166 kf = new TypeKeyFactory (_context , type );
163- else
167+ hkf = new PQKeyFactory (type );
168+ } else {
164169 throw new IllegalArgumentException ("PQ types require -n" );
170+ }
165171 kf .start ();
166172 ekf = new Elg2KeyFactory (_context );
167173 ekf .start ();
@@ -192,6 +198,36 @@ public KeyPair getKeys() {
192198 }
193199 }
194200
201+ /**
202+ * Make type 5-7 hybrid keys since we don't have a key factory yet
203+ */
204+ private static class PQKeyFactory implements KeyFactory {
205+ private final EncType etype ;
206+ public PQKeyFactory (EncType type ) {
207+ switch (type ) {
208+ case MLKEM512_X25519 :
209+ type = EncType .MLKEM512_X25519_INT ;
210+ break ;
211+ case MLKEM768_X25519 :
212+ type = EncType .MLKEM768_X25519_INT ;
213+ break ;
214+ case MLKEM1024_X25519 :
215+ type = EncType .MLKEM1024_X25519_INT ;
216+ break ;
217+ default :
218+ throw new IllegalArgumentException ("Unsupported type " + type );
219+ }
220+ etype = type ;
221+ }
222+ public KeyPair getKeys () {
223+ try {
224+ return MLKEM .getKeys (etype );
225+ } catch (GeneralSecurityException gse ) {
226+ throw new IllegalStateException (gse );
227+ }
228+ }
229+ }
230+
195231 private static String getNoisePattern (EncType type ) {
196232 switch (type ) {
197233 case ECIES_X25519 :
@@ -252,7 +288,7 @@ private void connect() throws Exception {
252288 byte [] epriv = null ;
253289 HandshakeState state = null ;
254290 if (_useNoise ) {
255- state = new HandshakeState (getNoisePattern (_type ), HandshakeState .INITIATOR , ekf );
291+ state = new HandshakeState (getNoisePattern (_type ), HandshakeState .INITIATOR , ekf , hkf );
256292 state .getRemotePublicKey ().setPublicKey (s , 0 );
257293 state .getLocalKeyPair ().setKeys (priv , 0 , as , 0 );
258294 System .out .println ("Before start" );
@@ -342,7 +378,19 @@ private void connect() throws Exception {
342378
343379 if (_useNoise ) {
344380 // encrypt X and write X and the options block
345- tmp = new byte [32 + 32 + 16 + payloadlen + 16 ];
381+ int tmplen = 32 + 32 + 16 + payloadlen + 16 ;
382+ switch (_type ) {
383+ case MLKEM512_X25519 :
384+ tmplen += EncType .MLKEM512_X25519_INT .getPubkeyLen () + 16 ;
385+ break ;
386+ case MLKEM768_X25519 :
387+ tmplen += EncType .MLKEM768_X25519_INT .getPubkeyLen () + 16 ;
388+ break ;
389+ case MLKEM1024_X25519 :
390+ tmplen += EncType .MLKEM1024_X25519_INT .getPubkeyLen () + 16 ;
391+ break ;
392+ }
393+ tmp = new byte [tmplen ];
346394 state .writeMessage (tmp , 0 , payload , 0 , payload .length );
347395 System .out .println (state .toString ());
348396 // overwrite eph. key with encoded key
@@ -424,8 +472,20 @@ private void connect() throws Exception {
424472 // System.out.println("Tag NOT FOUND in expected tagset");
425473
426474
427- tmp = new byte [48 ];
428- System .arraycopy (itmp , 8 , tmp , 0 , 48 );
475+ int tmplen = 48 ;
476+ switch (_type ) {
477+ case MLKEM512_X25519 :
478+ tmplen += EncType .MLKEM512_X25519_CT .getPubkeyLen () + 16 ;
479+ break ;
480+ case MLKEM768_X25519 :
481+ tmplen += EncType .MLKEM768_X25519_CT .getPubkeyLen () + 16 ;
482+ break ;
483+ case MLKEM1024_X25519 :
484+ tmplen += EncType .MLKEM1024_X25519_CT .getPubkeyLen () + 16 ;
485+ break ;
486+ }
487+ tmp = new byte [tmplen ];
488+ System .arraycopy (itmp , 8 , tmp , 0 , tmplen );
429489 System .out .println ("Got msg 2 frame part 1" );
430490 System .out .println (HexDump .dump (tmp ));
431491
@@ -443,13 +503,13 @@ private void connect() throws Exception {
443503 state .mixHash (tag , 0 , TAGLEN );
444504 System .out .println (state .toString ());
445505 try {
446- state .readMessage (tmp , 0 , 48 , ZEROLEN , 0 );
506+ state .readMessage (tmp , 0 , tmplen , ZEROLEN , 0 );
447507 } catch (GeneralSecurityException gse ) {
448508 System .out .println ("**************\n State at failure:" );
449509 System .out .println (state .toString ());
450510 throw new IOException ("Bad AEAD msg 2" , gse );
451511 }
452- System .out .println ("After Message 1 " );
512+ System .out .println ("After Message 2 " );
453513 System .out .println (state .toString ());
454514 } else {
455515 // KDF 2
@@ -758,7 +818,7 @@ private void runConnection(Socket socket, byte[] s, byte[] inithash, byte[] priv
758818
759819 HandshakeState state = null ;
760820 if (_useNoise ) {
761- state = new HandshakeState (getNoisePattern (_type ), HandshakeState .RESPONDER , ekf );
821+ state = new HandshakeState (getNoisePattern (_type ), HandshakeState .RESPONDER , ekf , hkf );
762822 state .getLocalKeyPair ().setKeys (priv , 0 , s , 0 );
763823 System .out .println ("Before start" );
764824 System .out .println (state .toString ());
@@ -784,6 +844,17 @@ private void runConnection(Socket socket, byte[] s, byte[] inithash, byte[] priv
784844 int len = tmp .length ;
785845 System .out .println (HexDump .dump (tmp ));
786846 int payloadlen = len - (32 + 32 + 16 + 16 );
847+ switch (_type ) {
848+ case MLKEM512_X25519 :
849+ payloadlen -= EncType .MLKEM512_X25519_INT .getPubkeyLen () + 16 ;
850+ break ;
851+ case MLKEM768_X25519 :
852+ payloadlen -= EncType .MLKEM768_X25519_INT .getPubkeyLen () + 16 ;
853+ break ;
854+ case MLKEM1024_X25519 :
855+ payloadlen -= EncType .MLKEM1024_X25519_INT .getPubkeyLen () + 16 ;
856+ break ;
857+ }
787858 byte [] payload = new byte [payloadlen ];
788859
789860 byte [] tmp2 = new byte [32 ];
@@ -906,7 +977,19 @@ private void runConnection(Socket socket, byte[] s, byte[] inithash, byte[] priv
906977
907978 if (_useNoise ) {
908979 state .mixHash (tag , 0 , TAGLEN );
909- tmp = new byte [32 + 16 ]; // 48
980+ int tmplen = 32 + 16 ; // 48
981+ switch (_type ) {
982+ case MLKEM512_X25519 :
983+ tmplen += EncType .MLKEM512_X25519_CT .getPubkeyLen () + 16 ;
984+ break ;
985+ case MLKEM768_X25519 :
986+ tmplen += EncType .MLKEM768_X25519_CT .getPubkeyLen () + 16 ;
987+ break ;
988+ case MLKEM1024_X25519 :
989+ tmplen += EncType .MLKEM1024_X25519_CT .getPubkeyLen () + 16 ;
990+ break ;
991+ }
992+ tmp = new byte [tmplen ];
910993
911994 System .out .println (state .toString ());
912995 state .writeMessage (tmp , 0 , ZEROLEN , 0 , 0 );
0 commit comments