Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand Down Expand Up @@ -52,6 +53,7 @@
import org.elasticsearch.transport.RemoteClusterService;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportActionProxy;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
Expand Down Expand Up @@ -487,6 +489,7 @@ public static void registerRequestHandler(TransportService transportService, Sea
(request, channel, task) -> searchService.executeFetchPhase(
request,
(SearchShardTask) task,
maybeGetNetworkBuffer(transportService, channel),
new ChannelActionListener<>(channel)
)
);
Expand All @@ -503,7 +506,12 @@ public static void registerRequestHandler(TransportService transportService, Sea
TransportActionProxy.registerProxyAction(transportService, RANK_FEATURE_SHARD_ACTION_NAME, true, RankFeatureResult::new);

final TransportRequestHandler<ShardFetchRequest> shardFetchRequestHandler = (request, channel, task) -> searchService
.executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel));
.executeFetchPhase(
request,
(SearchShardTask) task,
maybeGetNetworkBuffer(transportService, channel),
new ChannelActionListener<>(channel)
);
transportService.registerRequestHandler(
FETCH_ID_SCROLL_ACTION_NAME,
EsExecutors.DIRECT_EXECUTOR_SERVICE,
Expand Down Expand Up @@ -531,6 +539,12 @@ public static void registerRequestHandler(TransportService transportService, Sea
TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NODE_NAME, true, CanMatchNodeResponse::new);
}

private static RecyclerBytesStreamOutput maybeGetNetworkBuffer(TransportService transportService, TransportChannel channel) {
return TransportService.DIRECT_RESPONSE_PROFILE.equals(channel.getProfileName()) || channel.compressionScheme() != null
? null
: transportService.newNetworkBytesStream();
}

private static Executor buildFreeContextExecutor(TransportService transportService) {
final ThrottledTaskRunner throttledTaskRunner = new ThrottledTaskRunner(
"free_context",
Expand Down
42 changes: 33 additions & 9 deletions server/src/main/java/org/elasticsearch/search/SearchHits.java
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,9 @@ public boolean isPooled() {
@Override
public void writeTo(StreamOutput out) throws IOException {
assert hasReferences();
final boolean hasTotalHits = totalHits != null;
out.writeBoolean(hasTotalHits);
if (hasTotalHits) {
Lucene.writeTotalHits(out, totalHits);
}
out.writeFloat(maxScore);
writeHeader(out);
out.writeArray(hits);
out.writeOptional(Lucene::writeSortFieldArray, sortFields);
out.writeOptionalString(collapseField);
out.writeOptionalArray(Lucene::writeSortValue, collapseValues);
writeFooter(out);
}

/**
Expand Down Expand Up @@ -260,6 +253,37 @@ private void deallocate() {
}
}

public void writeAndRelease(StreamOutput out) throws IOException {
boolean released = refCounted.decRef();
assert released;
writeHeader(out);
var hits = this.hits;
out.writeVInt(hits.length);
for (int i = 0; i < hits.length; i++) {
var h = hits[i];
hits[i] = null;
assert h != null;
h.writeTo(out);
h.decRef();
}
writeFooter(out);
}

private void writeFooter(StreamOutput out) throws IOException {
out.writeOptional(Lucene::writeSortFieldArray, sortFields);
out.writeOptionalString(collapseField);
out.writeOptionalArray(Lucene::writeSortValue, collapseValues);
}

private void writeHeader(StreamOutput out) throws IOException {
final boolean hasTotalHits = totalHits != null;
out.writeBoolean(hasTotalHits);
if (hasTotalHits) {
Lucene.writeTotalHits(out, totalHits);
}
out.writeFloat(maxScore);
}

@Override
public boolean hasReferences() {
return refCounted.hasReferences();
Expand Down
39 changes: 32 additions & 7 deletions server/src/main/java/org/elasticsearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.settings.Setting;
Expand Down Expand Up @@ -95,7 +97,6 @@
import org.elasticsearch.search.dfs.DfsPhase;
import org.elasticsearch.search.dfs.DfsSearchResult;
import org.elasticsearch.search.fetch.FetchPhase;
import org.elasticsearch.search.fetch.FetchSearchResult;
import org.elasticsearch.search.fetch.QueryFetchSearchResult;
import org.elasticsearch.search.fetch.ScrollQueryFetchSearchResult;
import org.elasticsearch.search.fetch.ShardFetchRequest;
Expand Down Expand Up @@ -136,7 +137,9 @@
import org.elasticsearch.threadpool.Scheduler.Cancellable;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.threadpool.ThreadPool.Names;
import org.elasticsearch.transport.BytesTransportResponse;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.Transports;

import java.io.IOException;
Expand Down Expand Up @@ -1107,7 +1110,8 @@ private Executor getExecutor(IndexShard indexShard) {
public void executeFetchPhase(
InternalScrollSearchRequest request,
SearchShardTask task,
ActionListener<ScrollQueryFetchSearchResult> listener
RecyclerBytesStreamOutput networkBuffer,
ActionListener<TransportResponse> listener
) {
final LegacyReaderContext readerContext = (LegacyReaderContext) findReaderContext(request.contextId(), request);
final Releasable markAsUsed;
Expand Down Expand Up @@ -1139,8 +1143,14 @@ public void executeFetchPhase(
opsListener.onFailedQueryPhase(searchContext);
}
}
QueryFetchSearchResult fetchSearchResult = executeFetchPhase(readerContext, searchContext, afterQueryTime);
return new ScrollQueryFetchSearchResult(fetchSearchResult, searchContext.shardTarget());
var resp = executeFetchPhase(readerContext, searchContext, afterQueryTime);
if (networkBuffer == null) {
return new ScrollQueryFetchSearchResult(resp, searchContext.shardTarget());
}
searchContext.shardTarget().writeTo(networkBuffer);
resp.writeTo(networkBuffer);
resp.decRef();
return new BytesTransportResponse(new ReleasableBytesReference(networkBuffer.bytes(), networkBuffer));
} catch (Exception e) {
assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e);
logger.trace("Fetch phase failed", e);
Expand All @@ -1150,7 +1160,12 @@ public void executeFetchPhase(
}, wrapFailureListener(listener, readerContext, markAsUsed));
}

public void executeFetchPhase(ShardFetchRequest request, CancellableTask task, ActionListener<FetchSearchResult> listener) {
public void executeFetchPhase(
ShardFetchRequest request,
CancellableTask task,
RecyclerBytesStreamOutput networkBuffer,
ActionListener<TransportResponse> listener
) {
final ReaderContext readerContext = findReaderContext(request.contextId(), request);
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest());
final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest));
Expand Down Expand Up @@ -1179,8 +1194,18 @@ public void executeFetchPhase(ShardFetchRequest request, CancellableTask task, A
}
var fetchResult = searchContext.fetchResult();
// inc-ref fetch result because we close the SearchContext that references it in this try-with-resources block
fetchResult.incRef();
return fetchResult;
if (networkBuffer == null) {
fetchResult.incRef();
return fetchResult;
}
try (networkBuffer) {
// no need to worry about releasing this instance safely before we write the first byte to it
// => the try-with-resources here is all we need to not leak any buffers
fetchResult.contextId.writeTo(networkBuffer);
fetchResult.consumeHits(networkBuffer);
networkBuffer.writeOptionalWriteable(fetchResult.profileResult());
return new BytesTransportResponse(networkBuffer.moveToBytesReference());
}
} catch (Exception e) {
assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e);
// we handle the failure in the failure listener below
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ public SearchHits hits() {
return hits;
}

public void consumeHits(StreamOutput out) throws IOException {
var hits = this.hits;
this.hits = null;
hits.writeAndRelease(out);
}

public FetchSearchResult initCounter() {
counter = 0;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ public void sendResponse(Exception exception) {
}
}

@Override
public Compression.Scheme compressionScheme() {
return channel.compressionScheme();
}

@Override
public TransportVersion getVersion() {
return channel.getVersion();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ public void sendResponse(Exception exception) {
}
}

@Override
public Compression.Scheme compressionScheme() {
return compressionScheme;
}

@Override
public TransportVersion getVersion() {
return version;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ public interface TransportChannel {

void sendResponse(Exception exception);

/**
* Returns a suggestion about the desired compression scheme to use for sending the response when using {@link BytesTransportResponse}
* to bypass transport layer serialization and compression.
*
* @return the suggested compression scheme to use for responses or {@code null} when not using compression
*/
Compression.Scheme compressionScheme();

/**
* Returns the version of the data to communicate in this channel.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1630,6 +1630,11 @@ public String toString() {
}
}

@Override
public Compression.Scheme compressionScheme() {
return null;
}

protected RemoteTransportException wrapInRemote(Exception e) {
return e instanceof RemoteTransportException remoteTransportException
? remoteTransportException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.junit.Before;
Expand Down Expand Up @@ -412,8 +413,8 @@ public void testSearchWhileIndexDeleted() throws InterruptedException {
intCursors,
null/* not a scroll */
);
PlainActionFuture<FetchSearchResult> listener = new PlainActionFuture<>();
service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, emptyMap()), listener);
PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, emptyMap()), null, listener);
listener.get();
if (useScroll) {
// have to free context since this test does not remove the index from IndicesService.
Expand Down Expand Up @@ -601,9 +602,10 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
// execute fetch phase and perform any validations once we retrieve the response
// the difference in how we do assertions here is needed because once the transport service sends back the response
// it decrements the reference to the FetchSearchResult (through the ActionListener#respondAndRelease) and sets hits to null
PlainActionFuture<FetchSearchResult> fetchListener = new PlainActionFuture<>() {
PlainActionFuture<TransportResponse> fetchListener = new PlainActionFuture<>() {
@Override
public void onResponse(FetchSearchResult fetchSearchResult) {
public void onResponse(TransportResponse response) {
FetchSearchResult fetchSearchResult = (FetchSearchResult) response;
assertNotNull(fetchSearchResult);
assertNotNull(fetchSearchResult.hits());

Expand All @@ -624,7 +626,7 @@ public void onFailure(Exception e) {
throw new AssertionError("No failure should have been raised", e);
}
};
service.executeFetchPhase(fetchRequest, searchTask, fetchListener);
service.executeFetchPhase(fetchRequest, searchTask, null, fetchListener);
fetchListener.get();
} catch (Exception ex) {
if (queryResult != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,11 @@ public void sendResponse(Exception exception) {
in.sendResponse(exception);
}

@Override
public Compression.Scheme compressionScheme() {
return in.compressionScheme();
}

@Override
public TransportVersion getVersion() {
return in.getVersion();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,11 @@ public void sendResponse(Exception exception) {
channel.sendResponse(exception);

}

@Override
public Compression.Scheme compressionScheme() {
return channel.compressionScheme();
}
}, task);
} else {
return actualHandler;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,11 @@ public String toString() {
}
});
}

@Override
public Compression.Scheme compressionScheme() {
return null;
}
};

final TransportRequest copiedRequest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,9 @@ public void sendResponse(TransportResponse response) {
public void sendResponse(Exception exception) {
listener.onFailure(exception);
}

@Override
public Compression.Scheme compressionScheme() {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.transport.AbstractSimpleTransportTestCase;
import org.elasticsearch.transport.Compression;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportRequest;
Expand Down Expand Up @@ -721,6 +722,11 @@ public void sendResponse(TransportResponse response) {
public void sendResponse(Exception exception) {
in.sendResponse(exception);
}

@Override
public Compression.Scheme compressionScheme() {
return in.compressionScheme();
}
}

private final List<CircuitBreaker> breakers = Collections.synchronizedList(new ArrayList<>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.transport.Compression;
import org.elasticsearch.transport.NoSuchRemoteClusterException;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportResponse;
Expand Down Expand Up @@ -467,6 +468,11 @@ public void sendResponse(TransportResponse response) {
public void sendResponse(Exception exception) {
channel.sendResponse(exception);
}

@Override
public Compression.Scheme compressionScheme() {
return channel.compressionScheme();
}
}, task)
);
}
Expand Down
Loading