Skip to content

Commit 90099bb

Browse files
Updated SSHClient to interrupt KeepAlive Thread when disconnecting (#506) (#752)
- Changed KeepAlive.setKeepAliveInterval() to avoid starting Thread - Updated SSHClient.onConnect() to start KeepAlive Thread when enabled - Updated SSHClient.disconnect() to interrupt KeepAlive Thread - Updated KeepAliveThreadTerminationTest to verify state of KeepAlive Thread Co-authored-by: Jeroen van Erp <[email protected]>
1 parent ce0a7d5 commit 90099bb

File tree

3 files changed

+61
-51
lines changed

3 files changed

+61
-51
lines changed

src/main/java/net/schmizz/keepalive/KeepAlive.java

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import net.schmizz.sshj.transport.TransportException;
2121
import org.slf4j.Logger;
2222

23+
import java.util.concurrent.TimeUnit;
24+
2325
public abstract class KeepAlive extends Thread {
2426
protected final Logger log;
2527
protected final ConnectionImpl conn;
@@ -33,50 +35,40 @@ protected KeepAlive(ConnectionImpl conn, String name) {
3335
setDaemon(true);
3436
}
3537

38+
public boolean isEnabled() {
39+
return keepAliveInterval > 0;
40+
}
41+
3642
public synchronized int getKeepAliveInterval() {
3743
return keepAliveInterval;
3844
}
3945

4046
public synchronized void setKeepAliveInterval(int keepAliveInterval) {
4147
this.keepAliveInterval = keepAliveInterval;
42-
if (keepAliveInterval > 0 && getState() == State.NEW) {
43-
start();
44-
}
45-
notify();
46-
}
47-
48-
synchronized protected int getPositiveInterval()
49-
throws InterruptedException {
50-
while (keepAliveInterval <= 0) {
51-
wait();
52-
}
53-
return keepAliveInterval;
5448
}
5549

5650
@Override
5751
public void run() {
58-
log.debug("Starting {}, sending keep-alive every {} seconds", getClass().getSimpleName(), keepAliveInterval);
52+
log.debug("{} Started with interval [{} seconds]", getClass().getSimpleName(), keepAliveInterval);
5953
try {
6054
while (!isInterrupted()) {
61-
final int hi = getPositiveInterval();
55+
final int interval = getKeepAliveInterval();
6256
if (conn.getTransport().isRunning()) {
63-
log.debug("Sending keep-alive since {} seconds elapsed", hi);
57+
log.debug("{} Sending after interval [{} seconds]", getClass().getSimpleName(), interval);
6458
doKeepAlive();
6559
}
66-
Thread.sleep(hi * 1000);
60+
TimeUnit.SECONDS.sleep(interval);
6761
}
6862
} catch (InterruptedException e) {
69-
// Interrupt signal may be catched when sleeping.
63+
log.trace("{} Interrupted while sleeping", getClass().getSimpleName());
7064
} catch (Exception e) {
7165
// If we weren't interrupted, kill the transport, then this exception was unexpected.
7266
// Else we're in shutdown-mode already, so don't forcibly kill the transport.
7367
if (!isInterrupted()) {
7468
conn.getTransport().die(e);
7569
}
7670
}
77-
78-
log.debug("Stopping {}", getClass().getSimpleName());
79-
71+
log.debug("{} Stopped", getClass().getSimpleName());
8072
}
8173

8274
protected abstract void doKeepAlive() throws TransportException, ConnectionException;

src/main/java/net/schmizz/sshj/SSHClient.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package net.schmizz.sshj;
1717

18+
import net.schmizz.keepalive.KeepAlive;
1819
import net.schmizz.sshj.common.*;
1920
import net.schmizz.sshj.connection.Connection;
2021
import net.schmizz.sshj.connection.ConnectionException;
@@ -424,6 +425,7 @@ public void authGssApiWithMic(String username, LoginContext context, Oid support
424425
@Override
425426
public void disconnect()
426427
throws IOException {
428+
conn.getKeepAlive().interrupt();
427429
for (LocalPortForwarder forwarder : forwarders) {
428430
try {
429431
forwarder.close();
@@ -791,6 +793,10 @@ protected void onConnect()
791793
throws IOException {
792794
super.onConnect();
793795
trans.init(getRemoteHostname(), getRemotePort(), getInputStream(), getOutputStream());
796+
final KeepAlive keepAliveThread = conn.getKeepAlive();
797+
if (keepAliveThread.isEnabled()) {
798+
keepAliveThread.start();
799+
}
794800
doKex();
795801
}
796802

src/test/java/com/hierynomus/sshj/keepalive/KeepAliveThreadTerminationTest.java

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,54 +15,66 @@
1515
*/
1616
package com.hierynomus.sshj.keepalive;
1717

18-
import com.hierynomus.sshj.test.KnownFailingTests;
19-
import com.hierynomus.sshj.test.SlowTests;
2018
import com.hierynomus.sshj.test.SshFixture;
19+
import net.schmizz.keepalive.KeepAlive;
2120
import net.schmizz.keepalive.KeepAliveProvider;
2221
import net.schmizz.sshj.DefaultConfig;
2322
import net.schmizz.sshj.SSHClient;
2423
import net.schmizz.sshj.userauth.UserAuthException;
2524
import org.junit.Rule;
2625
import org.junit.Test;
27-
import org.junit.experimental.categories.Category;
2826

2927
import java.io.IOException;
30-
import java.lang.management.ManagementFactory;
31-
import java.lang.management.ThreadInfo;
32-
import java.lang.management.ThreadMXBean;
3328

34-
import static org.junit.Assert.fail;
29+
import static org.junit.Assert.assertEquals;
30+
import static org.junit.Assert.assertFalse;
31+
import static org.junit.Assert.assertThrows;
32+
import static org.junit.Assert.assertTrue;
3533

3634
public class KeepAliveThreadTerminationTest {
3735

36+
private static final int KEEP_ALIVE_SECONDS = 1;
37+
38+
private static final long STOP_SLEEP = 1500;
39+
3840
@Rule
3941
public SshFixture fixture = new SshFixture();
4042

4143
@Test
42-
@Category({SlowTests.class, KnownFailingTests.class})
43-
public void shouldCorrectlyTerminateThreadOnDisconnect() throws IOException, InterruptedException {
44-
DefaultConfig defaultConfig = new DefaultConfig();
44+
public void shouldNotStartThreadOnSetKeepAliveInterval() {
45+
final SSHClient sshClient = setupClient();
46+
47+
final KeepAlive keepAlive = sshClient.getConnection().getKeepAlive();
48+
assertTrue(keepAlive.isDaemon());
49+
assertFalse(keepAlive.isAlive());
50+
assertEquals(Thread.State.NEW, keepAlive.getState());
51+
}
52+
53+
@Test
54+
public void shouldStartThreadOnConnectAndInterruptOnDisconnect() throws IOException, InterruptedException {
55+
final SSHClient sshClient = setupClient();
56+
57+
final KeepAlive keepAlive = sshClient.getConnection().getKeepAlive();
58+
assertTrue(keepAlive.isDaemon());
59+
assertEquals(Thread.State.NEW, keepAlive.getState());
60+
61+
fixture.connectClient(sshClient);
62+
assertEquals(Thread.State.TIMED_WAITING, keepAlive.getState());
63+
64+
assertThrows(UserAuthException.class, () -> sshClient.authPassword("bad", "credentials"));
65+
66+
fixture.stopClient();
67+
Thread.sleep(STOP_SLEEP);
68+
69+
assertFalse(keepAlive.isAlive());
70+
assertEquals(Thread.State.TERMINATED, keepAlive.getState());
71+
}
72+
73+
private SSHClient setupClient() {
74+
final DefaultConfig defaultConfig = new DefaultConfig();
4575
defaultConfig.setKeepAliveProvider(KeepAliveProvider.KEEP_ALIVE);
46-
for (int i = 0; i < 10; i++) {
47-
SSHClient sshClient = fixture.setupClient(defaultConfig);
48-
fixture.connectClient(sshClient);
49-
sshClient.getConnection().getKeepAlive().setKeepAliveInterval(1);
50-
try {
51-
sshClient.authPassword("bad", "credentials");
52-
fail("Should not auth.");
53-
} catch (UserAuthException e) {
54-
// OK
55-
}
56-
fixture.stopClient();
57-
Thread.sleep(2000);
58-
}
59-
60-
ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean();
61-
for (long l : threadMXBean.getAllThreadIds()) {
62-
ThreadInfo threadInfo = threadMXBean.getThreadInfo(l);
63-
if (threadInfo.getThreadName().equals("keep-alive") && threadInfo.getThreadState() != Thread.State.TERMINATED) {
64-
fail("Found alive keep-alive thread in state " + threadInfo.getThreadState());
65-
}
66-
}
76+
final SSHClient sshClient = fixture.setupClient(defaultConfig);
77+
sshClient.getConnection().getKeepAlive().setKeepAliveInterval(KEEP_ALIVE_SECONDS);
78+
return sshClient;
6779
}
6880
}

0 commit comments

Comments
 (0)