Skip to content

Commit 81d77d2

Browse files
Don't send keep alive signals before kex is done (#934)
Otherwise, they could interfere with strict key exchange. Co-authored-by: Jeroen van Erp <[email protected]>
1 parent 70af58d commit 81d77d2

File tree

2 files changed

+50
-8
lines changed

2 files changed

+50
-8
lines changed

src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,26 @@
1818
import java.util.ArrayList;
1919
import java.util.List;
2020
import java.util.stream.Collectors;
21+
import java.util.stream.Stream;
2122

2223
import ch.qos.logback.classic.Logger;
2324
import ch.qos.logback.classic.spi.ILoggingEvent;
2425
import ch.qos.logback.core.read.ListAppender;
2526
import com.hierynomus.sshj.SshdContainer;
27+
import net.schmizz.keepalive.KeepAlive;
28+
import net.schmizz.keepalive.KeepAliveProvider;
29+
import net.schmizz.sshj.Config;
30+
import net.schmizz.sshj.DefaultConfig;
2631
import net.schmizz.sshj.SSHClient;
32+
import net.schmizz.sshj.common.Message;
33+
import net.schmizz.sshj.common.SSHPacket;
34+
import net.schmizz.sshj.connection.ConnectionImpl;
35+
import net.schmizz.sshj.transport.TransportException;
2736
import org.junit.jupiter.api.AfterEach;
2837
import org.junit.jupiter.api.BeforeEach;
29-
import org.junit.jupiter.api.Test;
38+
import org.junit.jupiter.params.ParameterizedTest;
39+
import org.junit.jupiter.params.provider.Arguments;
40+
import org.junit.jupiter.params.provider.MethodSource;
3041
import org.slf4j.LoggerFactory;
3142
import org.testcontainers.junit.jupiter.Container;
3243
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -62,14 +73,27 @@ private void setUpLogger(String className) {
6273
watchedLoggers.add(logger);
6374
}
6475

65-
@Test
66-
void strictKeyExchange() throws Throwable {
67-
try (SSHClient client = sshd.getConnectedClient()) {
76+
private static Stream<Arguments> strictKeyExchange() {
77+
Config defaultConfig = new DefaultConfig();
78+
Config heartbeaterConfig = new DefaultConfig();
79+
heartbeaterConfig.setKeepAliveProvider(new KeepAliveProvider() {
80+
@Override
81+
public KeepAlive provide(ConnectionImpl connection) {
82+
return new HotLoopHeartbeater(connection);
83+
}
84+
});
85+
return Stream.of(defaultConfig, heartbeaterConfig).map(Arguments::of);
86+
}
87+
88+
@MethodSource
89+
@ParameterizedTest
90+
void strictKeyExchange(Config config) throws Throwable {
91+
try (SSHClient client = sshd.getConnectedClient(config)) {
6892
client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1");
6993
assertTrue(client.isAuthenticated());
7094
}
7195
List<String> keyExchangerLogs = getLogs("KeyExchanger");
72-
assertThat(keyExchangerLogs).containsSequence(
96+
assertThat(keyExchangerLogs).contains(
7397
"Initiating key exchange",
7498
"Sending SSH_MSG_KEXINIT",
7599
"Received SSH_MSG_KEXINIT",
@@ -78,7 +102,7 @@ void strictKeyExchange() throws Throwable {
78102
List<String> decoderLogs = getLogs("Decoder").stream()
79103
.map(log -> log.split(":")[0])
80104
.collect(Collectors.toList());
81-
assertThat(decoderLogs).containsExactly(
105+
assertThat(decoderLogs).startsWith(
82106
"Received packet #0",
83107
"Received packet #1",
84108
"Received packet #2",
@@ -90,7 +114,7 @@ void strictKeyExchange() throws Throwable {
90114
List<String> encoderLogs = getLogs("Encoder").stream()
91115
.map(log -> log.split(":")[0])
92116
.collect(Collectors.toList());
93-
assertThat(encoderLogs).containsExactly(
117+
assertThat(encoderLogs).startsWith(
94118
"Encoding packet #0",
95119
"Encoding packet #1",
96120
"Encoding packet #2",
@@ -108,4 +132,22 @@ private List<String> getLogs(String className) {
108132
.collect(Collectors.toList());
109133
}
110134

135+
private static class HotLoopHeartbeater extends KeepAlive {
136+
137+
HotLoopHeartbeater(ConnectionImpl conn) {
138+
super(conn, "sshj-Heartbeater");
139+
}
140+
141+
@Override
142+
public boolean isEnabled() {
143+
return true;
144+
}
145+
146+
@Override
147+
protected void doKeepAlive() throws TransportException {
148+
conn.getTransport().write(new SSHPacket(Message.IGNORE));
149+
}
150+
151+
}
152+
111153
}

src/main/java/net/schmizz/sshj/SSHClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,12 +804,12 @@ protected void onConnect()
804804
throws IOException {
805805
super.onConnect();
806806
trans.init(getRemoteHostname(), getRemotePort(), getInputStream(), getOutputStream());
807+
doKex();
807808
final KeepAlive keepAliveThread = conn.getKeepAlive();
808809
if (keepAliveThread.isEnabled()) {
809810
ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans);
810811
keepAliveThread.start();
811812
}
812-
doKex();
813813
}
814814

815815
/**

0 commit comments

Comments
 (0)