diff --git a/server/src/main/java/org/elasticsearch/TransportVersion.java b/server/src/main/java/org/elasticsearch/TransportVersion.java index 6e7b0814156e1..e4f20b64a7a3d 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersion.java +++ b/server/src/main/java/org/elasticsearch/TransportVersion.java @@ -98,6 +98,15 @@ public static TransportVersion current() { return CurrentHolder.CURRENT; } + /** + * @return whether this is a known {@link TransportVersion}, i.e. one declared in {@link TransportVersions} or which dates back to + * before 8.9.0 when they matched the release versions exactly and there was no branching or patching. Other versions may exist + * in the wild (they're sent over the wire by numeric ID) but we don't know how to communicate using such versions. + */ + public boolean isKnown() { + return before(TransportVersions.V_8_9_X) || TransportVersions.VERSION_IDS.containsKey(id); + } + public static TransportVersion fromString(String str) { return TransportVersion.fromId(Integer.parseInt(str)); } diff --git a/server/src/main/java/org/elasticsearch/transport/TransportHandshaker.java b/server/src/main/java/org/elasticsearch/transport/TransportHandshaker.java index 3683b89c922a2..daf5a18db3dcd 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportHandshaker.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportHandshaker.java @@ -11,7 +11,6 @@ import org.elasticsearch.Build; import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.bytes.BytesReference; @@ -19,7 +18,10 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.metrics.CounterMetric; +import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.threadpool.ThreadPool; import java.io.EOFException; @@ -157,6 +159,8 @@ final class TransportHandshaker { * [3] Parent task ID should be empty; see org.elasticsearch.tasks.TaskId.writeTo for its structure. */ + private static final Logger logger = LogManager.getLogger(TransportHandshaker.class); + static final TransportVersion V7_HANDSHAKE_VERSION = TransportVersion.fromId(6_08_00_99); static final TransportVersion V8_HANDSHAKE_VERSION = TransportVersion.fromId(7_17_00_99); static final TransportVersion V9_HANDSHAKE_VERSION = TransportVersion.fromId(8_800_00_0); @@ -195,7 +199,7 @@ void sendHandshake( ActionListener listener ) { numHandshakes.inc(); - final HandshakeResponseHandler handler = new HandshakeResponseHandler(requestId, listener); + final HandshakeResponseHandler handler = new HandshakeResponseHandler(requestId, channel, listener); pendingHandshakes.put(requestId, handler); channel.addCloseListener( ActionListener.running(() -> handler.handleLocalException(new TransportException("handshake failed because connection reset"))) @@ -221,9 +225,9 @@ void sendHandshake( } void handleHandshake(TransportChannel channel, long requestId, StreamInput stream) throws IOException { + final HandshakeRequest handshakeRequest; try { - // Must read the handshake request to exhaust the stream - new HandshakeRequest(stream); + handshakeRequest = new HandshakeRequest(stream); } catch (Exception e) { assert ignoreDeserializationErrors : e; throw e; @@ -242,9 +246,44 @@ void handleHandshake(TransportChannel channel, long requestId, StreamInput strea assert ignoreDeserializationErrors : exception; throw exception; } + ensureCompatibleVersion(version, handshakeRequest.transportVersion, handshakeRequest.releaseVersion, channel); channel.sendResponse(new HandshakeResponse(this.version, Build.current().version())); } + static void ensureCompatibleVersion( + TransportVersion localTransportVersion, + TransportVersion remoteTransportVersion, + String releaseVersion, + Object channel + ) { + if (TransportVersion.isCompatible(remoteTransportVersion)) { + if (remoteTransportVersion.onOrAfter(localTransportVersion)) { + // Remote is newer than us, so we will be using our transport protocol and it's up to the other end to decide whether it + // knows how to do that. + return; + } + if (remoteTransportVersion.isKnown()) { + // Remote is older than us, so we will be using its transport protocol, which we can only do if and only if its protocol + // version is known to us. + return; + } + } + + final var message = Strings.format( + """ + Rejecting unreadable transport handshake from remote node with version [%s/%s] received on [%s] since this node has \ + version [%s/%s] which has an incompatible wire format.""", + releaseVersion, + remoteTransportVersion, + channel, + Build.current().version(), + localTransportVersion + ); + logger.warn(message); + throw new IllegalStateException(message); + + } + TransportResponseHandler removeHandlerForHandshake(long requestId) { return pendingHandshakes.remove(requestId); } @@ -260,11 +299,13 @@ long getNumHandshakes() { private class HandshakeResponseHandler implements TransportResponseHandler { private final long requestId; + private final TcpChannel channel; private final ActionListener listener; private final AtomicBoolean isDone = new AtomicBoolean(false); - private HandshakeResponseHandler(long requestId, ActionListener listener) { + private HandshakeResponseHandler(long requestId, TcpChannel channel, ActionListener listener) { this.requestId = requestId; + this.channel = channel; this.listener = listener; } @@ -281,20 +322,13 @@ public Executor executor() { @Override public void handleResponse(HandshakeResponse response) { if (isDone.compareAndSet(false, true)) { - TransportVersion responseVersion = response.transportVersion; - if (TransportVersion.isCompatible(responseVersion) == false) { - listener.onFailure( - new IllegalStateException( - "Received message from unsupported version: [" - + responseVersion - + "] minimal compatible version is: [" - + TransportVersions.MINIMUM_COMPATIBLE - + "]" - ) - ); - } else { - listener.onResponse(TransportVersion.min(TransportHandshaker.this.version, response.getTransportVersion())); - } + ActionListener.completeWith(listener, () -> { + ensureCompatibleVersion(version, response.getTransportVersion(), response.getReleaseVersion(), channel); + final var resultVersion = TransportVersion.min(TransportHandshaker.this.version, response.getTransportVersion()); + assert TransportVersion.current().before(version) // simulating a newer-version transport service for test purposes + || resultVersion.isKnown() : "negotiated unknown version " + resultVersion; + return resultVersion; + }); } } diff --git a/server/src/test/java/org/elasticsearch/transport/TransportHandshakerTests.java b/server/src/test/java/org/elasticsearch/transport/TransportHandshakerTests.java index fba130001c5af..97536839e5e20 100644 --- a/server/src/test/java/org/elasticsearch/transport/TransportHandshakerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TransportHandshakerTests.java @@ -8,7 +8,11 @@ */ package org.elasticsearch.transport; +import org.apache.logging.log4j.Level; +import org.elasticsearch.Build; import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; @@ -17,13 +21,17 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.MockLog; import org.elasticsearch.test.TransportVersionUtils; +import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.threadpool.TestThreadPool; import java.io.IOException; import java.util.Collections; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -90,6 +98,42 @@ public void testHandshakeRequestAndResponse() throws IOException { assertEquals(TransportVersion.current(), versionFuture.actionGet()); } + @TestLogging(reason = "testing WARN logging", value = "org.elasticsearch.transport.TransportHandshaker:WARN") + public void testIncompatibleHandshakeRequest() throws IOException { + TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest( + getRandomIncompatibleTransportVersion(), + randomIdentifier() + ); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + bytesStreamOutput.setTransportVersion(HANDSHAKE_REQUEST_VERSION); + handshakeRequest.writeTo(bytesStreamOutput); + StreamInput input = bytesStreamOutput.bytes().streamInput(); + input.setTransportVersion(HANDSHAKE_REQUEST_VERSION); + final TestTransportChannel channel = new TestTransportChannel(ActionListener.running(() -> fail("should not complete"))); + + MockLog.assertThatLogger( + () -> assertThat( + expectThrows(IllegalStateException.class, () -> handshaker.handleHandshake(channel, randomNonNegativeLong(), input)) + .getMessage(), + allOf( + containsString("Rejecting unreadable transport handshake"), + containsString( + "[" + handshakeRequest.transportVersion.toReleaseVersion() + "/" + handshakeRequest.transportVersion + "]" + ), + containsString("[" + Build.current().version() + "/" + TransportVersion.current() + "]"), + containsString("which has an incompatible wire format") + ) + ), + TransportHandshaker.class, + new MockLog.SeenEventExpectation( + "warning", + TransportHandshaker.class.getCanonicalName(), + Level.WARN, + "Rejecting unreadable transport handshake * incompatible wire format." + ) + ); + } + public void testHandshakeResponseFromOlderNode() throws Exception { final PlainActionFuture versionFuture = new PlainActionFuture<>(); final long reqId = randomNonNegativeLong(); @@ -105,6 +149,54 @@ public void testHandshakeResponseFromOlderNode() throws Exception { assertEquals(remoteVersion, versionFuture.result()); } + @TestLogging(reason = "testing WARN logging", value = "org.elasticsearch.transport.TransportHandshaker:WARN") + public void testHandshakeResponseFromOlderNodeWithPatchedProtocol() { + final PlainActionFuture versionFuture = new PlainActionFuture<>(); + final long reqId = randomNonNegativeLong(); + handshaker.sendHandshake(reqId, node, channel, SAFE_AWAIT_TIMEOUT, versionFuture); + TransportResponseHandler handler = handshaker.removeHandlerForHandshake(reqId); + + assertFalse(versionFuture.isDone()); + + final var handshakeResponse = new TransportHandshaker.HandshakeResponse( + getRandomIncompatibleTransportVersion(), + randomIdentifier() + ); + + MockLog.assertThatLogger( + () -> handler.handleResponse(handshakeResponse), + TransportHandshaker.class, + new MockLog.SeenEventExpectation( + "warning", + TransportHandshaker.class.getCanonicalName(), + Level.WARN, + "Rejecting unreadable transport handshake * incompatible wire format." + ) + ); + + assertTrue(versionFuture.isDone()); + assertThat( + expectThrows(ExecutionException.class, IllegalStateException.class, versionFuture::result).getMessage(), + allOf( + containsString("Rejecting unreadable transport handshake"), + containsString("[" + handshakeResponse.getReleaseVersion() + "/" + handshakeResponse.getTransportVersion() + "]"), + containsString("[" + Build.current().version() + "/" + TransportVersion.current() + "]"), + containsString("which has an incompatible wire format") + ) + ); + } + + private static TransportVersion getRandomIncompatibleTransportVersion() { + return randomBoolean() + // either older than MINIMUM_COMPATIBLE + ? new TransportVersion(between(1, TransportVersions.MINIMUM_COMPATIBLE.id() - 1)) + // or between MINIMUM_COMPATIBLE and current but not known + : randomValueOtherThanMany( + TransportVersion::isKnown, + () -> new TransportVersion(between(TransportVersions.MINIMUM_COMPATIBLE.id(), TransportVersion.current().id())) + ); + } + public void testHandshakeResponseFromNewerNode() throws Exception { final PlainActionFuture versionFuture = new PlainActionFuture<>(); final long reqId = randomNonNegativeLong();