Skip to content

Commit 43cbec4

Browse files
committed
Validate transport handshake from known version
With parallel releases on multiple branches it's possible that an older branch sees a transport version update that is not known to a numerically newer but chronologically older version. In that case the two nodes cannot intercommunicate, so with this commit we reject such connection attempts at the version negotiation stage. Backport of elastic#121747 to 8.x
1 parent c9f0d93 commit 43cbec4

File tree

4 files changed

+153
-19
lines changed

4 files changed

+153
-19
lines changed

server/src/main/java/org/elasticsearch/TransportVersion.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ public static TransportVersion current() {
9898
return CurrentHolder.CURRENT;
9999
}
100100

101+
/**
102+
* @return whether this is a known {@link TransportVersion}, i.e. one declared in {@link TransportVersions}. Other versions may exist
103+
* in the wild (they're sent over the wire by numeric ID) but we don't know how to communicate using such versions.
104+
*/
105+
public boolean isKnown() {
106+
return VersionsHolder.ALL_VERSIONS_MAP.containsKey(id);
107+
}
108+
101109
public static TransportVersion fromString(String str) {
102110
return TransportVersion.fromId(Integer.parseInt(str));
103111
}

server/src/main/java/org/elasticsearch/transport/TransportHandshaker.java

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,17 @@
1111

1212
import org.elasticsearch.Build;
1313
import org.elasticsearch.TransportVersion;
14-
import org.elasticsearch.TransportVersions;
1514
import org.elasticsearch.action.ActionListener;
1615
import org.elasticsearch.cluster.node.DiscoveryNode;
1716
import org.elasticsearch.common.bytes.BytesReference;
1817
import org.elasticsearch.common.io.stream.BytesStreamOutput;
1918
import org.elasticsearch.common.io.stream.StreamInput;
2019
import org.elasticsearch.common.io.stream.StreamOutput;
2120
import org.elasticsearch.common.metrics.CounterMetric;
21+
import org.elasticsearch.core.Strings;
2222
import org.elasticsearch.core.TimeValue;
23+
import org.elasticsearch.logging.LogManager;
24+
import org.elasticsearch.logging.Logger;
2325
import org.elasticsearch.threadpool.ThreadPool;
2426

2527
import java.io.EOFException;
@@ -157,6 +159,8 @@ final class TransportHandshaker {
157159
* [3] Parent task ID should be empty; see org.elasticsearch.tasks.TaskId.writeTo for its structure.
158160
*/
159161

162+
private static final Logger logger = LogManager.getLogger(TransportHandshaker.class);
163+
160164
static final TransportVersion V7_HANDSHAKE_VERSION = TransportVersion.fromId(6_08_00_99);
161165
static final TransportVersion V8_HANDSHAKE_VERSION = TransportVersion.fromId(7_17_00_99);
162166
static final TransportVersion V9_HANDSHAKE_VERSION = TransportVersion.fromId(8_800_00_0);
@@ -195,7 +199,7 @@ void sendHandshake(
195199
ActionListener<TransportVersion> listener
196200
) {
197201
numHandshakes.inc();
198-
final HandshakeResponseHandler handler = new HandshakeResponseHandler(requestId, listener);
202+
final HandshakeResponseHandler handler = new HandshakeResponseHandler(requestId, channel, listener);
199203
pendingHandshakes.put(requestId, handler);
200204
channel.addCloseListener(
201205
ActionListener.running(() -> handler.handleLocalException(new TransportException("handshake failed because connection reset")))
@@ -221,9 +225,9 @@ void sendHandshake(
221225
}
222226

223227
void handleHandshake(TransportChannel channel, long requestId, StreamInput stream) throws IOException {
228+
final HandshakeRequest handshakeRequest;
224229
try {
225-
// Must read the handshake request to exhaust the stream
226-
new HandshakeRequest(stream);
230+
handshakeRequest = new HandshakeRequest(stream);
227231
} catch (Exception e) {
228232
assert ignoreDeserializationErrors : e;
229233
throw e;
@@ -242,9 +246,44 @@ void handleHandshake(TransportChannel channel, long requestId, StreamInput strea
242246
assert ignoreDeserializationErrors : exception;
243247
throw exception;
244248
}
249+
ensureCompatibleVersion(version, handshakeRequest.transportVersion, handshakeRequest.releaseVersion, channel);
245250
channel.sendResponse(new HandshakeResponse(this.version, Build.current().version()));
246251
}
247252

253+
static void ensureCompatibleVersion(
254+
TransportVersion localTransportVersion,
255+
TransportVersion remoteTransportVersion,
256+
String releaseVersion,
257+
Object channel
258+
) {
259+
if (TransportVersion.isCompatible(remoteTransportVersion)) {
260+
if (remoteTransportVersion.onOrAfter(localTransportVersion)) {
261+
// Remote is newer than us, so we will be using our transport protocol and it's up to the other end to decide whether it
262+
// knows how to do that.
263+
return;
264+
}
265+
if (remoteTransportVersion.isKnown()) {
266+
// Remote is older than us, so we will be using its transport protocol, which we can only do if and only if its protocol
267+
// version is known to us.
268+
return;
269+
}
270+
}
271+
272+
final var message = Strings.format(
273+
"""
274+
Rejecting unreadable transport handshake from remote node with version [%s/%s] received on [%s] since this node has \
275+
version [%s/%s] which has an incompatible wire format.""",
276+
releaseVersion,
277+
remoteTransportVersion,
278+
channel,
279+
Build.current().version(),
280+
localTransportVersion
281+
);
282+
logger.warn(message);
283+
throw new IllegalStateException(message);
284+
285+
}
286+
248287
TransportResponseHandler<HandshakeResponse> removeHandlerForHandshake(long requestId) {
249288
return pendingHandshakes.remove(requestId);
250289
}
@@ -260,11 +299,13 @@ long getNumHandshakes() {
260299
private class HandshakeResponseHandler implements TransportResponseHandler<HandshakeResponse> {
261300

262301
private final long requestId;
302+
private final TcpChannel channel;
263303
private final ActionListener<TransportVersion> listener;
264304
private final AtomicBoolean isDone = new AtomicBoolean(false);
265305

266-
private HandshakeResponseHandler(long requestId, ActionListener<TransportVersion> listener) {
306+
private HandshakeResponseHandler(long requestId, TcpChannel channel, ActionListener<TransportVersion> listener) {
267307
this.requestId = requestId;
308+
this.channel = channel;
268309
this.listener = listener;
269310
}
270311

@@ -281,20 +322,13 @@ public Executor executor() {
281322
@Override
282323
public void handleResponse(HandshakeResponse response) {
283324
if (isDone.compareAndSet(false, true)) {
284-
TransportVersion responseVersion = response.transportVersion;
285-
if (TransportVersion.isCompatible(responseVersion) == false) {
286-
listener.onFailure(
287-
new IllegalStateException(
288-
"Received message from unsupported version: ["
289-
+ responseVersion
290-
+ "] minimal compatible version is: ["
291-
+ TransportVersions.MINIMUM_COMPATIBLE
292-
+ "]"
293-
)
294-
);
295-
} else {
296-
listener.onResponse(TransportVersion.min(TransportHandshaker.this.version, response.getTransportVersion()));
297-
}
325+
ActionListener.completeWith(listener, () -> {
326+
ensureCompatibleVersion(version, response.getTransportVersion(), response.getReleaseVersion(), channel);
327+
final var resultVersion = TransportVersion.min(TransportHandshaker.this.version, response.getTransportVersion());
328+
assert TransportVersion.current().before(version) // simulating a newer-version transport service for test purposes
329+
|| resultVersion.isKnown() : "negotiated unknown version " + resultVersion;
330+
return resultVersion;
331+
});
298332
}
299333
}
300334

server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,9 @@ public void testLogsSlowInboundProcessing() throws Exception {
290290
);
291291
BytesStreamOutput byteData = new BytesStreamOutput();
292292
TaskId.EMPTY_TASK_ID.writeTo(byteData);
293+
// simulate bytes of a transport handshake: vInt transport version then release version string
293294
TransportVersion.writeVersion(remoteVersion, byteData);
295+
byteData.writeString(randomIdentifier());
294296
final InboundMessage requestMessage = new InboundMessage(
295297
requestHeader,
296298
ReleasableBytesReference.wrap(byteData.bytes()),

server/src/test/java/org/elasticsearch/transport/TransportHandshakerTests.java

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88
*/
99
package org.elasticsearch.transport;
1010

11+
import org.apache.logging.log4j.Level;
12+
import org.elasticsearch.Build;
1113
import org.elasticsearch.TransportVersion;
14+
import org.elasticsearch.TransportVersions;
1215
import org.elasticsearch.Version;
16+
import org.elasticsearch.action.ActionListener;
1317
import org.elasticsearch.action.support.PlainActionFuture;
1418
import org.elasticsearch.cluster.node.DiscoveryNode;
1519
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
@@ -18,13 +22,17 @@
1822
import org.elasticsearch.core.TimeValue;
1923
import org.elasticsearch.tasks.TaskId;
2024
import org.elasticsearch.test.ESTestCase;
25+
import org.elasticsearch.test.MockLog;
2126
import org.elasticsearch.test.TransportVersionUtils;
27+
import org.elasticsearch.test.junit.annotations.TestLogging;
2228
import org.elasticsearch.threadpool.TestThreadPool;
2329

2430
import java.io.IOException;
2531
import java.util.Collections;
32+
import java.util.concurrent.ExecutionException;
2633
import java.util.concurrent.TimeUnit;
2734

35+
import static org.hamcrest.Matchers.allOf;
2836
import static org.hamcrest.Matchers.containsString;
2937
import static org.mockito.Mockito.doThrow;
3038
import static org.mockito.Mockito.mock;
@@ -91,6 +99,40 @@ public void testHandshakeRequestAndResponse() throws IOException {
9199
assertEquals(TransportVersion.current(), versionFuture.actionGet());
92100
}
93101

102+
@TestLogging(reason = "testing WARN logging", value = "org.elasticsearch.transport.TransportHandshaker:WARN")
103+
public void testIncompatibleHandshakeRequest() throws IOException {
104+
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(
105+
getRandomIncompatibleTransportVersion(),
106+
randomIdentifier()
107+
);
108+
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
109+
bytesStreamOutput.setTransportVersion(HANDSHAKE_REQUEST_VERSION);
110+
handshakeRequest.writeTo(bytesStreamOutput);
111+
StreamInput input = bytesStreamOutput.bytes().streamInput();
112+
input.setTransportVersion(HANDSHAKE_REQUEST_VERSION);
113+
final TestTransportChannel channel = new TestTransportChannel(ActionListener.running(() -> fail("should not complete")));
114+
115+
MockLog.assertThatLogger(
116+
() -> assertThat(
117+
expectThrows(IllegalStateException.class, () -> handshaker.handleHandshake(channel, randomNonNegativeLong(), input))
118+
.getMessage(),
119+
allOf(
120+
containsString("Rejecting unreadable transport handshake"),
121+
containsString("[" + handshakeRequest.releaseVersion + "/" + handshakeRequest.transportVersion + "]"),
122+
containsString("[" + Build.current().version() + "/" + TransportVersion.current() + "]"),
123+
containsString("which has an incompatible wire format")
124+
)
125+
),
126+
TransportHandshaker.class,
127+
new MockLog.SeenEventExpectation(
128+
"warning",
129+
TransportHandshaker.class.getCanonicalName(),
130+
Level.WARN,
131+
"Rejecting unreadable transport handshake * incompatible wire format."
132+
)
133+
);
134+
}
135+
94136
public void testHandshakeResponseFromOlderNode() throws Exception {
95137
final PlainActionFuture<TransportVersion> versionFuture = new PlainActionFuture<>();
96138
final long reqId = randomNonNegativeLong();
@@ -106,6 +148,54 @@ public void testHandshakeResponseFromOlderNode() throws Exception {
106148
assertEquals(remoteVersion, versionFuture.result());
107149
}
108150

151+
@TestLogging(reason = "testing WARN logging", value = "org.elasticsearch.transport.TransportHandshaker:WARN")
152+
public void testHandshakeResponseFromOlderNodeWithPatchedProtocol() {
153+
final PlainActionFuture<TransportVersion> versionFuture = new PlainActionFuture<>();
154+
final long reqId = randomNonNegativeLong();
155+
handshaker.sendHandshake(reqId, node, channel, SAFE_AWAIT_TIMEOUT, versionFuture);
156+
TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);
157+
158+
assertFalse(versionFuture.isDone());
159+
160+
final var handshakeResponse = new TransportHandshaker.HandshakeResponse(
161+
getRandomIncompatibleTransportVersion(),
162+
randomIdentifier()
163+
);
164+
165+
MockLog.assertThatLogger(
166+
() -> handler.handleResponse(handshakeResponse),
167+
TransportHandshaker.class,
168+
new MockLog.SeenEventExpectation(
169+
"warning",
170+
TransportHandshaker.class.getCanonicalName(),
171+
Level.WARN,
172+
"Rejecting unreadable transport handshake * incompatible wire format."
173+
)
174+
);
175+
176+
assertTrue(versionFuture.isDone());
177+
assertThat(
178+
expectThrows(ExecutionException.class, IllegalStateException.class, versionFuture::result).getMessage(),
179+
allOf(
180+
containsString("Rejecting unreadable transport handshake"),
181+
containsString("[" + handshakeResponse.getReleaseVersion() + "/" + handshakeResponse.getTransportVersion() + "]"),
182+
containsString("[" + Build.current().version() + "/" + TransportVersion.current() + "]"),
183+
containsString("which has an incompatible wire format")
184+
)
185+
);
186+
}
187+
188+
private static TransportVersion getRandomIncompatibleTransportVersion() {
189+
return randomBoolean()
190+
// either older than MINIMUM_COMPATIBLE
191+
? new TransportVersion(between(1, TransportVersions.MINIMUM_COMPATIBLE.id() - 1))
192+
// or between MINIMUM_COMPATIBLE and current but not known
193+
: randomValueOtherThanMany(
194+
TransportVersion::isKnown,
195+
() -> new TransportVersion(between(TransportVersions.MINIMUM_COMPATIBLE.id(), TransportVersion.current().id()))
196+
);
197+
}
198+
109199
public void testHandshakeResponseFromNewerNode() throws Exception {
110200
final PlainActionFuture<TransportVersion> versionFuture = new PlainActionFuture<>();
111201
final long reqId = randomNonNegativeLong();

0 commit comments

Comments
 (0)