diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index e552d9c9606c8..39e1c30f658d8 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -23,7 +23,9 @@ import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -50,6 +52,7 @@ import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.AbstractTransportRequest; +import org.elasticsearch.transport.BytesTransportResponse; import org.elasticsearch.transport.LeakTracker; import org.elasticsearch.transport.SendRequestTransportException; import org.elasticsearch.transport.Transport; @@ -58,6 +61,7 @@ import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportResponseHandler; +import org.elasticsearch.transport.TransportService; import java.io.IOException; import java.util.ArrayList; @@ -215,22 +219,6 @@ public static final class NodeQueryResponse extends TransportResponse { this.topDocsStats = SearchPhaseController.TopDocsStats.readFrom(in); } - NodeQueryResponse( - QueryPhaseResultConsumer.MergeResult mergeResult, - Object[] results, - SearchPhaseController.TopDocsStats topDocsStats - ) { - this.results = results; - for (Object result : results) { - if (result instanceof QuerySearchResult r) { - r.incRef(); - } - } - this.mergeResult = mergeResult; - this.topDocsStats = topDocsStats; - assert Arrays.stream(results).noneMatch(Objects::isNull) : Arrays.toString(results); - } - // public for tests public Object[] getResults() { return results; @@ -238,18 +226,15 @@ public Object[] getResults() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeArray((o, v) -> { - if (v instanceof Exception e) { - o.writeBoolean(false); - o.writeException(e); + out.writeVInt(results.length); + for (Object result : results) { + if (result instanceof Exception e) { + writePerShardException(out, e); } else { - o.writeBoolean(true); - assert v instanceof QuerySearchResult : v; - ((QuerySearchResult) v).writeTo(o); + writePerShardResult(out, (QuerySearchResult) result); } - }, results); - mergeResult.writeTo(out); - topDocsStats.writeTo(out); + } + writeMergeResult(out, mergeResult, topDocsStats); } @Override @@ -280,6 +265,25 @@ public boolean decRef() { } return false; } + + private static void writeMergeResult( + StreamOutput out, + QueryPhaseResultConsumer.MergeResult mergeResult, + SearchPhaseController.TopDocsStats topDocsStats + ) throws IOException { + mergeResult.writeTo(out); + topDocsStats.writeTo(out); + } + + private static void writePerShardException(StreamOutput o, Exception e) throws IOException { + o.writeBoolean(false); + o.writeException(e); + } + + private static void writePerShardResult(StreamOutput out, SearchPhaseResult result) throws IOException { + out.writeBoolean(true); + result.writeTo(out); + } } /** @@ -552,7 +556,7 @@ static void registerNodeSearchAction( ) { var transportService = searchTransportService.transportService(); var threadPool = transportService.getThreadPool(); - final Dependencies dependencies = new Dependencies(searchService, threadPool.executor(ThreadPool.Names.SEARCH)); + final Dependencies dependencies = new Dependencies(searchService, transportService, threadPool.executor(ThreadPool.Names.SEARCH)); // Even though not all searches run on the search pool, we use the search pool size as the upper limit of shards to execute in // parallel to keep the implementation simple instead of working out the exact pool(s) a query will use up-front. final int searchPoolMax = threadPool.info(ThreadPool.Names.SEARCH).getMax(); @@ -715,7 +719,7 @@ public void onFailure(Exception e) { } } - private record Dependencies(SearchService searchService, Executor executor) {} + private record Dependencies(SearchService searchService, TransportService transportService, Executor executor) {} private static final class QueryPerNodeState { @@ -760,6 +764,8 @@ void onShardDone() { if (countDown.countDown() == false) { return; } + RecyclerBytesStreamOutput out = null; + boolean success = false; var channelListener = new ChannelActionListener<>(channel); try (queryPhaseResultConsumer) { var failure = queryPhaseResultConsumer.failure.get(); @@ -788,33 +794,46 @@ void onShardDone() { relevantShardIndices.set(localIndex); } } - final Object[] results = new Object[queryPhaseResultConsumer.getNumShards()]; - for (int i = 0; i < results.length; i++) { - var result = queryPhaseResultConsumer.results.get(i); - if (result == null) { - results[i] = failures.get(i); - } else { - // free context id and remove it from the result right away in case we don't need it anymore - if (result instanceof QuerySearchResult q - && q.getContextId() != null - && relevantShardIndices.get(q.getShardIndex()) == false - && q.hasSuggestHits() == false - && q.getRankShardResult() == null - && searchRequest.searchRequest.scroll() == null - && isPartOfPIT(searchRequest.searchRequest, q.getContextId()) == false) { - if (dependencies.searchService.freeReaderContext(q.getContextId())) { - q.clearContextId(); - } + final int resultCount = queryPhaseResultConsumer.getNumShards(); + out = dependencies.transportService.newNetworkBytesStream(); + out.setTransportVersion(channel.getVersion()); + try { + out.writeVInt(resultCount); + for (int i = 0; i < resultCount; i++) { + var result = queryPhaseResultConsumer.results.get(i); + if (result == null) { + NodeQueryResponse.writePerShardException(out, failures.remove(i)); + } else { + // free context id and remove it from the result right away in case we don't need it anymore + maybeFreeContext(result, relevantShardIndices); + NodeQueryResponse.writePerShardResult(out, result); } - results[i] = result; } - assert results[i] != null; + NodeQueryResponse.writeMergeResult(out, mergeResult, queryPhaseResultConsumer.topDocsStats); + success = true; + } catch (IOException e) { + handleMergeFailure(e, channelListener); + return; } + } finally { + if (success == false && out != null) { + out.close(); + } + } + ActionListener.respondAndRelease(channelListener, new BytesTransportResponse(new ReleasableBytesReference(out.bytes(), out))); + } - ActionListener.respondAndRelease( - channelListener, - new NodeQueryResponse(mergeResult, results, queryPhaseResultConsumer.topDocsStats) - ); + private void maybeFreeContext(SearchPhaseResult result, BitSet relevantShardIndices) { + if (result instanceof QuerySearchResult q + && q.getContextId() != null + && relevantShardIndices.get(q.getShardIndex()) == false + && q.hasSuggestHits() == false + && q.getRankShardResult() == null + && searchRequest.searchRequest.scroll() == null + && isPartOfPIT(searchRequest.searchRequest, q.getContextId()) == false) { + if (dependencies.searchService.freeReaderContext(q.getContextId())) { + q.clearContextId(); + } } } diff --git a/server/src/main/java/org/elasticsearch/transport/BytesTransportMessage.java b/server/src/main/java/org/elasticsearch/transport/BytesTransportMessage.java new file mode 100644 index 0000000000000..c7c6dcb7d9ff1 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/transport/BytesTransportMessage.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.transport; + +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.io.stream.StreamOutput; + +import java.io.IOException; + +public interface BytesTransportMessage { + + ReleasableBytesReference bytes(); + + /** + * Writes the data in a "thin" manner, without the actual bytes, assumes + * the actual bytes will be appended right after this content. + */ + void writeThin(StreamOutput out) throws IOException; +} diff --git a/server/src/main/java/org/elasticsearch/transport/BytesTransportRequest.java b/server/src/main/java/org/elasticsearch/transport/BytesTransportRequest.java index a16411cb25f1d..3284597ebe14a 100644 --- a/server/src/main/java/org/elasticsearch/transport/BytesTransportRequest.java +++ b/server/src/main/java/org/elasticsearch/transport/BytesTransportRequest.java @@ -10,7 +10,6 @@ package org.elasticsearch.transport; import org.elasticsearch.TransportVersion; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -21,7 +20,7 @@ * A specialized, bytes only request, that can potentially be optimized on the network * layer, specifically for the same large buffer send to several nodes. */ -public class BytesTransportRequest extends AbstractTransportRequest { +public class BytesTransportRequest extends AbstractTransportRequest implements BytesTransportMessage { final ReleasableBytesReference bytes; private final TransportVersion version; @@ -41,14 +40,12 @@ public TransportVersion version() { return this.version; } - public BytesReference bytes() { + @Override + public ReleasableBytesReference bytes() { return this.bytes; } - /** - * Writes the data in a "thin" manner, without the actual bytes, assumes - * the actual bytes will be appended right after this content. - */ + @Override public void writeThin(StreamOutput out) throws IOException { super.writeTo(out); out.writeVInt(bytes.length()); diff --git a/server/src/main/java/org/elasticsearch/transport/BytesTransportResponse.java b/server/src/main/java/org/elasticsearch/transport/BytesTransportResponse.java new file mode 100644 index 0000000000000..571d0d4008e24 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/transport/BytesTransportResponse.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.transport; + +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.io.stream.StreamOutput; + +import java.io.IOException; + +/** + * A specialized, bytes only response, that can potentially be optimized on the network layer. + */ +public class BytesTransportResponse extends TransportResponse implements BytesTransportMessage { + + private final ReleasableBytesReference bytes; + + public BytesTransportResponse(ReleasableBytesReference bytes) { + this.bytes = bytes; + } + + @Override + public ReleasableBytesReference bytes() { + return this.bytes; + } + + @Override + public void writeThin(StreamOutput out) throws IOException {} + + @Override + public void writeTo(StreamOutput out) throws IOException { + bytes.writeTo(out); + } + + @Override + public void incRef() { + bytes.incRef(); + } + + @Override + public boolean tryIncRef() { + return bytes.tryIncRef(); + } + + @Override + public boolean decRef() { + return bytes.decRef(); + } + + @Override + public boolean hasReferences() { + return bytes.hasReferences(); + } +} diff --git a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java index a83d1019e7c64..c5041c4e75d17 100644 --- a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java +++ b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java @@ -227,7 +227,7 @@ private void sendMessage( Releasable onAfter ) throws IOException { assert action != null; - final var compressionScheme = writeable instanceof BytesTransportRequest ? null : possibleCompressionScheme; + final var compressionScheme = writeable instanceof BytesTransportMessage ? null : possibleCompressionScheme; final BytesReference message; boolean serializeSuccess = false; final RecyclerBytesStreamOutput byteStreamOutput = new RecyclerBytesStreamOutput(recycler); @@ -334,11 +334,11 @@ private static BytesReference serializeMessageBody( final ReleasableBytesReference zeroCopyBuffer; try { stream.setTransportVersion(version); - if (writeable instanceof BytesTransportRequest bRequest) { + if (writeable instanceof BytesTransportMessage bRequest) { assert stream == byteStreamOutput; assert compressionScheme == null; bRequest.writeThin(stream); - zeroCopyBuffer = bRequest.bytes; + zeroCopyBuffer = bRequest.bytes(); } else if (writeable instanceof RemoteTransportException remoteTransportException) { stream.writeException(remoteTransportException); zeroCopyBuffer = ReleasableBytesReference.empty();