1818import java .util .ArrayList ;
1919import java .util .List ;
2020import java .util .stream .Collectors ;
21+ import java .util .stream .Stream ;
2122
2223import ch .qos .logback .classic .Logger ;
2324import ch .qos .logback .classic .spi .ILoggingEvent ;
2425import ch .qos .logback .core .read .ListAppender ;
2526import com .hierynomus .sshj .SshdContainer ;
27+ import net .schmizz .keepalive .KeepAlive ;
28+ import net .schmizz .keepalive .KeepAliveProvider ;
29+ import net .schmizz .sshj .Config ;
30+ import net .schmizz .sshj .DefaultConfig ;
2631import net .schmizz .sshj .SSHClient ;
32+ import net .schmizz .sshj .common .Message ;
33+ import net .schmizz .sshj .common .SSHPacket ;
34+ import net .schmizz .sshj .connection .ConnectionImpl ;
35+ import net .schmizz .sshj .transport .TransportException ;
2736import org .junit .jupiter .api .AfterEach ;
2837import org .junit .jupiter .api .BeforeEach ;
29- import org .junit .jupiter .api .Test ;
38+ import org .junit .jupiter .params .ParameterizedTest ;
39+ import org .junit .jupiter .params .provider .Arguments ;
40+ import org .junit .jupiter .params .provider .MethodSource ;
3041import org .slf4j .LoggerFactory ;
3142import org .testcontainers .junit .jupiter .Container ;
3243import org .testcontainers .junit .jupiter .Testcontainers ;
@@ -62,14 +73,27 @@ private void setUpLogger(String className) {
6273 watchedLoggers .add (logger );
6374 }
6475
65- @ Test
66- void strictKeyExchange () throws Throwable {
67- try (SSHClient client = sshd .getConnectedClient ()) {
76+ private static Stream <Arguments > strictKeyExchange () {
77+ Config defaultConfig = new DefaultConfig ();
78+ Config heartbeaterConfig = new DefaultConfig ();
79+ heartbeaterConfig .setKeepAliveProvider (new KeepAliveProvider () {
80+ @ Override
81+ public KeepAlive provide (ConnectionImpl connection ) {
82+ return new HotLoopHeartbeater (connection );
83+ }
84+ });
85+ return Stream .of (defaultConfig , heartbeaterConfig ).map (Arguments ::of );
86+ }
87+
88+ @ MethodSource
89+ @ ParameterizedTest
90+ void strictKeyExchange (Config config ) throws Throwable {
91+ try (SSHClient client = sshd .getConnectedClient (config )) {
6892 client .authPublickey ("sshj" , "src/itest/resources/keyfiles/id_rsa_opensshv1" );
6993 assertTrue (client .isAuthenticated ());
7094 }
7195 List <String > keyExchangerLogs = getLogs ("KeyExchanger" );
72- assertThat (keyExchangerLogs ).containsSequence (
96+ assertThat (keyExchangerLogs ).contains (
7397 "Initiating key exchange" ,
7498 "Sending SSH_MSG_KEXINIT" ,
7599 "Received SSH_MSG_KEXINIT" ,
@@ -78,7 +102,7 @@ void strictKeyExchange() throws Throwable {
78102 List <String > decoderLogs = getLogs ("Decoder" ).stream ()
79103 .map (log -> log .split (":" )[0 ])
80104 .collect (Collectors .toList ());
81- assertThat (decoderLogs ).containsExactly (
105+ assertThat (decoderLogs ).startsWith (
82106 "Received packet #0" ,
83107 "Received packet #1" ,
84108 "Received packet #2" ,
@@ -90,7 +114,7 @@ void strictKeyExchange() throws Throwable {
90114 List <String > encoderLogs = getLogs ("Encoder" ).stream ()
91115 .map (log -> log .split (":" )[0 ])
92116 .collect (Collectors .toList ());
93- assertThat (encoderLogs ).containsExactly (
117+ assertThat (encoderLogs ).startsWith (
94118 "Encoding packet #0" ,
95119 "Encoding packet #1" ,
96120 "Encoding packet #2" ,
@@ -108,4 +132,22 @@ private List<String> getLogs(String className) {
108132 .collect (Collectors .toList ());
109133 }
110134
135+ private static class HotLoopHeartbeater extends KeepAlive {
136+
137+ HotLoopHeartbeater (ConnectionImpl conn ) {
138+ super (conn , "sshj-Heartbeater" );
139+ }
140+
141+ @ Override
142+ public boolean isEnabled () {
143+ return true ;
144+ }
145+
146+ @ Override
147+ protected void doKeepAlive () throws TransportException {
148+ conn .getTransport ().write (new SSHPacket (Message .IGNORE ));
149+ }
150+
151+ }
152+
111153}
0 commit comments