Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersion.java
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ public static List<TransportVersion> getAllVersions() {
return VersionsHolder.ALL_VERSIONS;
}

/**
* @return whether this is a known {@link TransportVersion}, i.e. one declared in {@link TransportVersions}. Other versions may exist
* in the wild (they're sent over the wire by numeric ID) but we don't know how to communicate using such versions.
*/
public boolean isKnown() {
return VersionsHolder.ALL_VERSIONS_MAP.containsKey(id);
}

public static TransportVersion fromString(String str) {
return TransportVersion.fromId(Integer.parseInt(str));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@

import org.elasticsearch.Build;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.metrics.CounterMetric;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.UpdateForV9;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.threadpool.ThreadPool;

import java.io.EOFException;
Expand Down Expand Up @@ -126,6 +128,8 @@ final class TransportHandshaker {
* [3] Parent task ID should be empty; see org.elasticsearch.tasks.TaskId.writeTo for its structure.
*/

private static final Logger logger = LogManager.getLogger(TransportHandshaker.class);

static final TransportVersion V8_HANDSHAKE_VERSION = TransportVersion.fromId(7_17_00_99);
static final TransportVersion V9_HANDSHAKE_VERSION = TransportVersion.fromId(8_800_00_0);
static final Set<TransportVersion> ALLOWED_HANDSHAKE_VERSIONS = Set.of(V8_HANDSHAKE_VERSION, V9_HANDSHAKE_VERSION);
Expand Down Expand Up @@ -159,7 +163,7 @@ void sendHandshake(
ActionListener<TransportVersion> listener
) {
numHandshakes.inc();
final HandshakeResponseHandler handler = new HandshakeResponseHandler(requestId, listener);
final HandshakeResponseHandler handler = new HandshakeResponseHandler(requestId, channel, listener);
pendingHandshakes.put(requestId, handler);
channel.addCloseListener(
ActionListener.running(() -> handler.handleLocalException(new TransportException("handshake failed because connection reset")))
Expand All @@ -185,9 +189,9 @@ void sendHandshake(
}

void handleHandshake(TransportChannel channel, long requestId, StreamInput stream) throws IOException {
final HandshakeRequest handshakeRequest;
try {
// Must read the handshake request to exhaust the stream
new HandshakeRequest(stream);
handshakeRequest = new HandshakeRequest(stream);
} catch (Exception e) {
assert ignoreDeserializationErrors : e;
throw e;
Expand All @@ -206,9 +210,44 @@ void handleHandshake(TransportChannel channel, long requestId, StreamInput strea
assert ignoreDeserializationErrors : exception;
throw exception;
}
ensureCompatibleVersion(version, handshakeRequest.transportVersion, handshakeRequest.releaseVersion, channel);
channel.sendResponse(new HandshakeResponse(this.version, Build.current().version()));
}

static void ensureCompatibleVersion(
TransportVersion localTransportVersion,
TransportVersion remoteTransportVersion,
String releaseVersion,
Object channel
) {
if (TransportVersion.isCompatible(remoteTransportVersion)) {
if (remoteTransportVersion.onOrAfter(localTransportVersion)) {
// Remote is newer than us, so we will be using our transport protocol and it's up to the other end to decide whether it
// knows how to do that.
return;
}
if (remoteTransportVersion.isKnown()) {
// Remote is older than us, so we will be using its transport protocol, which we can only do if and only if its protocol
// version is known to us.
return;
}
}

final var message = Strings.format(
"""
Rejecting unreadable transport handshake from remote node with version [%s/%s] received on [%s] since this node has \
version [%s/%s] which has an incompatible wire format.""",
releaseVersion,
remoteTransportVersion,
channel,
Build.current().version(),
localTransportVersion
);
logger.warn(message);
throw new IllegalStateException(message);

}

TransportResponseHandler<HandshakeResponse> removeHandlerForHandshake(long requestId) {
return pendingHandshakes.remove(requestId);
}
Expand All @@ -224,11 +263,13 @@ long getNumHandshakes() {
private class HandshakeResponseHandler implements TransportResponseHandler<HandshakeResponse> {

private final long requestId;
private final TcpChannel channel;
private final ActionListener<TransportVersion> listener;
private final AtomicBoolean isDone = new AtomicBoolean(false);

private HandshakeResponseHandler(long requestId, ActionListener<TransportVersion> listener) {
private HandshakeResponseHandler(long requestId, TcpChannel channel, ActionListener<TransportVersion> listener) {
this.requestId = requestId;
this.channel = channel;
this.listener = listener;
}

Expand All @@ -245,20 +286,13 @@ public Executor executor() {
@Override
public void handleResponse(HandshakeResponse response) {
if (isDone.compareAndSet(false, true)) {
TransportVersion responseVersion = response.transportVersion;
if (TransportVersion.isCompatible(responseVersion) == false) {
listener.onFailure(
new IllegalStateException(
"Received message from unsupported version: ["
+ responseVersion
+ "] minimal compatible version is: ["
+ TransportVersions.MINIMUM_COMPATIBLE
+ "]"
)
);
} else {
listener.onResponse(TransportVersion.min(TransportHandshaker.this.version, response.getTransportVersion()));
}
ActionListener.completeWith(listener, () -> {
ensureCompatibleVersion(version, response.getTransportVersion(), response.getReleaseVersion(), channel);
final var resultVersion = TransportVersion.min(TransportHandshaker.this.version, response.getTransportVersion());
assert TransportVersion.current().before(version) // simulating a newer-version transport service for test purposes
|| resultVersion.isKnown() : "negotiated unknown version " + resultVersion;
return resultVersion;
});
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,9 @@ public void testLogsSlowInboundProcessing() throws Exception {
);
BytesStreamOutput byteData = new BytesStreamOutput();
TaskId.EMPTY_TASK_ID.writeTo(byteData);
// simulate bytes of a transport handshake: vInt transport version then release version string
TransportVersion.writeVersion(remoteVersion, byteData);
byteData.writeString(randomIdentifier());
final InboundMessage requestMessage = new InboundMessage(
requestHeader,
ReleasableBytesReference.wrap(byteData.bytes()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
*/
package org.elasticsearch.transport;

import org.apache.logging.log4j.Level;
import org.elasticsearch.Build;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
Expand All @@ -19,13 +23,17 @@
import org.elasticsearch.core.UpdateForV10;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.MockLog;
import org.elasticsearch.test.TransportVersionUtils;
import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.threadpool.TestThreadPool;

import java.io.IOException;
import java.util.Collections;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -93,6 +101,40 @@ public void testHandshakeRequestAndResponse() throws IOException {
assertEquals(TransportVersion.current(), versionFuture.actionGet());
}

@TestLogging(reason = "testing WARN logging", value = "org.elasticsearch.transport.TransportHandshaker:WARN")
public void testIncompatibleHandshakeRequest() throws IOException {
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(
getRandomIncompatibleTransportVersion(),
randomIdentifier()
);
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
bytesStreamOutput.setTransportVersion(HANDSHAKE_REQUEST_VERSION);
handshakeRequest.writeTo(bytesStreamOutput);
StreamInput input = bytesStreamOutput.bytes().streamInput();
input.setTransportVersion(HANDSHAKE_REQUEST_VERSION);
final TestTransportChannel channel = new TestTransportChannel(ActionListener.running(() -> fail("should not complete")));

MockLog.assertThatLogger(
() -> assertThat(
expectThrows(IllegalStateException.class, () -> handshaker.handleHandshake(channel, randomNonNegativeLong(), input))
.getMessage(),
allOf(
containsString("Rejecting unreadable transport handshake"),
containsString("[" + handshakeRequest.releaseVersion + "/" + handshakeRequest.transportVersion + "]"),
containsString("[" + Build.current().version() + "/" + TransportVersion.current() + "]"),
containsString("which has an incompatible wire format")
)
),
TransportHandshaker.class,
new MockLog.SeenEventExpectation(
"warning",
TransportHandshaker.class.getCanonicalName(),
Level.WARN,
"Rejecting unreadable transport handshake * incompatible wire format."
)
);
}

public void testHandshakeResponseFromOlderNode() throws Exception {
final PlainActionFuture<TransportVersion> versionFuture = new PlainActionFuture<>();
final long reqId = randomNonNegativeLong();
Expand All @@ -108,6 +150,54 @@ public void testHandshakeResponseFromOlderNode() throws Exception {
assertEquals(remoteVersion, versionFuture.result());
}

@TestLogging(reason = "testing WARN logging", value = "org.elasticsearch.transport.TransportHandshaker:WARN")
public void testHandshakeResponseFromOlderNodeWithPatchedProtocol() {
final PlainActionFuture<TransportVersion> versionFuture = new PlainActionFuture<>();
final long reqId = randomNonNegativeLong();
handshaker.sendHandshake(reqId, node, channel, SAFE_AWAIT_TIMEOUT, versionFuture);
TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);

assertFalse(versionFuture.isDone());

final var handshakeResponse = new TransportHandshaker.HandshakeResponse(
getRandomIncompatibleTransportVersion(),
randomIdentifier()
);

MockLog.assertThatLogger(
() -> handler.handleResponse(handshakeResponse),
TransportHandshaker.class,
new MockLog.SeenEventExpectation(
"warning",
TransportHandshaker.class.getCanonicalName(),
Level.WARN,
"Rejecting unreadable transport handshake * incompatible wire format."
)
);

assertTrue(versionFuture.isDone());
assertThat(
expectThrows(ExecutionException.class, IllegalStateException.class, versionFuture::result).getMessage(),
allOf(
containsString("Rejecting unreadable transport handshake"),
containsString("[" + handshakeResponse.getReleaseVersion() + "/" + handshakeResponse.getTransportVersion() + "]"),
containsString("[" + Build.current().version() + "/" + TransportVersion.current() + "]"),
containsString("which has an incompatible wire format")
)
);
}

private static TransportVersion getRandomIncompatibleTransportVersion() {
return randomBoolean()
// either older than MINIMUM_COMPATIBLE
? new TransportVersion(between(1, TransportVersions.MINIMUM_COMPATIBLE.id() - 1))
// or between MINIMUM_COMPATIBLE and current but not known
: randomValueOtherThanMany(
TransportVersion::isKnown,
() -> new TransportVersion(between(TransportVersions.MINIMUM_COMPATIBLE.id(), TransportVersion.current().id()))
);
}

public void testHandshakeResponseFromNewerNode() throws Exception {
final PlainActionFuture<TransportVersion> versionFuture = new PlainActionFuture<>();
final long reqId = randomNonNegativeLong();
Expand Down