Skip to content

Commit 81e87a4

Browse files
authored
Add unit tests of strict key exchange extension (#918)
1 parent a262f51 commit 81e87a4

File tree

2 files changed

+356
-0
lines changed

2 files changed

+356
-0
lines changed
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
/*
2+
* Copyright (C)2009 - SSHJ Contributors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package net.schmizz.sshj.transport;
17+
18+
import java.math.BigInteger;
19+
import java.util.Collections;
20+
import java.util.List;
21+
22+
import net.schmizz.sshj.DefaultConfig;
23+
import net.schmizz.sshj.common.DisconnectReason;
24+
import net.schmizz.sshj.common.Factory;
25+
import net.schmizz.sshj.common.Message;
26+
import net.schmizz.sshj.common.SSHPacket;
27+
import net.schmizz.sshj.transport.kex.KeyExchange;
28+
import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
29+
import org.junit.jupiter.api.BeforeEach;
30+
import org.junit.jupiter.api.Test;
31+
import org.mockito.ArgumentCaptor;
32+
import org.mockito.Mockito;
33+
34+
import static org.assertj.core.api.Assertions.assertThat;
35+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
36+
import static org.mockito.ArgumentMatchers.any;
37+
import static org.mockito.Mockito.mock;
38+
import static org.mockito.Mockito.never;
39+
import static org.mockito.Mockito.verify;
40+
import static org.mockito.Mockito.when;
41+
42+
class KeyExchangerStrictKeyExchangeTest {
43+
44+
private TransportImpl transport;
45+
private DefaultConfig config;
46+
private KeyExchanger keyExchanger;
47+
48+
@BeforeEach
49+
void setUp() throws Exception {
50+
KeyExchange kex = mock(KeyExchange.class, Mockito.RETURNS_DEEP_STUBS);
51+
transport = mock(TransportImpl.class, Mockito.RETURNS_DEEP_STUBS);
52+
config = new DefaultConfig() {
53+
@Override
54+
protected void initKeyExchangeFactories() {
55+
setKeyExchangeFactories(Collections.singletonList(new Factory.Named<>() {
56+
@Override
57+
public KeyExchange create() {
58+
return kex;
59+
}
60+
61+
@Override
62+
public String getName() {
63+
return "mock-kex";
64+
}
65+
}));
66+
}
67+
};
68+
when(transport.getConfig()).thenReturn(config);
69+
when(transport.getServerID()).thenReturn("some server id");
70+
when(transport.getClientID()).thenReturn("some client id");
71+
when(kex.next(any(), any())).thenReturn(true);
72+
when(kex.getH()).thenReturn(new byte[0]);
73+
when(kex.getK()).thenReturn(BigInteger.ZERO);
74+
when(kex.getHash().digest()).thenReturn(new byte[10]);
75+
76+
keyExchanger = new KeyExchanger(transport);
77+
keyExchanger.addHostKeyVerifier(new PromiscuousVerifier());
78+
}
79+
80+
@Test
81+
void initialConditions() {
82+
assertThat(keyExchanger.isKexDone()).isFalse();
83+
assertThat(keyExchanger.isKexOngoing()).isFalse();
84+
assertThat(keyExchanger.isStrictKex()).isFalse();
85+
assertThat(keyExchanger.isInitialKex()).isTrue();
86+
}
87+
88+
@Test
89+
void startInitialKex() throws Exception {
90+
ArgumentCaptor<SSHPacket> sshPacketCaptor = ArgumentCaptor.forClass(SSHPacket.class);
91+
when(transport.write(sshPacketCaptor.capture())).thenReturn(0L);
92+
93+
keyExchanger.startKex(false);
94+
95+
assertThat(keyExchanger.isKexDone()).isFalse();
96+
assertThat(keyExchanger.isKexOngoing()).isTrue();
97+
assertThat(keyExchanger.isStrictKex()).isFalse();
98+
assertThat(keyExchanger.isInitialKex()).isTrue();
99+
100+
SSHPacket sshPacket = sshPacketCaptor.getValue();
101+
List<String> kex = new Proposal(sshPacket).getKeyExchangeAlgorithms();
102+
assertThat(kex).endsWith("[email protected]");
103+
}
104+
105+
@Test
106+
void receiveKexInitWithoutServerFlag() throws Exception {
107+
keyExchanger.startKex(false);
108+
109+
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(false));
110+
111+
assertThat(keyExchanger.isKexDone()).isFalse();
112+
assertThat(keyExchanger.isKexOngoing()).isTrue();
113+
assertThat(keyExchanger.isStrictKex()).isFalse();
114+
assertThat(keyExchanger.isInitialKex()).isTrue();
115+
}
116+
117+
@Test
118+
void finishNonStrictKex() throws Exception {
119+
keyExchanger.startKex(false);
120+
121+
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(false));
122+
keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31));
123+
keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS));
124+
125+
assertThat(keyExchanger.isKexDone()).isTrue();
126+
assertThat(keyExchanger.isKexOngoing()).isFalse();
127+
assertThat(keyExchanger.isStrictKex()).isFalse();
128+
assertThat(keyExchanger.isInitialKex()).isFalse();
129+
130+
verify(transport.getEncoder(), never()).resetSequenceNumber();
131+
verify(transport.getDecoder(), never()).resetSequenceNumber();
132+
}
133+
134+
@Test
135+
void receiveKexInitWithServerFlag() throws Exception {
136+
keyExchanger.startKex(false);
137+
138+
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true));
139+
140+
assertThat(keyExchanger.isKexDone()).isFalse();
141+
assertThat(keyExchanger.isKexOngoing()).isTrue();
142+
assertThat(keyExchanger.isStrictKex()).isTrue();
143+
assertThat(keyExchanger.isInitialKex()).isTrue();
144+
}
145+
146+
@Test
147+
void strictKexInitIsNotFirstPacket() throws Exception {
148+
when(transport.getDecoder().getSequenceNumber()).thenReturn(1L);
149+
keyExchanger.startKex(false);
150+
151+
assertThatExceptionOfType(TransportException.class).isThrownBy(
152+
() -> keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true))
153+
).satisfies(e -> {
154+
assertThat(e.getDisconnectReason()).isEqualTo(DisconnectReason.KEY_EXCHANGE_FAILED);
155+
assertThat(e.getMessage()).isEqualTo("SSH_MSG_KEXINIT was not first package during strict key exchange");
156+
});
157+
}
158+
159+
@Test
160+
void finishStrictKex() throws Exception {
161+
keyExchanger.startKex(false);
162+
163+
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true));
164+
verify(transport.getEncoder(), never()).resetSequenceNumber();
165+
keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31));
166+
verify(transport.getEncoder()).resetSequenceNumber();
167+
verify(transport.getDecoder(), never()).resetSequenceNumber();
168+
keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS));
169+
verify(transport.getDecoder()).resetSequenceNumber();
170+
171+
assertThat(keyExchanger.isKexDone()).isTrue();
172+
assertThat(keyExchanger.isKexOngoing()).isFalse();
173+
assertThat(keyExchanger.isStrictKex()).isTrue();
174+
assertThat(keyExchanger.isInitialKex()).isFalse();
175+
}
176+
177+
@Test
178+
void noClientFlagInSecondStrictKex() throws Exception {
179+
keyExchanger.startKex(false);
180+
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true));
181+
keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31));
182+
keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS));
183+
184+
ArgumentCaptor<SSHPacket> sshPacketCaptor = ArgumentCaptor.forClass(SSHPacket.class);
185+
when(transport.write(sshPacketCaptor.capture())).thenReturn(0L);
186+
when(transport.isAuthenticated()).thenReturn(true);
187+
188+
keyExchanger.startKex(false);
189+
190+
assertThat(keyExchanger.isKexDone()).isFalse();
191+
assertThat(keyExchanger.isKexOngoing()).isTrue();
192+
assertThat(keyExchanger.isStrictKex()).isTrue();
193+
assertThat(keyExchanger.isInitialKex()).isFalse();
194+
195+
SSHPacket sshPacket = sshPacketCaptor.getValue();
196+
List<String> kex = new Proposal(sshPacket).getKeyExchangeAlgorithms();
197+
assertThat(kex).doesNotContain("[email protected]");
198+
}
199+
200+
@Test
201+
void serverFlagIsIgnoredInSecondKex() throws Exception {
202+
keyExchanger.startKex(false);
203+
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(false));
204+
keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31));
205+
keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS));
206+
207+
ArgumentCaptor<SSHPacket> sshPacketCaptor = ArgumentCaptor.forClass(SSHPacket.class);
208+
when(transport.write(sshPacketCaptor.capture())).thenReturn(0L);
209+
when(transport.isAuthenticated()).thenReturn(true);
210+
211+
keyExchanger.startKex(false);
212+
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true));
213+
214+
assertThat(keyExchanger.isKexDone()).isFalse();
215+
assertThat(keyExchanger.isKexOngoing()).isTrue();
216+
assertThat(keyExchanger.isStrictKex()).isFalse();
217+
assertThat(keyExchanger.isInitialKex()).isFalse();
218+
219+
SSHPacket sshPacket = sshPacketCaptor.getValue();
220+
List<String> kex = new Proposal(sshPacket).getKeyExchangeAlgorithms();
221+
assertThat(kex).doesNotContain("[email protected]");
222+
}
223+
224+
private SSHPacket getKexInitPacket(boolean withServerFlag) {
225+
SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList(), true).getPacket();
226+
if (withServerFlag) {
227+
int finalWpos = kexinitPacket.wpos();
228+
kexinitPacket.wpos(22);
229+
kexinitPacket.putString("mock-kex,[email protected]");
230+
kexinitPacket.wpos(finalWpos);
231+
}
232+
kexinitPacket.rpos(kexinitPacket.rpos() + 1);
233+
return kexinitPacket;
234+
}
235+
236+
}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Copyright (C)2009 - SSHJ Contributors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package net.schmizz.sshj.transport;
17+
18+
import java.lang.reflect.Field;
19+
20+
import net.schmizz.sshj.Config;
21+
import net.schmizz.sshj.DefaultConfig;
22+
import net.schmizz.sshj.common.DisconnectReason;
23+
import net.schmizz.sshj.common.Message;
24+
import net.schmizz.sshj.common.SSHPacket;
25+
import org.junit.jupiter.api.BeforeEach;
26+
import org.junit.jupiter.api.Test;
27+
import org.junit.jupiter.params.ParameterizedTest;
28+
import org.junit.jupiter.params.provider.EnumSource;
29+
import org.junit.jupiter.params.provider.EnumSource.Mode;
30+
31+
import static org.assertj.core.api.Assertions.assertThat;
32+
import static org.assertj.core.api.Assertions.assertThatCode;
33+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
34+
import static org.mockito.Mockito.mock;
35+
import static org.mockito.Mockito.verify;
36+
import static org.mockito.Mockito.when;
37+
38+
class TransportImplStrictKeyExchangeTest {
39+
40+
private final Config config = new DefaultConfig();
41+
private final Transport transport = new TransportImpl(config);
42+
private final KeyExchanger kexer = mock(KeyExchanger.class);
43+
private final Decoder decoder = mock(Decoder.class);
44+
45+
@BeforeEach
46+
void setUp() throws Exception {
47+
Field kexerField = TransportImpl.class.getDeclaredField("kexer");
48+
kexerField.setAccessible(true);
49+
kexerField.set(transport, kexer);
50+
Field decoderField = TransportImpl.class.getDeclaredField("decoder");
51+
decoderField.setAccessible(true);
52+
decoderField.set(transport, decoder);
53+
}
54+
55+
@Test
56+
void throwExceptionOnWrapDuringInitialKex() {
57+
when(kexer.isInitialKex()).thenReturn(true);
58+
when(decoder.isSequenceNumberAtMax()).thenReturn(true);
59+
60+
assertThatExceptionOfType(TransportException.class).isThrownBy(
61+
() -> transport.handle(Message.KEXINIT, new SSHPacket(Message.KEXINIT))
62+
).satisfies(e -> {
63+
assertThat(e.getDisconnectReason()).isEqualTo(DisconnectReason.KEY_EXCHANGE_FAILED);
64+
assertThat(e.getMessage()).isEqualTo("Sequence number of decoder is about to wrap during initial key exchange");
65+
});
66+
}
67+
68+
@ParameterizedTest
69+
@EnumSource(value = Message.class, mode = Mode.EXCLUDE, names = {
70+
"DISCONNECT", "KEXINIT", "NEWKEYS", "KEXDH_INIT", "KEXDH_31", "KEX_DH_GEX_INIT", "KEX_DH_GEX_REPLY", "KEX_DH_GEX_REQUEST"
71+
})
72+
void forbidUnexpectedPacketsDuringStrictKeyExchange(Message message) {
73+
when(kexer.isInitialKex()).thenReturn(true);
74+
when(decoder.isSequenceNumberAtMax()).thenReturn(false);
75+
when(kexer.isStrictKex()).thenReturn(true);
76+
77+
assertThatExceptionOfType(TransportException.class).isThrownBy(
78+
() -> transport.handle(message, new SSHPacket(message))
79+
).satisfies(e -> {
80+
assertThat(e.getDisconnectReason()).isEqualTo(DisconnectReason.KEY_EXCHANGE_FAILED);
81+
assertThat(e.getMessage()).isEqualTo("Unexpected packet type during initial strict key exchange");
82+
});
83+
}
84+
85+
@ParameterizedTest
86+
@EnumSource(value = Message.class, mode = Mode.INCLUDE, names = {
87+
"KEXINIT", "NEWKEYS", "KEXDH_INIT", "KEXDH_31", "KEX_DH_GEX_INIT", "KEX_DH_GEX_REPLY", "KEX_DH_GEX_REQUEST"
88+
})
89+
void expectedPacketsDuringStrictKeyExchangeAreHandled(Message message) throws Exception {
90+
when(kexer.isInitialKex()).thenReturn(true);
91+
when(decoder.isSequenceNumberAtMax()).thenReturn(false);
92+
when(kexer.isStrictKex()).thenReturn(true);
93+
SSHPacket sshPacket = new SSHPacket(message);
94+
95+
assertThatCode(
96+
() -> transport.handle(message, sshPacket)
97+
).doesNotThrowAnyException();
98+
99+
verify(kexer).handle(message, sshPacket);
100+
}
101+
102+
@Test
103+
void disconnectIsAllowedDuringStrictKeyExchange() {
104+
when(kexer.isInitialKex()).thenReturn(true);
105+
when(decoder.isSequenceNumberAtMax()).thenReturn(false);
106+
when(kexer.isStrictKex()).thenReturn(true);
107+
108+
SSHPacket sshPacket = new SSHPacket();
109+
sshPacket.putUInt32(DisconnectReason.SERVICE_NOT_AVAILABLE.toInt());
110+
sshPacket.putString("service is down for maintenance");
111+
112+
assertThatExceptionOfType(TransportException.class).isThrownBy(
113+
() -> transport.handle(Message.DISCONNECT, sshPacket)
114+
).satisfies(e -> {
115+
assertThat(e.getDisconnectReason()).isEqualTo(DisconnectReason.SERVICE_NOT_AVAILABLE);
116+
assertThat(e.getMessage()).isEqualTo("service is down for maintenance");
117+
});
118+
}
119+
120+
}

0 commit comments

Comments
 (0)