Skip to content

Commit 2551f8e

Browse files
Add Transport.isKeyExchangeRequired() to avoid unnecessary KEXINIT (#811)
* Added Transport.isKeyExchangeRequired() to avoid unnecessary KEXINIT - Updated SSHClient.onConnect() to check isKeyExchangeRequired() before calling doKex() - Added started timestamp in ThreadNameProvider for improved tracking * Moved KeepAliveThread State check after authentication to avoid test timing issues
1 parent 430cbfc commit 2551f8e

File tree

5 files changed

+27
-3
lines changed

5 files changed

+27
-3
lines changed

src/main/java/com/hierynomus/sshj/common/ThreadNameProvider.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ public class ThreadNameProvider {
2929
public static void setThreadName(final Thread thread, final RemoteAddressProvider remoteAddressProvider) {
3030
final InetSocketAddress remoteSocketAddress = remoteAddressProvider.getRemoteSocketAddress();
3131
final String address = remoteSocketAddress == null ? DISCONNECTED : remoteSocketAddress.toString();
32-
final String threadName = String.format("sshj-%s-%s", thread.getClass().getSimpleName(), address);
32+
final long started = System.currentTimeMillis();
33+
final String threadName = String.format("sshj-%s-%s-%d", thread.getClass().getSimpleName(), address, started);
3334
thread.setName(threadName);
3435
}
3536
}

src/main/java/net/schmizz/sshj/SSHClient.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,12 @@ protected void onConnect()
810810
ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans);
811811
keepAliveThread.start();
812812
}
813-
doKex();
813+
if (trans.isKeyExchangeRequired()) {
814+
log.debug("Initiating Key Exchange for new connection");
815+
doKex();
816+
} else {
817+
log.debug("Key Exchange already completed for new connection");
818+
}
814819
}
815820

816821
/**

src/main/java/net/schmizz/sshj/transport/Transport.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ void init(String host, int port, InputStream in, OutputStream out)
7171
void doKex()
7272
throws TransportException;
7373

74+
/**
75+
* Is Key Exchange required based on current transport status
76+
*
77+
* @return Key Exchange required status
78+
*/
79+
boolean isKeyExchangeRequired();
80+
7481
/** @return the version string used by this client to identify itself to an SSH server, e.g. "SSHJ_3_0" */
7582
String getClientVersion();
7683

src/main/java/net/schmizz/sshj/transport/TransportImpl.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,16 @@ public void doKex()
254254
kexer.startKex(true);
255255
}
256256

257+
/**
258+
* Is Key Exchange required returns true when Key Exchange is not done and when Key Exchange is not ongoing
259+
*
260+
* @return Key Exchange required status
261+
*/
262+
@Override
263+
public boolean isKeyExchangeRequired() {
264+
return !kexer.isKexDone() && !kexer.isKexOngoing();
265+
}
266+
257267
public boolean isKexDone() {
258268
return kexer.isKexDone();
259269
}

src/test/java/com/hierynomus/sshj/keepalive/KeepAliveThreadTerminationTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,11 @@ public void shouldStartThreadOnConnectAndInterruptOnDisconnect() throws IOExcept
5959
assertEquals(Thread.State.NEW, keepAlive.getState());
6060

6161
fixture.connectClient(sshClient);
62-
assertEquals(Thread.State.TIMED_WAITING, keepAlive.getState());
6362

6463
assertThrows(UserAuthException.class, () -> sshClient.authPassword("bad", "credentials"));
6564

65+
assertEquals(Thread.State.TIMED_WAITING, keepAlive.getState());
66+
6667
fixture.stopClient();
6768
Thread.sleep(STOP_SLEEP);
6869

0 commit comments

Comments
 (0)