diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index 8351e2bcf7f42..83767a1f52521 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -66,7 +66,7 @@ abstract class AbstractSearchAsyncAction extends SearchPhase { protected static final float DEFAULT_INDEX_BOOST = 1.0f; private final Logger logger; - private final NamedWriteableRegistry namedWriteableRegistry; + protected final NamedWriteableRegistry namedWriteableRegistry; protected final SearchTransportService searchTransportService; private final Executor executor; private final ActionListener listener; diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index 0db16c2960dd7..665da77a3c50b 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -23,7 +23,11 @@ import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.compress.CompressorFactory; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.OutputStreamStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -31,8 +35,8 @@ import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.ListenableFuture; -import org.elasticsearch.core.RefCounted; -import org.elasticsearch.core.SimpleRefCounted; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Streams; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.search.SearchPhaseResult; @@ -51,7 +55,7 @@ import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.AbstractTransportRequest; -import org.elasticsearch.transport.LeakTracker; +import org.elasticsearch.transport.BytesTransportResponse; import org.elasticsearch.transport.SendRequestTransportException; import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportActionProxy; @@ -59,6 +63,7 @@ import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportResponseHandler; +import org.elasticsearch.transport.TransportService; import java.io.IOException; import java.util.ArrayList; @@ -198,91 +203,6 @@ protected SearchPhase getNextPhase() { return nextPhase(client, this, results, null); } - /** - * Response to a query phase request, holding per-shard results that have been partially reduced as well as - * the partial reduce result. - */ - public static final class NodeQueryResponse extends TransportResponse { - - private final RefCounted refCounted = LeakTracker.wrap(new SimpleRefCounted()); - - private final Object[] results; - private final SearchPhaseController.TopDocsStats topDocsStats; - private final QueryPhaseResultConsumer.MergeResult mergeResult; - - NodeQueryResponse(StreamInput in) throws IOException { - this.results = in.readArray(i -> i.readBoolean() ? new QuerySearchResult(i) : i.readException(), Object[]::new); - this.mergeResult = QueryPhaseResultConsumer.MergeResult.readFrom(in); - this.topDocsStats = SearchPhaseController.TopDocsStats.readFrom(in); - } - - NodeQueryResponse( - QueryPhaseResultConsumer.MergeResult mergeResult, - Object[] results, - SearchPhaseController.TopDocsStats topDocsStats - ) { - this.results = results; - for (Object result : results) { - if (result instanceof QuerySearchResult r) { - r.incRef(); - } - } - this.mergeResult = mergeResult; - this.topDocsStats = topDocsStats; - assert Arrays.stream(results).noneMatch(Objects::isNull) : Arrays.toString(results); - } - - // public for tests - public Object[] getResults() { - return results; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeArray((o, v) -> { - if (v instanceof Exception e) { - o.writeBoolean(false); - o.writeException(e); - } else { - o.writeBoolean(true); - assert v instanceof QuerySearchResult : v; - ((QuerySearchResult) v).writeTo(o); - } - }, results); - mergeResult.writeTo(out); - topDocsStats.writeTo(out); - } - - @Override - public void incRef() { - refCounted.incRef(); - } - - @Override - public boolean tryIncRef() { - return refCounted.tryIncRef(); - } - - @Override - public boolean hasReferences() { - return refCounted.hasReferences(); - } - - @Override - public boolean decRef() { - if (refCounted.decRef()) { - for (int i = 0; i < results.length; i++) { - if (results[i] instanceof QuerySearchResult r) { - r.decRef(); - } - results[i] = null; - } - return true; - } - return false; - } - } - /** * Request for starting the query phase for multiple shards. */ @@ -465,60 +385,82 @@ protected void doRun(Map shardIndexMap) { return; } searchTransportService.transportService() - .sendChildRequest(connection, NODE_SEARCH_ACTION_NAME, request, task, new TransportResponseHandler() { - @Override - public NodeQueryResponse read(StreamInput in) throws IOException { - return new NodeQueryResponse(in); - } - - @Override - public Executor executor() { - return EsExecutors.DIRECT_EXECUTOR_SERVICE; - } + .sendChildRequest( + connection, + NODE_SEARCH_ACTION_NAME, + request, + task, + new TransportResponseHandler() { + @Override + public BytesTransportResponse read(StreamInput in) throws IOException { + return new BytesTransportResponse(in); + } - @Override - public void handleResponse(NodeQueryResponse response) { - if (results instanceof QueryPhaseResultConsumer queryPhaseResultConsumer) { - queryPhaseResultConsumer.addBatchedPartialResult(response.topDocsStats, response.mergeResult); + @Override + public Executor executor() { + return EsExecutors.DIRECT_EXECUTOR_SERVICE; } - for (int i = 0; i < response.results.length; i++) { - var s = request.shards.get(i); - int shardIdx = s.shardIndex; - final SearchShardTarget target = new SearchShardTarget(routing.nodeId(), s.shardId, routing.clusterAlias()); - switch (response.results[i]) { - case Exception e -> onShardFailure(shardIdx, target, shardIterators[shardIdx], e); - case SearchPhaseResult q -> { - q.setShardIndex(shardIdx); - q.setSearchShardTarget(target); - onShardResult(q); + + @Override + public void handleResponse(BytesTransportResponse bytesTransportResponse) { + try ( + var decompressedIn = CompressorFactory.COMPRESSOR.threadLocalStreamInput( + bytesTransportResponse.bytes().streamInput() + ) + ) { + var in = new NamedWriteableAwareStreamInput(decompressedIn, namedWriteableRegistry); + in.setTransportVersion(bytesTransportResponse.version()); + var mergeResult = QueryPhaseResultConsumer.MergeResult.readFrom(in); + var topDocsStats = SearchPhaseController.TopDocsStats.readFrom(in); + if (results instanceof QueryPhaseResultConsumer queryPhaseResultConsumer) { + queryPhaseResultConsumer.addBatchedPartialResult(topDocsStats, mergeResult); } - case null, default -> { - assert false : "impossible [" + response.results[i] + "]"; + final int shardCount = request.shards.size(); + for (int i = 0; i < shardCount; i++) { + var s = request.shards.get(i); + int shardIdx = s.shardIndex; + final SearchShardTarget target = new SearchShardTarget( + routing.nodeId(), + s.shardId, + routing.clusterAlias() + ); + if (in.readBoolean()) { + var q = new QuerySearchResult(in); + q.setShardIndex(shardIdx); + q.setSearchShardTarget(target); + onShardResult(q); + } else { + onShardFailure(shardIdx, target, shardIterators[shardIdx], in.readException()); + } } + } catch (IOException e) { + assert false : new AssertionError("No real IO here this is a serialization bug", e); + handleException(new TransportException(e)); } } - } - @Override - public void handleException(TransportException e) { - Exception cause = (Exception) ExceptionsHelper.unwrapCause(e); - if (e instanceof SendRequestTransportException || cause instanceof TaskCancelledException) { - // two possible special cases here where we do not want to fail the phase: - // failure to send out the request -> handle things the same way a shard would fail with unbatched execution - // as this could be a transient failure and partial results we may have are still valid - // cancellation of the whole batched request on the remote -> maybe we timed out or so, partial results may - // still be valid - onNodeQueryFailure(e, request, routing); - } else { - // Remote failure that wasn't due to networking or cancellation means that the data node was unable to reduce - // its local results. Failure to reduce always fails the phase without exception so we fail the phase here. - if (results instanceof QueryPhaseResultConsumer queryPhaseResultConsumer) { - queryPhaseResultConsumer.failure.compareAndSet(null, cause); + @Override + public void handleException(TransportException e) { + Exception cause = (Exception) ExceptionsHelper.unwrapCause(e); + if (e instanceof SendRequestTransportException || cause instanceof TaskCancelledException) { + // two possible special cases here where we do not want to fail the phase: + // failure to send out the request -> handle things the same way a shard would fail with unbatched execution + // as this could be a transient failure and partial results we may have are still valid + // cancellation of the whole batched request on the remote -> maybe we timed out or so, partial results may + // still be valid + onNodeQueryFailure(e, request, routing); + } else { + // Remote failure that wasn't due to networking or cancellation means that the data node was unable to + // reduce + // its local results. Failure to reduce always fails the phase without exception so we fail the phase here. + if (results instanceof QueryPhaseResultConsumer queryPhaseResultConsumer) { + queryPhaseResultConsumer.failure.compareAndSet(null, cause); + } + onPhaseFailure(getName(), "", cause); } - onPhaseFailure(getName(), "", cause); } } - }); + ); }); } @@ -553,7 +495,7 @@ static void registerNodeSearchAction( ) { var transportService = searchTransportService.transportService(); var threadPool = transportService.getThreadPool(); - final Dependencies dependencies = new Dependencies(searchService, threadPool.executor(ThreadPool.Names.SEARCH)); + final Dependencies dependencies = new Dependencies(searchService, threadPool.executor(ThreadPool.Names.SEARCH), transportService); // Even though not all searches run on the search pool, we use the search pool size as the upper limit of shards to execute in // parallel to keep the implementation simple instead of working out the exact pool(s) a query will use up-front. final int searchPoolMax = threadPool.info(ThreadPool.Names.SEARCH).getMax(); @@ -587,7 +529,7 @@ static void registerNodeSearchAction( } } ); - TransportActionProxy.registerProxyAction(transportService, NODE_SEARCH_ACTION_NAME, true, NodeQueryResponse::new); + TransportActionProxy.registerProxyAction(transportService, NODE_SEARCH_ACTION_NAME, true, BytesTransportResponse::new); } private static void releaseLocalContext(SearchService searchService, NodeQueryRequest request, SearchPhaseResult result) { @@ -716,7 +658,7 @@ public void onFailure(Exception e) { } } - private record Dependencies(SearchService searchService, Executor executor) {} + private record Dependencies(SearchService searchService, Executor executor, TransportService transportService) {} private static final class QueryPerNodeState { @@ -762,58 +704,73 @@ void onShardDone() { return; } var channelListener = new ChannelActionListener<>(channel); + var out = dependencies.transportService.newNetworkBytesStream(); try (queryPhaseResultConsumer) { var failure = queryPhaseResultConsumer.failure.get(); if (failure != null) { handleMergeFailure(failure, channelListener); return; } - final QueryPhaseResultConsumer.MergeResult mergeResult; - try { - mergeResult = Objects.requireNonNullElse( - queryPhaseResultConsumer.consumePartialMergeResultDataNode(), - EMPTY_PARTIAL_MERGE_RESULT - ); - } catch (Exception e) { - handleMergeFailure(e, channelListener); - return; - } - // translate shard indices to those on the coordinator so that it can interpret the merge result without adjustments, - // also collect the set of indices that may be part of a subsequent fetch operation here so that we can release all other - // indices without a roundtrip to the coordinating node - final BitSet relevantShardIndices = new BitSet(searchRequest.shards.size()); - for (ScoreDoc scoreDoc : mergeResult.reducedTopDocs().scoreDocs) { - final int localIndex = scoreDoc.shardIndex; - scoreDoc.shardIndex = searchRequest.shards.get(localIndex).shardIndex; - relevantShardIndices.set(localIndex); - } - final Object[] results = new Object[queryPhaseResultConsumer.getNumShards()]; - for (int i = 0; i < results.length; i++) { - var result = queryPhaseResultConsumer.results.get(i); - if (result == null) { - results[i] = failures.get(i); - } else { - // free context id and remove it from the result right away in case we don't need it anymore - if (result instanceof QuerySearchResult q - && q.getContextId() != null - && relevantShardIndices.get(q.getShardIndex()) == false - && q.hasSuggestHits() == false - && q.getRankShardResult() == null - && searchRequest.searchRequest.scroll() == null - && isPartOfPIT(searchRequest.searchRequest, q.getContextId()) == false) { - if (dependencies.searchService.freeReaderContext(q.getContextId())) { - q.clearContextId(); + try ( + var compressedOut = new OutputStreamStreamOutput( + CompressorFactory.COMPRESSOR.threadLocalOutputStream(Streams.noCloseStream(out)) + ) + ) { + final QueryPhaseResultConsumer.MergeResult mergeResult; + try { + mergeResult = Objects.requireNonNullElse( + queryPhaseResultConsumer.consumePartialMergeResultDataNode(), + EMPTY_PARTIAL_MERGE_RESULT + ); + } catch (Exception e) { + handleMergeFailure(e, channelListener); + return; + } + // translate shard indices to those on the coordinator so that it can interpret the merge result without adjustments, + // also collect the set of indices that may be part of a subsequent fetch operation here so that we can release all + // other + // indices without a roundtrip to the coordinating node + final int shardCount = searchRequest.shards.size(); + final BitSet relevantShardIndices = new BitSet(shardCount); + for (ScoreDoc scoreDoc : mergeResult.reducedTopDocs().scoreDocs) { + final int localIndex = scoreDoc.shardIndex; + scoreDoc.shardIndex = searchRequest.shards.get(localIndex).shardIndex; + relevantShardIndices.set(localIndex); + } + mergeResult.writeTo(compressedOut); + queryPhaseResultConsumer.topDocsStats.writeTo(compressedOut); + for (int i = 0; i < shardCount; i++) { + var result = queryPhaseResultConsumer.results.get(i); + if (result == null) { + compressedOut.writeBoolean(false); + compressedOut.writeException(failures.get(i)); + } else { + // free context id and remove it from the result right away in case we don't need it anymore + if (result instanceof QuerySearchResult q + && q.getContextId() != null + && relevantShardIndices.get(q.getShardIndex()) == false + && q.hasSuggestHits() == false + && q.getRankShardResult() == null + && searchRequest.searchRequest.scroll() == null + && isPartOfPIT(searchRequest.searchRequest, q.getContextId()) == false) { + if (dependencies.searchService.freeReaderContext(q.getContextId())) { + q.clearContextId(); + } } + compressedOut.writeBoolean(true); + result.writeTo(compressedOut); } - results[i] = result; } - assert results[i] != null; } - ActionListener.respondAndRelease( - channelListener, - new NodeQueryResponse(mergeResult, results, queryPhaseResultConsumer.topDocsStats) - ); + var response = new BytesTransportResponse(new ReleasableBytesReference(out.bytes(), out), channel.getVersion()); + out = null; + ActionListener.respondAndRelease(channelListener, response); + } catch (IOException e) { + assert false : e; + channelListener.onFailure(e); + } finally { + Releasables.close(out); } } diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/ByteBufferStreamInput.java b/server/src/main/java/org/elasticsearch/common/io/stream/ByteBufferStreamInput.java index 9b32342897217..b59e79c0a10c3 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/ByteBufferStreamInput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/ByteBufferStreamInput.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; import java.io.EOFException; import java.io.IOException; @@ -289,4 +290,16 @@ public boolean markSupported() { @Override public void close() throws IOException {} + + @Override + public boolean supportReadAllToReleasableBytesReference() { + return true; + } + + @Override + public ReleasableBytesReference readAllToReleasableBytesReference() { + final byte[] res = new byte[buffer.remaining()]; + buffer.get(res); + return ReleasableBytesReference.wrap(new BytesArray(res)); + } } diff --git a/server/src/main/java/org/elasticsearch/transport/BytesTransportMessage.java b/server/src/main/java/org/elasticsearch/transport/BytesTransportMessage.java new file mode 100644 index 0000000000000..f7356aa5b2658 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/transport/BytesTransportMessage.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.transport; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.io.stream.StreamOutput; + +import java.io.IOException; + +public interface BytesTransportMessage { + + TransportVersion version(); + + ReleasableBytesReference bytes(); + + /** + * Writes the data in a "thin" manner, without the actual bytes, assumes + * the actual bytes will be appended right after this content. + */ + void writeThin(StreamOutput out) throws IOException; +} diff --git a/server/src/main/java/org/elasticsearch/transport/BytesTransportRequest.java b/server/src/main/java/org/elasticsearch/transport/BytesTransportRequest.java index a16411cb25f1d..475d1e4f2d2bd 100644 --- a/server/src/main/java/org/elasticsearch/transport/BytesTransportRequest.java +++ b/server/src/main/java/org/elasticsearch/transport/BytesTransportRequest.java @@ -10,7 +10,6 @@ package org.elasticsearch.transport; import org.elasticsearch.TransportVersion; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -21,7 +20,7 @@ * A specialized, bytes only request, that can potentially be optimized on the network * layer, specifically for the same large buffer send to several nodes. */ -public class BytesTransportRequest extends AbstractTransportRequest { +public class BytesTransportRequest extends AbstractTransportRequest implements BytesTransportMessage { final ReleasableBytesReference bytes; private final TransportVersion version; @@ -37,18 +36,17 @@ public BytesTransportRequest(ReleasableBytesReference bytes, TransportVersion ve this.version = version; } + @Override public TransportVersion version() { return this.version; } - public BytesReference bytes() { + @Override + public ReleasableBytesReference bytes() { return this.bytes; } - /** - * Writes the data in a "thin" manner, without the actual bytes, assumes - * the actual bytes will be appended right after this content. - */ + @Override public void writeThin(StreamOutput out) throws IOException { super.writeTo(out); out.writeVInt(bytes.length()); diff --git a/server/src/main/java/org/elasticsearch/transport/BytesTransportResponse.java b/server/src/main/java/org/elasticsearch/transport/BytesTransportResponse.java new file mode 100644 index 0000000000000..876bccc9071fb --- /dev/null +++ b/server/src/main/java/org/elasticsearch/transport/BytesTransportResponse.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.transport; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; + +import java.io.IOException; + +/** + * A specialized, bytes only response, that can potentially be optimized on the network layer. + */ +public class BytesTransportResponse extends TransportResponse implements BytesTransportMessage { + + private final ReleasableBytesReference bytes; + private final TransportVersion version; + + public BytesTransportResponse(StreamInput in) throws IOException { + bytes = in.readAllToReleasableBytesReference(); + version = in.getTransportVersion(); + } + + public BytesTransportResponse(ReleasableBytesReference bytes, TransportVersion version) { + this.bytes = bytes; + this.version = version; + } + + @Override + public TransportVersion version() { + return this.version; + } + + @Override + public ReleasableBytesReference bytes() { + return this.bytes; + } + + @Override + public void writeThin(StreamOutput out) throws IOException {} + + @Override + public void writeTo(StreamOutput out) throws IOException { + bytes.writeTo(out); + } + + @Override + public void incRef() { + bytes.incRef(); + } + + @Override + public boolean tryIncRef() { + return bytes.tryIncRef(); + } + + @Override + public boolean decRef() { + return bytes.decRef(); + } + + @Override + public boolean hasReferences() { + return bytes.hasReferences(); + } +} diff --git a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java index a83d1019e7c64..c5041c4e75d17 100644 --- a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java +++ b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java @@ -227,7 +227,7 @@ private void sendMessage( Releasable onAfter ) throws IOException { assert action != null; - final var compressionScheme = writeable instanceof BytesTransportRequest ? null : possibleCompressionScheme; + final var compressionScheme = writeable instanceof BytesTransportMessage ? null : possibleCompressionScheme; final BytesReference message; boolean serializeSuccess = false; final RecyclerBytesStreamOutput byteStreamOutput = new RecyclerBytesStreamOutput(recycler); @@ -334,11 +334,11 @@ private static BytesReference serializeMessageBody( final ReleasableBytesReference zeroCopyBuffer; try { stream.setTransportVersion(version); - if (writeable instanceof BytesTransportRequest bRequest) { + if (writeable instanceof BytesTransportMessage bRequest) { assert stream == byteStreamOutput; assert compressionScheme == null; bRequest.writeThin(stream); - zeroCopyBuffer = bRequest.bytes; + zeroCopyBuffer = bRequest.bytes(); } else if (writeable instanceof RemoteTransportException remoteTransportException) { stream.writeException(remoteTransportException); zeroCopyBuffer = ReleasableBytesReference.empty();