Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersion.java
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ public static List<TransportVersion> getAllVersions() {
return VersionsHolder.ALL_VERSIONS;
}

/**
* @return whether the given {@code id} corresponds to a known {@link TransportVersion}.
*/
public static boolean isKnownVersionId(int id) {
return VersionsHolder.ALL_VERSIONS_MAP.containsKey(id);
}

public static TransportVersion fromString(String str) {
return TransportVersion.fromId(Integer.parseInt(str));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@

import org.elasticsearch.Build;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.metrics.CounterMetric;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.UpdateForV9;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.threadpool.ThreadPool;

import java.io.EOFException;
Expand Down Expand Up @@ -126,6 +128,8 @@ final class TransportHandshaker {
* [3] Parent task ID should be empty; see org.elasticsearch.tasks.TaskId.writeTo for its structure.
*/

private static final Logger logger = LogManager.getLogger(TransportHandshaker.class);

static final TransportVersion V8_HANDSHAKE_VERSION = TransportVersion.fromId(7_17_00_99);
static final TransportVersion V9_HANDSHAKE_VERSION = TransportVersion.fromId(8_800_00_0);
static final Set<TransportVersion> ALLOWED_HANDSHAKE_VERSIONS = Set.of(V8_HANDSHAKE_VERSION, V9_HANDSHAKE_VERSION);
Expand Down Expand Up @@ -159,7 +163,7 @@ void sendHandshake(
ActionListener<TransportVersion> listener
) {
numHandshakes.inc();
final HandshakeResponseHandler handler = new HandshakeResponseHandler(requestId, listener);
final HandshakeResponseHandler handler = new HandshakeResponseHandler(requestId, channel, listener);
pendingHandshakes.put(requestId, handler);
channel.addCloseListener(
ActionListener.running(() -> handler.handleLocalException(new TransportException("handshake failed because connection reset")))
Expand All @@ -185,9 +189,9 @@ void sendHandshake(
}

void handleHandshake(TransportChannel channel, long requestId, StreamInput stream) throws IOException {
final HandshakeRequest handshakeRequest;
try {
// Must read the handshake request to exhaust the stream
new HandshakeRequest(stream);
handshakeRequest = new HandshakeRequest(stream);
} catch (Exception e) {
assert ignoreDeserializationErrors : e;
throw e;
Expand All @@ -206,9 +210,44 @@ void handleHandshake(TransportChannel channel, long requestId, StreamInput strea
assert ignoreDeserializationErrors : exception;
throw exception;
}
ensureCompatibleVersion(version, handshakeRequest.transportVersion, handshakeRequest.releaseVersion, channel);
channel.sendResponse(new HandshakeResponse(this.version, Build.current().version()));
}

static void ensureCompatibleVersion(
TransportVersion localTransportVersion,
TransportVersion remoteTransportVersion,
String releaseVersion,
Object channel
) {
if (TransportVersion.isCompatible(remoteTransportVersion)) {
if (remoteTransportVersion.onOrAfter(localTransportVersion)) {
// Remote is newer than us, so we will be using our transport protocol and it's up to the other end to decide whether it
// knows how to do that.
return;
}
if (TransportVersion.isKnownVersionId(remoteTransportVersion.id())) {
// Remote is older than us, so we will be using its transport protocol, which we can only do if and only if its protocol
// version is known to us.
return;
}
}

final var message = Strings.format(
"""
Rejecting unreadable transport handshake from remote node with version [%s/%s] received on [%s] since this node has \
version [%s/%s] which has an incompatible wire format.""",
releaseVersion,
remoteTransportVersion,
channel,
Build.current().version(),
localTransportVersion
);
logger.warn(message);
throw new IllegalStateException(message);

}

TransportResponseHandler<HandshakeResponse> removeHandlerForHandshake(long requestId) {
return pendingHandshakes.remove(requestId);
}
Expand All @@ -224,11 +263,13 @@ long getNumHandshakes() {
private class HandshakeResponseHandler implements TransportResponseHandler<HandshakeResponse> {

private final long requestId;
private final TcpChannel channel;
private final ActionListener<TransportVersion> listener;
private final AtomicBoolean isDone = new AtomicBoolean(false);

private HandshakeResponseHandler(long requestId, ActionListener<TransportVersion> listener) {
private HandshakeResponseHandler(long requestId, TcpChannel channel, ActionListener<TransportVersion> listener) {
this.requestId = requestId;
this.channel = channel;
this.listener = listener;
}

Expand All @@ -245,20 +286,13 @@ public Executor executor() {
@Override
public void handleResponse(HandshakeResponse response) {
if (isDone.compareAndSet(false, true)) {
TransportVersion responseVersion = response.transportVersion;
if (TransportVersion.isCompatible(responseVersion) == false) {
listener.onFailure(
new IllegalStateException(
"Received message from unsupported version: ["
+ responseVersion
+ "] minimal compatible version is: ["
+ TransportVersions.MINIMUM_COMPATIBLE
+ "]"
)
);
} else {
listener.onResponse(TransportVersion.min(TransportHandshaker.this.version, response.getTransportVersion()));
}
ActionListener.completeWith(listener, () -> {
ensureCompatibleVersion(version, response.getTransportVersion(), response.getReleaseVersion(), channel);
final var resultVersion = TransportVersion.min(TransportHandshaker.this.version, response.getTransportVersion());
assert TransportVersion.current().before(version) // simulating a newer-version transport service for test purposes
|| TransportVersion.isKnownVersionId(resultVersion.id()) : "negotiated unknown version " + resultVersion;
return resultVersion;
});
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,9 @@ public void testLogsSlowInboundProcessing() throws Exception {
);
BytesStreamOutput byteData = new BytesStreamOutput();
TaskId.EMPTY_TASK_ID.writeTo(byteData);
// simulate bytes of a transport handshake: vInt transport version then release version string
TransportVersion.writeVersion(remoteVersion, byteData);
byteData.writeString(randomIdentifier());
final InboundMessage requestMessage = new InboundMessage(
requestHeader,
ReleasableBytesReference.wrap(byteData.bytes()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
*/
package org.elasticsearch.transport;

import org.apache.logging.log4j.Level;
import org.elasticsearch.Build;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
Expand All @@ -19,13 +23,17 @@
import org.elasticsearch.core.UpdateForV10;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.MockLog;
import org.elasticsearch.test.TransportVersionUtils;
import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.threadpool.TestThreadPool;

import java.io.IOException;
import java.util.Collections;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -93,6 +101,40 @@ public void testHandshakeRequestAndResponse() throws IOException {
assertEquals(TransportVersion.current(), versionFuture.actionGet());
}

@TestLogging(reason = "testing WARN logging", value = "org.elasticsearch.transport.TransportHandshaker:WARN")
public void testIncompatibleHandshakeRequest() throws IOException {
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(
getRandomIncompatibleTransportVersion(),
randomIdentifier()
);
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
bytesStreamOutput.setTransportVersion(HANDSHAKE_REQUEST_VERSION);
handshakeRequest.writeTo(bytesStreamOutput);
StreamInput input = bytesStreamOutput.bytes().streamInput();
input.setTransportVersion(HANDSHAKE_REQUEST_VERSION);
final TestTransportChannel channel = new TestTransportChannel(ActionListener.running(() -> fail("should not complete")));

MockLog.assertThatLogger(
() -> assertThat(
expectThrows(IllegalStateException.class, () -> handshaker.handleHandshake(channel, randomNonNegativeLong(), input))
.getMessage(),
allOf(
containsString("Rejecting unreadable transport handshake"),
containsString("[" + handshakeRequest.releaseVersion + "/" + handshakeRequest.transportVersion + "]"),
containsString("[" + Build.current().version() + "/" + TransportVersion.current() + "]"),
containsString("which has an incompatible wire format")
)
),
TransportHandshaker.class,
new MockLog.SeenEventExpectation(
"warning",
TransportHandshaker.class.getCanonicalName(),
Level.WARN,
"Rejecting unreadable transport handshake * incompatible wire format."
)
);
}

public void testHandshakeResponseFromOlderNode() throws Exception {
final PlainActionFuture<TransportVersion> versionFuture = new PlainActionFuture<>();
final long reqId = randomNonNegativeLong();
Expand All @@ -108,6 +150,59 @@ public void testHandshakeResponseFromOlderNode() throws Exception {
assertEquals(remoteVersion, versionFuture.result());
}

@TestLogging(reason = "testing WARN logging", value = "org.elasticsearch.transport.TransportHandshaker:WARN")
public void testHandshakeResponseFromOlderNodeWithPatchedProtocol() {
final PlainActionFuture<TransportVersion> versionFuture = new PlainActionFuture<>();
final long reqId = randomNonNegativeLong();
handshaker.sendHandshake(reqId, node, channel, SAFE_AWAIT_TIMEOUT, versionFuture);
TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);

assertFalse(versionFuture.isDone());

final var handshakeResponse = new TransportHandshaker.HandshakeResponse(
randomValueOtherThanMany(
v -> TransportVersion.isKnownVersionId(v.id()),
TransportHandshakerTests::getRandomIncompatibleTransportVersion
),
randomIdentifier()
);

MockLog.assertThatLogger(
() -> handler.handleResponse(handshakeResponse),
TransportHandshaker.class,
new MockLog.SeenEventExpectation(
"warning",
TransportHandshaker.class.getCanonicalName(),
Level.WARN,
"Rejecting unreadable transport handshake * incompatible wire format."
)
);

assertTrue(versionFuture.isDone());
assertThat(
expectThrows(ExecutionException.class, IllegalStateException.class, versionFuture::result).getMessage(),
allOf(
containsString("Rejecting unreadable transport handshake"),
containsString("[" + handshakeResponse.getReleaseVersion() + "/" + handshakeResponse.getTransportVersion() + "]"),
containsString("[" + Build.current().version() + "/" + TransportVersion.current() + "]"),
containsString("which has an incompatible wire format")
)
);
}

private static TransportVersion getRandomIncompatibleTransportVersion() {
return new TransportVersion(
randomBoolean()
// either older than MINIMUM_COMPATIBLE
? between(1, TransportVersions.MINIMUM_COMPATIBLE.id() - 1)
// or between MINIMUM_COMPATIBLE and current but not known
: randomValueOtherThanMany(
TransportVersion::isKnownVersionId,
() -> between(TransportVersions.MINIMUM_COMPATIBLE.id(), TransportVersion.current().id())
)
);
}

public void testHandshakeResponseFromNewerNode() throws Exception {
final PlainActionFuture<TransportVersion> versionFuture = new PlainActionFuture<>();
final long reqId = randomNonNegativeLong();
Expand Down