Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/135873.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 135873
summary: Convert `BytesTransportResponse` when proxying response from/to local node
area: "Network"
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,13 @@ static void registerNodeSearchAction(
}
}
);
TransportActionProxy.registerProxyAction(transportService, NODE_SEARCH_ACTION_NAME, true, NodeQueryResponse::new);
TransportActionProxy.registerProxyAction(
transportService,
NODE_SEARCH_ACTION_NAME,
true,
NodeQueryResponse::new,
namedWriteableRegistry
);
}

private static void releaseLocalContext(
Expand Down Expand Up @@ -845,7 +851,10 @@ void onShardDone() {
out.close();
}
}
ActionListener.respondAndRelease(channelListener, new BytesTransportResponse(out.moveToBytesReference()));
ActionListener.respondAndRelease(
channelListener,
new BytesTransportResponse(out.moveToBytesReference(), out.getTransportVersion())
);
}

private void maybeFreeContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand Down Expand Up @@ -384,7 +385,11 @@ public void writeTo(StreamOutput out) throws IOException {
}
}

public static void registerRequestHandler(TransportService transportService, SearchService searchService) {
public static void registerRequestHandler(
TransportService transportService,
SearchService searchService,
NamedWriteableRegistry namedWriteableRegistry
) {
final TransportRequestHandler<ScrollFreeContextRequest> freeContextHandler = (request, channel, task) -> {
logger.trace("releasing search context [{}]", request.id());
boolean freed = searchService.freeReaderContext(request.id());
Expand All @@ -401,7 +406,8 @@ public static void registerRequestHandler(TransportService transportService, Sea
transportService,
FREE_CONTEXT_SCROLL_ACTION_NAME,
false,
SearchFreeContextResponse::readFrom
SearchFreeContextResponse::readFrom,
namedWriteableRegistry
);

// TODO: remove this handler once the lowest compatible version stops using it
Expand All @@ -411,7 +417,13 @@ public static void registerRequestHandler(TransportService transportService, Sea
OriginalIndices.readOriginalIndices(in);
return res;
}, freeContextHandler);
TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_ACTION_NAME, false, SearchFreeContextResponse::readFrom);
TransportActionProxy.registerProxyAction(
transportService,
FREE_CONTEXT_ACTION_NAME,
false,
SearchFreeContextResponse::readFrom,
namedWriteableRegistry
);

transportService.registerRequestHandler(
CLEAR_SCROLL_CONTEXTS_ACTION_NAME,
Expand All @@ -426,7 +438,8 @@ public static void registerRequestHandler(TransportService transportService, Sea
transportService,
CLEAR_SCROLL_CONTEXTS_ACTION_NAME,
false,
(in) -> ActionResponse.Empty.INSTANCE
(in) -> ActionResponse.Empty.INSTANCE,
namedWriteableRegistry
);

transportService.registerRequestHandler(
Expand All @@ -435,7 +448,7 @@ public static void registerRequestHandler(TransportService transportService, Sea
ShardSearchRequest::new,
(request, channel, task) -> searchService.executeDfsPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel))
);
TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, true, DfsSearchResult::new);
TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, true, DfsSearchResult::new, namedWriteableRegistry);

transportService.registerRequestHandler(
QUERY_ACTION_NAME,
Expand All @@ -451,7 +464,8 @@ public static void registerRequestHandler(TransportService transportService, Sea
transportService,
QUERY_ACTION_NAME,
true,
(request) -> ((ShardSearchRequest) request).numberOfShards() == 1 ? QueryFetchSearchResult::new : QuerySearchResult::new
(request) -> ((ShardSearchRequest) request).numberOfShards() == 1 ? QueryFetchSearchResult::new : QuerySearchResult::new,
namedWriteableRegistry
);

transportService.registerRequestHandler(
Expand All @@ -465,7 +479,13 @@ public static void registerRequestHandler(TransportService transportService, Sea
channel.getVersion()
)
);
TransportActionProxy.registerProxyAction(transportService, QUERY_ID_ACTION_NAME, true, QuerySearchResult::new);
TransportActionProxy.registerProxyAction(
transportService,
QUERY_ID_ACTION_NAME,
true,
QuerySearchResult::new,
namedWriteableRegistry
);

transportService.registerRequestHandler(
QUERY_SCROLL_ACTION_NAME,
Expand All @@ -478,7 +498,13 @@ public static void registerRequestHandler(TransportService transportService, Sea
channel.getVersion()
)
);
TransportActionProxy.registerProxyAction(transportService, QUERY_SCROLL_ACTION_NAME, true, ScrollQuerySearchResult::new);
TransportActionProxy.registerProxyAction(
transportService,
QUERY_SCROLL_ACTION_NAME,
true,
ScrollQuerySearchResult::new,
namedWriteableRegistry
);

transportService.registerRequestHandler(
QUERY_FETCH_SCROLL_ACTION_NAME,
Expand All @@ -490,7 +516,13 @@ public static void registerRequestHandler(TransportService transportService, Sea
new ChannelActionListener<>(channel)
)
);
TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, true, ScrollQueryFetchSearchResult::new);
TransportActionProxy.registerProxyAction(
transportService,
QUERY_FETCH_SCROLL_ACTION_NAME,
true,
ScrollQueryFetchSearchResult::new,
namedWriteableRegistry
);

final TransportRequestHandler<RankFeatureShardRequest> rankShardFeatureRequest = (request, channel, task) -> searchService
.executeRankFeaturePhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel));
Expand All @@ -500,7 +532,13 @@ public static void registerRequestHandler(TransportService transportService, Sea
RankFeatureShardRequest::new,
rankShardFeatureRequest
);
TransportActionProxy.registerProxyAction(transportService, RANK_FEATURE_SHARD_ACTION_NAME, true, RankFeatureResult::new);
TransportActionProxy.registerProxyAction(
transportService,
RANK_FEATURE_SHARD_ACTION_NAME,
true,
RankFeatureResult::new,
namedWriteableRegistry
);

final TransportRequestHandler<ShardFetchRequest> shardFetchRequestHandler = (request, channel, task) -> searchService
.executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel));
Expand All @@ -510,7 +548,13 @@ public static void registerRequestHandler(TransportService transportService, Sea
ShardFetchRequest::new,
shardFetchRequestHandler
);
TransportActionProxy.registerProxyAction(transportService, FETCH_ID_SCROLL_ACTION_NAME, true, FetchSearchResult::new);
TransportActionProxy.registerProxyAction(
transportService,
FETCH_ID_SCROLL_ACTION_NAME,
true,
FetchSearchResult::new,
namedWriteableRegistry
);

transportService.registerRequestHandler(
FETCH_ID_ACTION_NAME,
Expand All @@ -520,15 +564,27 @@ public static void registerRequestHandler(TransportService transportService, Sea
ShardFetchSearchRequest::new,
shardFetchRequestHandler
);
TransportActionProxy.registerProxyAction(transportService, FETCH_ID_ACTION_NAME, true, FetchSearchResult::new);
TransportActionProxy.registerProxyAction(
transportService,
FETCH_ID_ACTION_NAME,
true,
FetchSearchResult::new,
namedWriteableRegistry
);

transportService.registerRequestHandler(
QUERY_CAN_MATCH_NODE_NAME,
transportService.getThreadPool().executor(ThreadPool.Names.SEARCH_COORDINATION),
CanMatchNodeRequest::new,
(request, channel, task) -> searchService.canMatch(request, new ChannelActionListener<>(channel))
);
TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NODE_NAME, true, CanMatchNodeResponse::new);
TransportActionProxy.registerProxyAction(
transportService,
QUERY_CAN_MATCH_NODE_NAME,
true,
CanMatchNodeResponse::new,
namedWriteableRegistry
);
}

private static Executor buildFreeContextExecutor(TransportService transportService) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,13 @@ public TransportOpenPointInTimeAction(
ShardOpenReaderRequest::new,
new ShardOpenReaderRequestHandler()
);
TransportActionProxy.registerProxyAction(transportService, OPEN_SHARD_READER_CONTEXT_NAME, false, ShardOpenReaderResponse::new);
TransportActionProxy.registerProxyAction(
transportService,
OPEN_SHARD_READER_CONTEXT_NAME,
false,
ShardOpenReaderResponse::new,
namedWriteableRegistry
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ public TransportSearchAction(
this.searchPhaseController = searchPhaseController;
this.searchTransportService = searchTransportService;
this.remoteClusterService = searchTransportService.getRemoteClusterService();
SearchTransportService.registerRequestHandler(transportService, searchService);
SearchTransportService.registerRequestHandler(transportService, searchService, namedWriteableRegistry);
SearchQueryThenFetchAsyncAction.registerNodeSearchAction(
searchTransportService,
searchService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,41 @@

package org.elasticsearch.transport;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;

import java.io.IOException;
import java.util.Objects;

/**
* A specialized, bytes only response, that can potentially be optimized on the network layer.
*/
public class BytesTransportResponse extends TransportResponse implements BytesTransportMessage {

private final ReleasableBytesReference bytes;
private final TransportVersion version;

public BytesTransportResponse(ReleasableBytesReference bytes) {
public BytesTransportResponse(ReleasableBytesReference bytes, TransportVersion version) {
this.bytes = bytes;
this.version = Objects.requireNonNull(version);
}

/**
* Does the binary response need conversion before being sent to the provided target version?
*/
public boolean mustConvertResponseForVersion(TransportVersion targetVersion) {
return version.equals(targetVersion) == false;
}

/**
* Returns a {@link StreamInput} configured to read the underlying bytes that this response holds.
*/
public StreamInput streamInput() throws IOException {
StreamInput streamInput = bytes.streamInput();
streamInput.setTransportVersion(version);
return streamInput;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
package org.elasticsearch.transport;

import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand All @@ -18,6 +20,7 @@
import org.elasticsearch.tasks.TaskId;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.function.Function;
Expand All @@ -36,15 +39,18 @@ private static class ProxyRequestHandler<T extends ProxyRequest<TransportRequest
private final TransportService service;
private final String action;
private final Function<TransportRequest, Writeable.Reader<? extends TransportResponse>> responseFunction;
private final NamedWriteableRegistry namedWriteableRegistry;

ProxyRequestHandler(
TransportService service,
String action,
Function<TransportRequest, Writeable.Reader<? extends TransportResponse>> responseFunction
Function<TransportRequest, Writeable.Reader<? extends TransportResponse>> responseFunction,
NamedWriteableRegistry namedWriteableRegistry
) {
this.service = service;
this.action = action;
this.responseFunction = responseFunction;
this.namedWriteableRegistry = namedWriteableRegistry;
}

@Override
Expand All @@ -62,6 +68,20 @@ public Executor executor() {

@Override
public void handleResponse(TransportResponse response) {
// This is a short term solution to ensure data node responses for batched search go back to the coordinating
// node in the expected format when a proxy data node proxies the request to itself. The response would otherwise
// be sent directly via DirectResponseChannel, skipping the read and write step that this handler normally performs.
if (response instanceof BytesTransportResponse btr && btr.mustConvertResponseForVersion(channel.getVersion())) {
try {
NamedWriteableAwareStreamInput in = new NamedWriteableAwareStreamInput(
btr.streamInput(),
namedWriteableRegistry
);
response = responseFunction.apply(wrappedRequest).read(in);
} catch (IOException e) {
throw new UncheckedIOException(e);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should I throw or send the exception as a response? I am assuming that exceptions thrown are already caught and sent as responses, but good to double check.

}
}
channel.sendResponse(response);
}

Expand All @@ -73,7 +93,7 @@ public void handleException(TransportException exp) {
@Override
public TransportResponse read(StreamInput in) throws IOException {
if (in.getTransportVersion().equals(channel.getVersion()) && in.supportReadAllToReleasableBytesReference()) {
return new BytesTransportResponse(in.readAllToReleasableBytesReference());
return new BytesTransportResponse(in.readAllToReleasableBytesReference(), channel.getVersion());
} else {
return responseFunction.apply(wrappedRequest).read(in);
}
Expand Down Expand Up @@ -144,7 +164,9 @@ public static void registerProxyActionWithDynamicResponseType(
TransportService service,
String action,
boolean cancellable,
Function<TransportRequest, Writeable.Reader<? extends TransportResponse>> responseFunction
Function<TransportRequest, Writeable.Reader<? extends TransportResponse>> responseFunction,
NamedWriteableRegistry namedWriteableRegistry

) {
RequestHandlerRegistry<? extends TransportRequest> requestHandler = service.getRequestHandler(action);
service.registerRequestHandler(
Expand All @@ -155,7 +177,7 @@ public static void registerProxyActionWithDynamicResponseType(
in -> cancellable
? new CancellableProxyRequest<>(in, requestHandler::newRequest)
: new ProxyRequest<>(in, requestHandler::newRequest),
new ProxyRequestHandler<>(service, action, responseFunction)
new ProxyRequestHandler<>(service, action, responseFunction, namedWriteableRegistry)
);
}

Expand All @@ -167,9 +189,10 @@ public static void registerProxyAction(
TransportService service,
String action,
boolean cancellable,
Writeable.Reader<? extends TransportResponse> reader
Writeable.Reader<? extends TransportResponse> reader,
NamedWriteableRegistry namedWriteableRegistry
) {
registerProxyActionWithDynamicResponseType(service, action, cancellable, request -> reader);
registerProxyActionWithDynamicResponseType(service, action, cancellable, request -> reader, namedWriteableRegistry);
}

private static final String PROXY_ACTION_PREFIX = "internal:transport/proxy/";
Expand Down
Loading