Skip to content
9 changes: 9 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersion.java
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ public static TransportVersion current() {
return CurrentHolder.CURRENT;
}

/**
* @return whether this is a known {@link TransportVersion}, i.e. one declared in {@link TransportVersions} or which dates back to
* before 8.9.0 when they matched the release versions exactly and there was no branching or patching. Other versions may exist
* in the wild (they're sent over the wire by numeric ID) but we don't know how to communicate using such versions.
*/
public boolean isKnown() {
return before(TransportVersions.V_8_9_X) || TransportVersions.VERSION_IDS.containsKey(id);
}

public static TransportVersion fromString(String str) {
return TransportVersion.fromId(Integer.parseInt(str));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@

import org.elasticsearch.Build;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.metrics.CounterMetric;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.threadpool.ThreadPool;

import java.io.EOFException;
Expand Down Expand Up @@ -157,6 +159,8 @@ final class TransportHandshaker {
* [3] Parent task ID should be empty; see org.elasticsearch.tasks.TaskId.writeTo for its structure.
*/

private static final Logger logger = LogManager.getLogger(TransportHandshaker.class);

static final TransportVersion V7_HANDSHAKE_VERSION = TransportVersion.fromId(6_08_00_99);
static final TransportVersion V8_HANDSHAKE_VERSION = TransportVersion.fromId(7_17_00_99);
static final TransportVersion V9_HANDSHAKE_VERSION = TransportVersion.fromId(8_800_00_0);
Expand Down Expand Up @@ -195,7 +199,7 @@ void sendHandshake(
ActionListener<TransportVersion> listener
) {
numHandshakes.inc();
final HandshakeResponseHandler handler = new HandshakeResponseHandler(requestId, listener);
final HandshakeResponseHandler handler = new HandshakeResponseHandler(requestId, channel, listener);
pendingHandshakes.put(requestId, handler);
channel.addCloseListener(
ActionListener.running(() -> handler.handleLocalException(new TransportException("handshake failed because connection reset")))
Expand All @@ -221,9 +225,9 @@ void sendHandshake(
}

void handleHandshake(TransportChannel channel, long requestId, StreamInput stream) throws IOException {
final HandshakeRequest handshakeRequest;
try {
// Must read the handshake request to exhaust the stream
new HandshakeRequest(stream);
handshakeRequest = new HandshakeRequest(stream);
} catch (Exception e) {
assert ignoreDeserializationErrors : e;
throw e;
Expand All @@ -242,9 +246,44 @@ void handleHandshake(TransportChannel channel, long requestId, StreamInput strea
assert ignoreDeserializationErrors : exception;
throw exception;
}
ensureCompatibleVersion(version, handshakeRequest.transportVersion, handshakeRequest.releaseVersion, channel);
channel.sendResponse(new HandshakeResponse(this.version, Build.current().version()));
}

static void ensureCompatibleVersion(
TransportVersion localTransportVersion,
TransportVersion remoteTransportVersion,
String releaseVersion,
Object channel
) {
if (TransportVersion.isCompatible(remoteTransportVersion)) {
if (remoteTransportVersion.onOrAfter(localTransportVersion)) {
// Remote is newer than us, so we will be using our transport protocol and it's up to the other end to decide whether it
// knows how to do that.
return;
}
if (remoteTransportVersion.isKnown()) {
// Remote is older than us, so we will be using its transport protocol, which we can only do if and only if its protocol
// version is known to us.
return;
}
}

final var message = Strings.format(
"""
Rejecting unreadable transport handshake from remote node with version [%s/%s] received on [%s] since this node has \
version [%s/%s] which has an incompatible wire format.""",
releaseVersion,
remoteTransportVersion,
channel,
Build.current().version(),
localTransportVersion
);
logger.warn(message);
throw new IllegalStateException(message);

}

TransportResponseHandler<HandshakeResponse> removeHandlerForHandshake(long requestId) {
return pendingHandshakes.remove(requestId);
}
Expand All @@ -260,11 +299,13 @@ long getNumHandshakes() {
private class HandshakeResponseHandler implements TransportResponseHandler<HandshakeResponse> {

private final long requestId;
private final TcpChannel channel;
private final ActionListener<TransportVersion> listener;
private final AtomicBoolean isDone = new AtomicBoolean(false);

private HandshakeResponseHandler(long requestId, ActionListener<TransportVersion> listener) {
private HandshakeResponseHandler(long requestId, TcpChannel channel, ActionListener<TransportVersion> listener) {
this.requestId = requestId;
this.channel = channel;
this.listener = listener;
}

Expand All @@ -281,20 +322,13 @@ public Executor executor() {
@Override
public void handleResponse(HandshakeResponse response) {
if (isDone.compareAndSet(false, true)) {
TransportVersion responseVersion = response.transportVersion;
if (TransportVersion.isCompatible(responseVersion) == false) {
listener.onFailure(
new IllegalStateException(
"Received message from unsupported version: ["
+ responseVersion
+ "] minimal compatible version is: ["
+ TransportVersions.MINIMUM_COMPATIBLE
+ "]"
)
);
} else {
listener.onResponse(TransportVersion.min(TransportHandshaker.this.version, response.getTransportVersion()));
}
ActionListener.completeWith(listener, () -> {
ensureCompatibleVersion(version, response.getTransportVersion(), response.getReleaseVersion(), channel);
final var resultVersion = TransportVersion.min(TransportHandshaker.this.version, response.getTransportVersion());
assert TransportVersion.current().before(version) // simulating a newer-version transport service for test purposes
|| resultVersion.isKnown() : "negotiated unknown version " + resultVersion;
return resultVersion;
});
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
*/
package org.elasticsearch.transport;

import org.apache.logging.log4j.Level;
import org.elasticsearch.Build;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
Expand All @@ -17,13 +21,17 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.MockLog;
import org.elasticsearch.test.TransportVersionUtils;
import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.threadpool.TestThreadPool;

import java.io.IOException;
import java.util.Collections;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -90,6 +98,42 @@ public void testHandshakeRequestAndResponse() throws IOException {
assertEquals(TransportVersion.current(), versionFuture.actionGet());
}

@TestLogging(reason = "testing WARN logging", value = "org.elasticsearch.transport.TransportHandshaker:WARN")
public void testIncompatibleHandshakeRequest() throws IOException {
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(
getRandomIncompatibleTransportVersion(),
randomIdentifier()
);
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
bytesStreamOutput.setTransportVersion(HANDSHAKE_REQUEST_VERSION);
handshakeRequest.writeTo(bytesStreamOutput);
StreamInput input = bytesStreamOutput.bytes().streamInput();
input.setTransportVersion(HANDSHAKE_REQUEST_VERSION);
final TestTransportChannel channel = new TestTransportChannel(ActionListener.running(() -> fail("should not complete")));

MockLog.assertThatLogger(
() -> assertThat(
expectThrows(IllegalStateException.class, () -> handshaker.handleHandshake(channel, randomNonNegativeLong(), input))
.getMessage(),
allOf(
containsString("Rejecting unreadable transport handshake"),
containsString(
"[" + handshakeRequest.transportVersion.toReleaseVersion() + "/" + handshakeRequest.transportVersion + "]"
),
containsString("[" + Build.current().version() + "/" + TransportVersion.current() + "]"),
containsString("which has an incompatible wire format")
)
),
TransportHandshaker.class,
new MockLog.SeenEventExpectation(
"warning",
TransportHandshaker.class.getCanonicalName(),
Level.WARN,
"Rejecting unreadable transport handshake * incompatible wire format."
)
);
}

public void testHandshakeResponseFromOlderNode() throws Exception {
final PlainActionFuture<TransportVersion> versionFuture = new PlainActionFuture<>();
final long reqId = randomNonNegativeLong();
Expand All @@ -105,6 +149,54 @@ public void testHandshakeResponseFromOlderNode() throws Exception {
assertEquals(remoteVersion, versionFuture.result());
}

@TestLogging(reason = "testing WARN logging", value = "org.elasticsearch.transport.TransportHandshaker:WARN")
public void testHandshakeResponseFromOlderNodeWithPatchedProtocol() {
final PlainActionFuture<TransportVersion> versionFuture = new PlainActionFuture<>();
final long reqId = randomNonNegativeLong();
handshaker.sendHandshake(reqId, node, channel, SAFE_AWAIT_TIMEOUT, versionFuture);
TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);

assertFalse(versionFuture.isDone());

final var handshakeResponse = new TransportHandshaker.HandshakeResponse(
getRandomIncompatibleTransportVersion(),
randomIdentifier()
);

MockLog.assertThatLogger(
() -> handler.handleResponse(handshakeResponse),
TransportHandshaker.class,
new MockLog.SeenEventExpectation(
"warning",
TransportHandshaker.class.getCanonicalName(),
Level.WARN,
"Rejecting unreadable transport handshake * incompatible wire format."
)
);

assertTrue(versionFuture.isDone());
assertThat(
expectThrows(ExecutionException.class, IllegalStateException.class, versionFuture::result).getMessage(),
allOf(
containsString("Rejecting unreadable transport handshake"),
containsString("[" + handshakeResponse.getReleaseVersion() + "/" + handshakeResponse.getTransportVersion() + "]"),
containsString("[" + Build.current().version() + "/" + TransportVersion.current() + "]"),
containsString("which has an incompatible wire format")
)
);
}

private static TransportVersion getRandomIncompatibleTransportVersion() {
return randomBoolean()
// either older than MINIMUM_COMPATIBLE
? new TransportVersion(between(1, TransportVersions.MINIMUM_COMPATIBLE.id() - 1))
// or between MINIMUM_COMPATIBLE and current but not known
: randomValueOtherThanMany(
TransportVersion::isKnown,
() -> new TransportVersion(between(TransportVersions.MINIMUM_COMPATIBLE.id(), TransportVersion.current().id()))
);
}

public void testHandshakeResponseFromNewerNode() throws Exception {
final PlainActionFuture<TransportVersion> versionFuture = new PlainActionFuture<>();
final long reqId = randomNonNegativeLong();
Expand Down