Skip to content

Commit 60688ee

Browse files
committed
add BoundedKeepAliveProvider (#986)
1 parent 7f8f43c commit 60688ee

File tree

5 files changed

+314
-1
lines changed

5 files changed

+314
-1
lines changed
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
package net.schmizz.keepalive;
2+
3+
import net.schmizz.sshj.Config;
4+
import net.schmizz.sshj.connection.ConnectionException;
5+
import net.schmizz.sshj.connection.ConnectionImpl;
6+
import net.schmizz.sshj.transport.TransportException;
7+
import org.slf4j.Logger;
8+
9+
import java.util.ArrayList;
10+
import java.util.Comparator;
11+
import java.util.List;
12+
import java.util.concurrent.PriorityBlockingQueue;
13+
import java.util.concurrent.atomic.AtomicInteger;
14+
import java.util.concurrent.locks.Condition;
15+
import java.util.concurrent.locks.ReentrantLock;
16+
17+
/**
18+
* This implementation manages all {@link KeepAlive}s using configured number of threads. It works like a
19+
* thread pool, thus {@link BoundedKeepAliveProvider#shutdown()} must be called to clean up resources.
20+
* <br>
21+
* This provider uses {@link KeepAliveRunner#doKeepAlive()} as delegate, so it supports maxKeepAliveCount
22+
* parameter. All instances provided by this provider have identical configuration.
23+
*/
24+
public class BoundedKeepAliveProvider extends KeepAliveProvider {
25+
26+
public int maxKeepAliveCount = 3;
27+
public int keepAliveInterval = 5;
28+
29+
protected final KeepAliveMonitor monitor;
30+
31+
32+
public BoundedKeepAliveProvider(Config config, int numberOfThreads) {
33+
this.monitor = new KeepAliveMonitor(config, numberOfThreads);
34+
}
35+
36+
public void setKeepAliveInterval(int interval) {
37+
keepAliveInterval = interval;
38+
}
39+
40+
public void setMaxKeepAliveCount(int count) {
41+
maxKeepAliveCount = count;
42+
}
43+
44+
@Override
45+
public KeepAlive provide(ConnectionImpl connection) {
46+
return new Impl(connection, "bounded-keepalive-impl");
47+
}
48+
49+
public void shutdown() throws InterruptedException {
50+
monitor.shutdown();
51+
}
52+
53+
class Impl extends KeepAlive {
54+
55+
private final KeepAliveRunner delegate;
56+
57+
protected Impl(ConnectionImpl conn, String name) {
58+
super(conn, name);
59+
this.delegate = new KeepAliveRunner(conn);
60+
61+
// take care here, some parameters are set to both delegate and this
62+
this.delegate.setMaxAliveCount(BoundedKeepAliveProvider.this.maxKeepAliveCount);
63+
super.keepAliveInterval = BoundedKeepAliveProvider.this.keepAliveInterval;
64+
}
65+
66+
@Override
67+
protected void doKeepAlive() throws TransportException, ConnectionException {
68+
delegate.doKeepAlive();
69+
}
70+
71+
@Override
72+
public void startKeepAlive() {
73+
monitor.register(this);
74+
}
75+
76+
}
77+
78+
protected static class KeepAliveMonitor {
79+
80+
private final int numberOfThreads;
81+
private final PriorityBlockingQueue<Wrapper> Q =
82+
new PriorityBlockingQueue<>(32, Comparator.comparingLong(w -> w.nextTimeMillis));
83+
private long idleSleepMillis = 100;
84+
private static final List<Thread> workerThreads = new ArrayList<>();
85+
volatile boolean started = false;
86+
private final Logger logger;
87+
88+
private final ReentrantLock lock = new ReentrantLock();
89+
private final Condition shutDown = lock.newCondition();
90+
private final AtomicInteger shutDownCnt = new AtomicInteger(0);
91+
92+
public KeepAliveMonitor(Config config, int numberOfThreads) {
93+
this.numberOfThreads = numberOfThreads;
94+
logger = config.getLoggerFactory().getLogger(KeepAliveMonitor.class);
95+
}
96+
97+
// made public for test
98+
public void register(KeepAlive keepAlive) {
99+
if (!started) {
100+
start();
101+
}
102+
Q.add(new Wrapper(keepAlive));
103+
}
104+
105+
public void setIdleSleepMillis(long idleSleepMillis) {
106+
this.idleSleepMillis = idleSleepMillis;
107+
}
108+
109+
void unregister(KeepAlive keepAlive) {
110+
Q.removeIf(w -> keepAlive == w.keepAlive);
111+
}
112+
113+
private void sleep() {
114+
sleep(idleSleepMillis);
115+
}
116+
117+
private void sleep(long millis) {
118+
try {
119+
Thread.sleep(millis);
120+
} catch (InterruptedException e) {
121+
Thread.currentThread().interrupt();
122+
}
123+
}
124+
125+
private synchronized void start() {
126+
if (started) {
127+
return;
128+
}
129+
130+
for (int i = 0; i < numberOfThreads; i++) {
131+
Thread t = new Thread(this::doStart);
132+
workerThreads.add(t);
133+
}
134+
workerThreads.forEach(Thread::start);
135+
started = true;
136+
}
137+
138+
139+
private void doStart() {
140+
while (!Thread.currentThread().isInterrupted()) {
141+
Wrapper wrapper;
142+
143+
if (Q.isEmpty() || (wrapper = Q.poll()) == null) {
144+
sleep();
145+
continue;
146+
}
147+
148+
long currentTimeMillis = System.currentTimeMillis();
149+
if (wrapper.nextTimeMillis > currentTimeMillis) {
150+
long sleepMillis = wrapper.nextTimeMillis - currentTimeMillis;
151+
logger.debug("{} millis until next check, sleep", sleepMillis);
152+
sleep(sleepMillis);
153+
}
154+
155+
try {
156+
wrapper.keepAlive.doKeepAlive();
157+
Q.add(wrapper.reschedule());
158+
} catch (Exception e) {
159+
// If we weren't interrupted, kill the transport, then this exception was unexpected.
160+
// Else we're in shutdown-mode already, so don't forcibly kill the transport.
161+
if (!Thread.currentThread().isInterrupted()) {
162+
wrapper.keepAlive.conn.getTransport().die(e);
163+
}
164+
}
165+
}
166+
lock.lock();
167+
try {
168+
if (shutDownCnt.incrementAndGet() == numberOfThreads) {
169+
shutDown.signal();
170+
}
171+
} finally {
172+
lock.unlock();
173+
}
174+
}
175+
176+
private synchronized void shutdown() throws InterruptedException {
177+
if (workerThreads.isEmpty()) {
178+
return;
179+
}
180+
for (Thread t : workerThreads) {
181+
t.interrupt();
182+
}
183+
lock.lock();
184+
logger.info("waiting for all {} threads to finish", numberOfThreads);
185+
shutDown.await();
186+
}
187+
188+
private static class Wrapper {
189+
private final KeepAlive keepAlive;
190+
private final long nextTimeMillis;
191+
192+
private Wrapper(KeepAlive keepAlive) {
193+
this.keepAlive = keepAlive;
194+
this.nextTimeMillis = System.currentTimeMillis() + keepAlive.keepAliveInterval * 1000L;
195+
}
196+
197+
private Wrapper reschedule() {
198+
return new Wrapper(keepAlive);
199+
}
200+
}
201+
}
202+
}

src/main/java/net/schmizz/keepalive/KeepAlive.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,11 @@ public void run() {
8989
}
9090

9191
protected abstract void doKeepAlive() throws TransportException, ConnectionException;
92+
93+
/**
94+
* Start keep-alive loop. Implementations MUST NOT block current thread.
95+
*/
96+
public void startKeepAlive() {
97+
start();
98+
}
9299
}

src/main/java/net/schmizz/sshj/SSHClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ protected void onConnect()
808808
final KeepAlive keepAliveThread = conn.getKeepAlive();
809809
if (keepAliveThread.isEnabled()) {
810810
ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans);
811-
keepAliveThread.start();
811+
keepAliveThread.startKeepAlive();
812812
}
813813
}
814814

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package com.hierynomus.sshj.keepalive;
2+
3+
import com.hierynomus.sshj.test.SshServerExtension;
4+
import net.schmizz.keepalive.BoundedKeepAliveProvider;
5+
import net.schmizz.keepalive.KeepAlive;
6+
import net.schmizz.sshj.DefaultConfig;
7+
import net.schmizz.sshj.SSHClient;
8+
import net.schmizz.sshj.connection.ConnectionException;
9+
import net.schmizz.sshj.connection.ConnectionImpl;
10+
import net.schmizz.sshj.transport.TransportException;
11+
import org.junit.jupiter.api.Assertions;
12+
import org.junit.jupiter.api.BeforeAll;
13+
import org.junit.jupiter.api.Test;
14+
import org.junit.jupiter.api.extension.RegisterExtension;
15+
16+
import java.io.IOException;
17+
import java.util.ArrayList;
18+
import java.util.List;
19+
20+
class EventuallyFailKeepAlive extends KeepAlive {
21+
// they can survive first 2 checks, and fail at 3rd
22+
int failAfter = 2;
23+
volatile int current = 0;
24+
25+
protected EventuallyFailKeepAlive(ConnectionImpl conn, String name) {
26+
super(conn, name);
27+
setKeepAliveInterval(1);
28+
}
29+
30+
@Override
31+
protected void doKeepAlive() throws TransportException, ConnectionException {
32+
current++;
33+
if (current > failAfter) {
34+
throw new ConnectionException("failed");
35+
}
36+
}
37+
}
38+
39+
public class BoundedKeepAliveProviderTest {
40+
41+
static BoundedKeepAliveProvider kp;
42+
static final DefaultConfig defaultConfig = new DefaultConfig();
43+
44+
45+
@BeforeAll
46+
static void setUpBeforeClass() throws Exception {
47+
48+
kp = new BoundedKeepAliveProvider(defaultConfig, 2) {
49+
@Override
50+
public KeepAlive provide(ConnectionImpl connection) {
51+
return new EventuallyFailKeepAlive(connection, "test") {
52+
@Override
53+
public void startKeepAlive() {
54+
monitor.register(this);
55+
}
56+
};
57+
}
58+
};
59+
}
60+
61+
@RegisterExtension
62+
public SshServerExtension fixture = new SshServerExtension();
63+
64+
void testWithConnections(int numOfConnections) throws IOException, InterruptedException {
65+
List<SSHClient> clients = setupClients(numOfConnections);
66+
for (SSHClient client : clients) {
67+
fixture.connectClient(client);
68+
}
69+
// first two checks are ok
70+
Thread.sleep(2000);
71+
Assertions.assertTrue(clients.stream().allMatch(SSHClient::isConnected));
72+
73+
// wait for 3rd check to take place, we wait additional 100ms for it to finish
74+
Thread.sleep(1100);
75+
Assertions.assertTrue(clients.stream().noneMatch(SSHClient::isConnected));
76+
Assertions.assertEquals(0, fixture.getServer().getActiveSessions().size());
77+
}
78+
79+
@Test
80+
void testBoundedKeepAlive() throws IOException, InterruptedException {
81+
// 2 threads can handle 64 connections
82+
testWithConnections(64);
83+
}
84+
85+
private List<SSHClient> setupClients(int numOfConnections) {
86+
List<SSHClient> clients = new ArrayList<>();
87+
defaultConfig.setKeepAliveProvider(kp);
88+
89+
for (int i = 0; i < numOfConnections; i++) {
90+
final SSHClient sshClient = fixture.createClient(defaultConfig);
91+
clients.add(sshClient);
92+
}
93+
return clients;
94+
}
95+
}

src/test/java/com/hierynomus/sshj/test/SshServerExtension.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,15 @@ public SSHClient setupClient(Config config) {
9797
return client;
9898
}
9999

100+
/**
101+
* create a new uncached client
102+
*/
103+
public SSHClient createClient(Config config) {
104+
SSHClient client = new SSHClient(config);
105+
client.addHostKeyVerifier(fingerprint);
106+
return client;
107+
}
108+
100109
public SSHClient getClient() {
101110
if (client != null) {
102111
return client;

0 commit comments

Comments
 (0)