Skip to content

Commit cc4a451

Browse files
committed
RatchetTest: more mods to test PQ, WIP
1 parent e4ed9e3 commit cc4a451

File tree

1 file changed

+94
-11
lines changed

1 file changed

+94
-11
lines changed

java-utils/RatchetTest.java

Lines changed: 94 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import net.i2p.crypto.EncType;
3737
import net.i2p.crypto.HKDF;
3838
import net.i2p.crypto.HMAC256Generator;
39+
import net.i2p.crypto.KeyFactory;
3940
import net.i2p.crypto.KeyPair;
4041
import net.i2p.crypto.SHA256Generator;
4142
import net.i2p.crypto.SigType;
@@ -62,6 +63,7 @@
6263
import net.i2p.data.i2np.I2NPMessageHandler;
6364
import net.i2p.data.i2np.UnknownI2NPMessage;
6465
import net.i2p.router.RouterContext;
66+
import net.i2p.router.crypto.pqc.MLKEM;
6567
import static net.i2p.router.crypto.ratchet.RatchetPayload.*;
6668
import net.i2p.router.transport.crypto.X25519KeyFactory;
6769
import net.i2p.util.HexDump;
@@ -130,6 +132,7 @@ public class RatchetTest implements RatchetPayload.PayloadCallback {
130132

131133
private final X25519KeyFactory kf;
132134
private final Elg2KeyFactory ekf;
135+
private final PQKeyFactory hkf;
133136
private final Elligator2 elg2;
134137
private final HKDF hkdf;
135138

@@ -156,12 +159,15 @@ public RatchetTest(RouterContext ctx, String host, int port, boolean isAlice, bo
156159
_pkFile = pkfile;
157160
_type = type;
158161

159-
if (type == EncType.ECIES_X25519)
162+
if (type == EncType.ECIES_X25519) {
160163
kf = new X25519KeyFactory(_context);
161-
else if (useNoise)
164+
hkf = null;
165+
} else if (useNoise) {
162166
kf = new TypeKeyFactory(_context, type);
163-
else
167+
hkf = new PQKeyFactory(type);
168+
} else {
164169
throw new IllegalArgumentException("PQ types require -n");
170+
}
165171
kf.start();
166172
ekf = new Elg2KeyFactory(_context);
167173
ekf.start();
@@ -192,6 +198,36 @@ public KeyPair getKeys() {
192198
}
193199
}
194200

201+
/**
202+
* Make type 5-7 hybrid keys since we don't have a key factory yet
203+
*/
204+
private static class PQKeyFactory implements KeyFactory {
205+
private final EncType etype;
206+
public PQKeyFactory(EncType type) {
207+
switch(type) {
208+
case MLKEM512_X25519:
209+
type = EncType.MLKEM512_X25519_INT;
210+
break;
211+
case MLKEM768_X25519:
212+
type = EncType.MLKEM768_X25519_INT;
213+
break;
214+
case MLKEM1024_X25519:
215+
type = EncType.MLKEM1024_X25519_INT;
216+
break;
217+
default:
218+
throw new IllegalArgumentException("Unsupported type " + type);
219+
}
220+
etype = type;
221+
}
222+
public KeyPair getKeys() {
223+
try {
224+
return MLKEM.getKeys(etype);
225+
} catch (GeneralSecurityException gse) {
226+
throw new IllegalStateException(gse);
227+
}
228+
}
229+
}
230+
195231
private static String getNoisePattern(EncType type) {
196232
switch(type) {
197233
case ECIES_X25519:
@@ -252,7 +288,7 @@ private void connect() throws Exception {
252288
byte[] epriv = null;
253289
HandshakeState state = null;
254290
if (_useNoise) {
255-
state = new HandshakeState(getNoisePattern(_type), HandshakeState.INITIATOR, ekf);
291+
state = new HandshakeState(getNoisePattern(_type), HandshakeState.INITIATOR, ekf, hkf);
256292
state.getRemotePublicKey().setPublicKey(s, 0);
257293
state.getLocalKeyPair().setKeys(priv, 0, as, 0);
258294
System.out.println("Before start");
@@ -342,7 +378,19 @@ private void connect() throws Exception {
342378

343379
if (_useNoise) {
344380
// encrypt X and write X and the options block
345-
tmp = new byte[32 + 32 + 16 + payloadlen + 16];
381+
int tmplen = 32 + 32 + 16 + payloadlen + 16;
382+
switch(_type) {
383+
case MLKEM512_X25519:
384+
tmplen += EncType.MLKEM512_X25519_INT.getPubkeyLen() + 16;
385+
break;
386+
case MLKEM768_X25519:
387+
tmplen += EncType.MLKEM768_X25519_INT.getPubkeyLen() + 16;
388+
break;
389+
case MLKEM1024_X25519:
390+
tmplen += EncType.MLKEM1024_X25519_INT.getPubkeyLen() + 16;
391+
break;
392+
}
393+
tmp = new byte[tmplen];
346394
state.writeMessage(tmp, 0, payload, 0, payload.length);
347395
System.out.println(state.toString());
348396
// overwrite eph. key with encoded key
@@ -424,8 +472,20 @@ private void connect() throws Exception {
424472
// System.out.println("Tag NOT FOUND in expected tagset");
425473

426474

427-
tmp = new byte[48];
428-
System.arraycopy(itmp, 8, tmp, 0, 48);
475+
int tmplen = 48;
476+
switch(_type) {
477+
case MLKEM512_X25519:
478+
tmplen += EncType.MLKEM512_X25519_CT.getPubkeyLen() + 16;
479+
break;
480+
case MLKEM768_X25519:
481+
tmplen += EncType.MLKEM768_X25519_CT.getPubkeyLen() + 16;
482+
break;
483+
case MLKEM1024_X25519:
484+
tmplen += EncType.MLKEM1024_X25519_CT.getPubkeyLen() + 16;
485+
break;
486+
}
487+
tmp = new byte[tmplen];
488+
System.arraycopy(itmp, 8, tmp, 0, tmplen);
429489
System.out.println("Got msg 2 frame part 1");
430490
System.out.println(HexDump.dump(tmp));
431491

@@ -443,13 +503,13 @@ private void connect() throws Exception {
443503
state.mixHash(tag, 0, TAGLEN);
444504
System.out.println(state.toString());
445505
try {
446-
state.readMessage(tmp, 0, 48, ZEROLEN, 0);
506+
state.readMessage(tmp, 0, tmplen, ZEROLEN, 0);
447507
} catch (GeneralSecurityException gse) {
448508
System.out.println("**************\nState at failure:");
449509
System.out.println(state.toString());
450510
throw new IOException("Bad AEAD msg 2", gse);
451511
}
452-
System.out.println("After Message 1");
512+
System.out.println("After Message 2");
453513
System.out.println(state.toString());
454514
} else {
455515
// KDF 2
@@ -758,7 +818,7 @@ private void runConnection(Socket socket, byte[] s, byte[] inithash, byte[] priv
758818

759819
HandshakeState state = null;
760820
if (_useNoise) {
761-
state = new HandshakeState(getNoisePattern(_type), HandshakeState.RESPONDER, ekf);
821+
state = new HandshakeState(getNoisePattern(_type), HandshakeState.RESPONDER, ekf, hkf);
762822
state.getLocalKeyPair().setKeys(priv, 0, s, 0);
763823
System.out.println("Before start");
764824
System.out.println(state.toString());
@@ -784,6 +844,17 @@ private void runConnection(Socket socket, byte[] s, byte[] inithash, byte[] priv
784844
int len = tmp.length;
785845
System.out.println(HexDump.dump(tmp));
786846
int payloadlen = len - (32 + 32 + 16 + 16);
847+
switch(_type) {
848+
case MLKEM512_X25519:
849+
payloadlen -= EncType.MLKEM512_X25519_INT.getPubkeyLen() + 16;
850+
break;
851+
case MLKEM768_X25519:
852+
payloadlen -= EncType.MLKEM768_X25519_INT.getPubkeyLen() + 16;
853+
break;
854+
case MLKEM1024_X25519:
855+
payloadlen -= EncType.MLKEM1024_X25519_INT.getPubkeyLen() + 16;
856+
break;
857+
}
787858
byte[] payload = new byte[payloadlen];
788859

789860
byte[] tmp2 = new byte[32];
@@ -906,7 +977,19 @@ private void runConnection(Socket socket, byte[] s, byte[] inithash, byte[] priv
906977

907978
if (_useNoise) {
908979
state.mixHash(tag, 0, TAGLEN);
909-
tmp = new byte[32 + 16]; // 48
980+
int tmplen = 32 + 16; // 48
981+
switch(_type) {
982+
case MLKEM512_X25519:
983+
tmplen += EncType.MLKEM512_X25519_CT.getPubkeyLen() + 16;
984+
break;
985+
case MLKEM768_X25519:
986+
tmplen += EncType.MLKEM768_X25519_CT.getPubkeyLen() + 16;
987+
break;
988+
case MLKEM1024_X25519:
989+
tmplen += EncType.MLKEM1024_X25519_CT.getPubkeyLen() + 16;
990+
break;
991+
}
992+
tmp = new byte[tmplen];
910993

911994
System.out.println(state.toString());
912995
state.writeMessage(tmp, 0, ZEROLEN, 0, 0);

0 commit comments

Comments
 (0)