diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/fetch/ChunkedFetchPhaseCancellationIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/fetch/ChunkedFetchPhaseCancellationIT.java new file mode 100644 index 0000000000000..89ee03e29aa87 --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/fetch/ChunkedFetchPhaseCancellationIT.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.fetch; + +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.TransportSearchAction; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.indices.breaker.CircuitBreakerService; +import org.elasticsearch.script.Script; +import org.elasticsearch.script.ScriptType; +import org.elasticsearch.search.SearchService; +import org.elasticsearch.test.AbstractSearchCancellationTestCase; +import org.elasticsearch.test.ESIntegTestCase; + +import java.util.Collections; + +import static org.elasticsearch.test.AbstractSearchCancellationTestCase.ScriptedBlockPlugin.SEARCH_BLOCK_SCRIPT_NAME; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0, numClientNodes = 0) +public class ChunkedFetchPhaseCancellationIT extends AbstractSearchCancellationTestCase { + + @Override + protected boolean enableConcurrentSearch() { + return false; + } + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal, otherSettings)) + .put("indices.breaker.request.type", "memory") + .put("indices.breaker.request.limit", "100mb") + .put(SearchService.FETCH_PHASE_CHUNKED_ENABLED.getKey(), true) + .build(); + } + + public void testTaskCancellationReleasesCoordinatorBreakerBytes() throws Exception { + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + createIndex("test", 2, 0); + indexTestData(); + ensureGreen("test"); + + var plugins = initBlockFactory(); + long breakerBefore = getRequestBreakerUsed(coordinatorNode); + + ActionFuture searchResponse = internalCluster().client(coordinatorNode) + .prepareSearch("test") + .addScriptField("test_field", new Script(ScriptType.INLINE, "mockscript", SEARCH_BLOCK_SCRIPT_NAME, Collections.emptyMap())) + .setAllowPartialSearchResults(true) + .execute(); + + awaitForBlock(plugins); + cancelSearch(TransportSearchAction.TYPE.name()); + disableBlocks(plugins); + ensureSearchWasCancelled(searchResponse); + + assertBusy( + () -> assertThat( + "Coordinator breaker bytes should be released after cancellation", + getRequestBreakerUsed(coordinatorNode), + lessThanOrEqualTo(breakerBefore) + ) + ); + } + + private long getRequestBreakerUsed(String node) { + CircuitBreakerService breakerService = internalCluster().getInstance(CircuitBreakerService.class, node); + CircuitBreaker breaker = breakerService.getBreaker(CircuitBreaker.REQUEST); + return breaker.getUsed(); + } +} diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/fetch/ChunkedFetchPhaseCircuitBreakerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/fetch/ChunkedFetchPhaseCircuitBreakerIT.java new file mode 100644 index 0000000000000..739df3b4e621c --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/fetch/ChunkedFetchPhaseCircuitBreakerIT.java @@ -0,0 +1,730 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.fetch; + +import org.apache.logging.log4j.util.Strings; +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.search.ClosePointInTimeRequest; +import org.elasticsearch.action.search.OpenPointInTimeRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.SearchType; +import org.elasticsearch.action.search.TransportClosePointInTimeAction; +import org.elasticsearch.action.search.TransportOpenPointInTimeAction; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.indices.breaker.CircuitBreakerService; +import org.elasticsearch.search.SearchService; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.builder.PointInTimeBuilder; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.test.ESIntegTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery; +import static org.elasticsearch.index.query.QueryBuilders.termQuery; +import static org.elasticsearch.search.aggregations.AggregationBuilders.terms; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse; +import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; +import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.hamcrest.Matchers.notNullValue; + +/** + * Integration tests for chunked fetch phase circuit breaker tracking. The tests verify that the coordinator node properly + * tracks and releases circuit breaker memory when using chunked fetch across multiple shards and nodes. + */ +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0, numClientNodes = 0) +public class ChunkedFetchPhaseCircuitBreakerIT extends ESIntegTestCase { + + private static final String INDEX_NAME = "chunked_multi_shard_idx"; + private static final String SORT_FIELD = "sort_field"; + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal, otherSettings)) + .put("indices.breaker.request.type", "memory") + .put("indices.breaker.request.limit", "200mb") + .put(SearchService.FETCH_PHASE_CHUNKED_ENABLED.getKey(), true) + .build(); + } + + public void testChunkedFetchMultipleShardsSingleNode() throws Exception { + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + createIndexForTest( + INDEX_NAME, + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 3).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).build() + ); + + populateIndex(INDEX_NAME, 150, 1_500); + ensureGreen(INDEX_NAME); + + long breakerBefore = getRequestBreakerUsed(coordinatorNode); + + assertNoFailuresAndResponse( + internalCluster().client(coordinatorNode) + .prepareSearch(INDEX_NAME) + .setQuery(matchAllQuery()) + .setSize(100) + .addSort(SORT_FIELD, SortOrder.ASC), + response -> { + assertThat(response.getHits().getHits().length, equalTo(100)); + verifyHitsOrder(response); + } + ); + + assertBusy(() -> { + assertThat( + "Coordinator circuit breaker should be released after chunked fetch completes", + getRequestBreakerUsed(coordinatorNode), + lessThanOrEqualTo(breakerBefore) + ); + }); + } + + public void testChunkedFetchMultipleShardsMultipleNodes() throws Exception { + internalCluster().startNode(); + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + int numberOfShards = randomIntBetween(6, 16); + createIndexForTest( + INDEX_NAME, + Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numberOfShards) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .build() + ); + + int numberOfDocuments = randomIntBetween(250, 600); + populateIndex(INDEX_NAME, numberOfDocuments, 1_500); + ensureGreen(INDEX_NAME); + + long breakerBefore = getRequestBreakerUsed(coordinatorNode); + assertNoFailuresAndResponse( + internalCluster().client(coordinatorNode) + .prepareSearch(INDEX_NAME) + .setQuery(matchAllQuery()) + .setSize(200) + .addSort(SORT_FIELD, SortOrder.ASC), + response -> { + assertThat(response.getHits().getHits().length, equalTo(200)); + verifyHitsOrder(response); + } + ); + + assertBusy(() -> { + long currentBreaker = getRequestBreakerUsed(coordinatorNode); + assertThat( + "Coordinator circuit breaker should be released after many-shard chunked fetch, current: " + + currentBreaker + + ", before: " + + breakerBefore, + currentBreaker, + lessThanOrEqualTo(breakerBefore) + ); + }); + } + + public void testChunkedFetchConcurrentSearches() throws Exception { + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + createIndexForTest( + INDEX_NAME, + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 4).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).build() + ); + + populateIndex(INDEX_NAME, 110, 500); + ensureGreen(INDEX_NAME); + + long breakerBefore = getRequestBreakerUsed(coordinatorNode); + + int numSearches = 5; + ExecutorService executor = Executors.newFixedThreadPool(numSearches); + try { + List> futures = IntStream.range(0, numSearches).mapToObj(i -> CompletableFuture.runAsync(() -> { + assertNoFailuresAndResponse( + internalCluster().client(coordinatorNode) + .prepareSearch(INDEX_NAME) + .setQuery(matchAllQuery()) + .setSize(30) + .addSort(SORT_FIELD, SortOrder.ASC), + response -> assertThat(response.getHits().getHits().length, equalTo(30)) + ); + }, executor)).toList(); + + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).get(30, TimeUnit.SECONDS); + assertThat("All concurrent searches should succeed", futures.size(), equalTo(numSearches)); + } finally { + executor.shutdown(); + assertTrue("Executor should terminate", executor.awaitTermination(10, TimeUnit.SECONDS)); + } + + assertBusy(() -> { + long currentBreaker = getRequestBreakerUsed(coordinatorNode); + assertThat( + "Coordinator circuit breaker should be released after concurrent searches, current: " + + currentBreaker + + ", before: " + + breakerBefore, + currentBreaker, + lessThanOrEqualTo(breakerBefore) + ); + }); + } + + public void testChunkedFetchWithReplicas() throws Exception { + internalCluster().startNode(); + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + createIndexForTest( + INDEX_NAME, + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 3).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1).build() + ); + + populateIndex(INDEX_NAME, 150, 1_000); + ensureGreen(INDEX_NAME); + + long breakerBefore = getRequestBreakerUsed(coordinatorNode); + + // Search will naturally hit both primaries and replicas due to load balancing + assertNoFailuresAndResponse( + internalCluster().client(coordinatorNode) + .prepareSearch(INDEX_NAME) + .setQuery(matchAllQuery()) + .setSize(100) + .addSort(SORT_FIELD, SortOrder.ASC), + response -> { + assertThat(response.getHits().getHits().length, equalTo(100)); + verifyHitsOrder(response); + } + ); + + assertBusy(() -> { + long currentBreaker = getRequestBreakerUsed(coordinatorNode); + assertThat( + "Coordinator circuit breaker should be released after chunked fetch with replicas", + currentBreaker, + lessThanOrEqualTo(breakerBefore) + ); + }); + } + + public void testChunkedFetchWithFiltering() throws Exception { + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + createIndexForTest( + INDEX_NAME, + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 4).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).build() + ); + + populateIndex(INDEX_NAME, 300, 800); + ensureGreen(INDEX_NAME); + + long breakerBefore = getRequestBreakerUsed(coordinatorNode); + + assertNoFailuresAndResponse( + internalCluster().client(coordinatorNode) + .prepareSearch(INDEX_NAME) + .setQuery(termQuery("keyword", "value1")) + .setSize(50) + .addSort(SORT_FIELD, SortOrder.ASC), + response -> { + assertThat(response.getHits().getHits().length, greaterThan(0)); + // Verify all results match filter + for (int i = 0; i < response.getHits().getHits().length; i++) { + assertThat(Objects.requireNonNull(response.getHits().getHits()[i].getSourceAsMap()).get("keyword"), equalTo("value1")); + } + verifyHitsOrder(response); + } + ); + + assertBusy(() -> { + assertThat( + "Coordinator circuit breaker should be released after chunked fetch completes", + getRequestBreakerUsed(coordinatorNode), + lessThanOrEqualTo(breakerBefore) + ); + }); + } + + public void testChunkedFetchNoMemoryLeakSequential() throws Exception { + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + createIndexForTest( + INDEX_NAME, + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 4).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).build() + ); + + populateIndex(INDEX_NAME, 200, 800); + ensureGreen(INDEX_NAME); + + long initialBreaker = getRequestBreakerUsed(coordinatorNode); + + for (int i = 0; i < 50; i++) { + assertNoFailuresAndResponse( + internalCluster().client(coordinatorNode) + .prepareSearch(INDEX_NAME) + .setQuery(matchAllQuery()) + .setSize(40) + .addSort(SORT_FIELD, SortOrder.ASC), + response -> { + assertThat(response.getHits().getHits().length, equalTo(40)); + } + ); + } + + assertBusy(() -> { + long currentBreaker = getRequestBreakerUsed(coordinatorNode); + assertThat( + "Coordinator circuit breaker should not leak memory across sequential chunked fetches, current: " + + currentBreaker + + ", initial: " + + initialBreaker, + currentBreaker, + lessThanOrEqualTo(initialBreaker) + ); + }); + } + + public void testChunkedFetchWithAggregations() throws Exception { + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + createIndexForTest( + INDEX_NAME, + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 3).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).build() + ); + populateIndex(INDEX_NAME, 250, 800); + ensureGreen(INDEX_NAME); + + long breakerBefore = getRequestBreakerUsed(coordinatorNode); + + assertNoFailuresAndResponse( + internalCluster().client(coordinatorNode) + .prepareSearch(INDEX_NAME) + .setQuery(matchAllQuery()) + .setSize(100) + .addAggregation(terms("keywords").field("keyword").size(10)) + .addSort(SORT_FIELD, SortOrder.ASC), + response -> { + assertThat(response.getHits().getHits().length, equalTo(100)); + verifyHitsOrder(response); + + // Verify aggregation results + Terms keywordAgg = response.getAggregations().get("keywords"); + assertThat(keywordAgg, notNullValue()); + assertThat(keywordAgg.getBuckets().size(), equalTo(10)); + } + ); + + assertBusy(() -> { + long currentBreaker = getRequestBreakerUsed(coordinatorNode); + assertThat( + "Coordinator circuit breaker should be released after chunked fetch with aggregations", + currentBreaker, + lessThanOrEqualTo(breakerBefore) + ); + }); + } + + public void testChunkedFetchWithSearchAfter() throws Exception { + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + createIndexForTest( + INDEX_NAME, + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 4).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).build() + ); + + populateIndex(INDEX_NAME, 150, 800); + ensureGreen(INDEX_NAME); + + long breakerBefore = getRequestBreakerUsed(coordinatorNode); + + // First page + SearchResponse response1 = internalCluster().client(coordinatorNode) + .prepareSearch(INDEX_NAME) + .setQuery(matchAllQuery()) + .setSize(30) + .addSort(SORT_FIELD, SortOrder.ASC) + .get(); + + try { + assertThat(response1.getHits().getHits().length, equalTo(30)); + Object[] lastSort = response1.getHits().getHits()[29].getSortValues(); + + // Second page with search_after using same coordinator + assertNoFailuresAndResponse( + internalCluster().client(coordinatorNode) + .prepareSearch(INDEX_NAME) + .setQuery(matchAllQuery()) + .setSize(30) + .addSort(SORT_FIELD, SortOrder.ASC) + .searchAfter(lastSort), + response2 -> { + assertThat(response2.getHits().getHits().length, equalTo(30)); + + // Verify second page starts after first page + long firstValuePage2 = (Long) response2.getHits().getHits()[0].getSortValues()[0]; + long lastValuePage1 = (Long) lastSort[0]; + assertThat(firstValuePage2, greaterThan(lastValuePage1)); + } + ); + } finally { + response1.decRef(); + } + + assertBusy(() -> { + long currentBreaker = getRequestBreakerUsed(coordinatorNode); + assertThat( + "Coordinator circuit breaker should be released after paginated chunked fetches, current: " + + currentBreaker + + ", before: " + + breakerBefore, + currentBreaker, + lessThanOrEqualTo(breakerBefore) + ); + }); + } + + public void testChunkedFetchWithDfsQueryThenFetch() throws Exception { + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + createIndexForTest( + INDEX_NAME, + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 4).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).build() + ); + + populateIndex(INDEX_NAME, 100, 1_500); + ensureGreen(INDEX_NAME); + + long breakerBefore = getRequestBreakerUsed(coordinatorNode); + + assertNoFailuresAndResponse( + internalCluster().client(coordinatorNode) + .prepareSearch(INDEX_NAME) + .setSearchType(SearchType.DFS_QUERY_THEN_FETCH) + .setQuery(matchAllQuery()) + .setSize(50) + .addSort(SORT_FIELD, SortOrder.ASC), + response -> { + assertThat(response.getHits().getHits().length, equalTo(50)); + verifyHitsOrder(response); + } + ); + + assertBusy(() -> { + assertThat( + "Coordinator circuit breaker should be released after DFS chunked fetch", + getRequestBreakerUsed(coordinatorNode), + lessThanOrEqualTo(breakerBefore) + ); + }); + } + + public void testChunkedFetchWithPointInTimeReleasesBreaker() throws Exception { + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + createIndexForTest( + INDEX_NAME, + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 3).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).build() + ); + + populateIndex(INDEX_NAME, 180, 800); + ensureGreen(INDEX_NAME); + + long breakerBefore = getRequestBreakerUsed(coordinatorNode); + + var pitResponse = internalCluster().client(coordinatorNode) + .execute(TransportOpenPointInTimeAction.TYPE, new OpenPointInTimeRequest(INDEX_NAME).keepAlive(TimeValue.timeValueMinutes(1))) + .actionGet(); + + try { + assertNoFailuresAndResponse( + internalCluster().client(coordinatorNode) + .prepareSearch() + .setPointInTime(new PointInTimeBuilder(pitResponse.getPointInTimeId())) + .setSize(60) + .addSort(SORT_FIELD, SortOrder.ASC), + response -> { + assertThat(response.getHits().getHits().length, equalTo(60)); + verifyHitsOrder(response); + } + ); + } finally { + internalCluster().client(coordinatorNode) + .execute(TransportClosePointInTimeAction.TYPE, new ClosePointInTimeRequest(pitResponse.getPointInTimeId())) + .actionGet(); + } + + assertBusy(() -> { + assertThat( + "Coordinator circuit breaker should be released after chunked PIT search", + getRequestBreakerUsed(coordinatorNode), + lessThanOrEqualTo(breakerBefore) + ); + }); + } + + public void testChunkedFetchNodeFailureDuringStreamingReleasesBreaker() throws Exception { + String dataNodeToFail = internalCluster().startNode(); + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + String failureIndex = "chunked_node_failure_idx"; + createIndexForTest( + failureIndex, + Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put("index.routing.allocation.include._name", dataNodeToFail) + .build() + ); + + populateIndex(failureIndex, 250, 1_200); + ensureGreen(failureIndex); + + long breakerBefore = getRequestBreakerUsed(coordinatorNode); + + ActionFuture searchFuture = internalCluster().client(coordinatorNode) + .prepareSearch(failureIndex) + .setAllowPartialSearchResults(true) + .setQuery(matchAllQuery()) + .setSize(180) + .addSort(SORT_FIELD, SortOrder.ASC) + .execute(); + + internalCluster().stopNode(dataNodeToFail); + + SearchResponse response = null; + Exception failure = null; + try { + response = searchFuture.actionGet(30, TimeUnit.SECONDS); + } catch (Exception e) { + failure = e; + } + + if (response != null) { + try { + if (response.getFailedShards() > 0) { + assertThat( + "Expected failed shards when shard-hosting node is stopped during chunked fetch", + response.getFailedShards(), + greaterThan(0) + ); + } else { + assertThat( + "Expected a full successful response when node stop races after search completion", + response.getHits().getHits().length, + equalTo(180) + ); + } + } finally { + response.decRef(); + } + } else { + assertNotNull("Search should either fail or report shard failures after node stop", failure); + } + + assertBusy(() -> { + assertThat( + "Coordinator circuit breaker should be released after node failure during chunked fetch", + getRequestBreakerUsed(coordinatorNode), + lessThanOrEqualTo(breakerBefore) + ); + }); + } + + public void testChunkedFetchCircuitBreakerReleasedOnFailure() throws Exception { + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + createIndexForTest( + INDEX_NAME, + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 4).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).build() + ); + + populateIndex(INDEX_NAME, 100, 1_500); + ensureGreen(INDEX_NAME); + + long breakerBefore = getRequestBreakerUsed(coordinatorNode); + + // Execute search that will fail + expectThrows( + Exception.class, + () -> internalCluster().client(coordinatorNode) + .prepareSearch(INDEX_NAME) + .setQuery(matchAllQuery()) + .setSize(50) + .addSort("non_existent_field", SortOrder.ASC) + .get() + ); + + assertBusy(() -> { + assertThat( + "Coordinator circuit breaker should be released even after chunked fetch failure", + getRequestBreakerUsed(coordinatorNode), + lessThanOrEqualTo(breakerBefore) + ); + }); + } + + public void testChunkedFetchWithPartialShardFailures() throws Exception { + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + String successIndex = "chunked_success_idx"; + String failingIndex = "chunked_failing_idx"; + + createIndexForTest( + successIndex, + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).build() + ); + assertAcked( + prepareCreate(failingIndex).setSettings( + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).build() + ).setMapping("text", "type=text") + ); + + populateIndex(successIndex, 100, 600); + populateSimpleIndex(failingIndex, 25); + ensureGreen(successIndex, failingIndex); + + long breakerBefore = getRequestBreakerUsed(coordinatorNode); + + SearchResponse response = internalCluster().client(coordinatorNode) + .prepareSearch(successIndex, failingIndex) + .setAllowPartialSearchResults(true) + .setQuery(matchAllQuery()) + .setSize(30) + .addSort(SORT_FIELD, SortOrder.ASC) + .get(); + + try { + assertThat("Expected at least one successful shard", response.getSuccessfulShards(), greaterThan(0)); + assertThat("Expected at least one failed shard", response.getFailedShards(), greaterThan(0)); + assertThat("Expected hits from successful shards", response.getHits().getHits().length, greaterThan(0)); + } finally { + response.decRef(); + } + + assertBusy(() -> { + long currentBreaker = getRequestBreakerUsed(coordinatorNode); + assertThat( + "Coordinator circuit breaker should be released after partial shard failures, current: " + + currentBreaker + + ", before: " + + breakerBefore, + currentBreaker, + lessThanOrEqualTo(breakerBefore) + ); + }); + } + + private void populateIndex(String indexName, int nDocs, int textSize) throws IOException { + int batchSize = 50; + // Reuse large payload strings across documents to avoid excessive temporary allocations during indexing. + String largeText1 = Strings.repeat("large content field 1 ", textSize); + String largeText2 = Strings.repeat("large content field 2 ", textSize); + String largeText3 = Strings.repeat("large content field 3 ", textSize); + for (int batch = 0; batch < nDocs; batch += batchSize) { + int endDoc = Math.min(batch + batchSize, nDocs); + List builders = new ArrayList<>(); + + for (int i = batch; i < endDoc; i++) { + builders.add( + prepareIndex(indexName).setId(Integer.toString(i)) + .setSource( + jsonBuilder().startObject() + .field(SORT_FIELD, i) + .field("text", "document " + i) + .field("large_text_1", largeText1) + .field("large_text_2", largeText2) + .field("large_text_3", largeText3) + .field("keyword", "value" + (i % 10)) + .endObject() + ) + ); + } + indexRandom(batch == 0, builders); + } + refresh(indexName); + } + + private void createIndexForTest(String indexName, Settings indexSettings) { + assertAcked( + prepareCreate(indexName).setSettings(indexSettings) + .setMapping( + SORT_FIELD, + "type=long", + "text", + "type=text,store=true", + "large_text_1", + "type=text,store=false", + "large_text_2", + "type=text,store=false", + "large_text_3", + "type=text,store=false", + "keyword", + "type=keyword" + ) + ); + } + + private long getRequestBreakerUsed(String node) { + CircuitBreakerService breakerService = internalCluster().getInstance(CircuitBreakerService.class, node); + CircuitBreaker breaker = breakerService.getBreaker(CircuitBreaker.REQUEST); + return breaker.getUsed(); + } + + private void populateSimpleIndex(String indexName, int nDocs) throws IOException { + List builders = new ArrayList<>(); + for (int i = 0; i < nDocs; i++) { + builders.add( + prepareIndex(indexName).setId(Integer.toString(i)) + .setSource(jsonBuilder().startObject().field("text", "doc " + i).endObject()) + ); + } + indexRandom(true, builders); + refresh(indexName); + } + + private void verifyHitsOrder(SearchResponse response) { + for (int i = 0; i < response.getHits().getHits().length - 1; i++) { + long current = (Long) response.getHits().getHits()[i].getSortValues()[0]; + long next = (Long) response.getHits().getHits()[i + 1].getSortValues()[0]; + assertThat("Hits should be in ascending order", current, lessThan(next)); + } + } +} diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/fetch/ChunkedFetchPhaseCircuitBreakerTrippingIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/fetch/ChunkedFetchPhaseCircuitBreakerTrippingIT.java new file mode 100644 index 0000000000000..3b7d4f5434949 --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/fetch/ChunkedFetchPhaseCircuitBreakerTrippingIT.java @@ -0,0 +1,405 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.fetch; + +import org.apache.logging.log4j.util.Strings; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.indices.breaker.CircuitBreakerService; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.SearchService; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.test.ESIntegTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +/** + * Integration tests for circuit breaker behavior when memory limits are exceeded + * during chunked fetch operations. + * + * Tests verify that the circuit breaker properly trips when the coordinator + * accumulates too much data, and that memory is correctly released even after + * breaker failures. Uses a low 5MB limit to reliably trigger breaker trips with + * large documents. + */ +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0, numClientNodes = 0) +public class ChunkedFetchPhaseCircuitBreakerTrippingIT extends ESIntegTestCase { + + private static final String INDEX_NAME = "idx"; + private static final String SORT_FIELD = "sort_field"; + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal, otherSettings)) + .put("indices.breaker.request.type", "memory") + .put("indices.breaker.request.limit", "5mb") // Low limit to trigger breaker - 5MB + .put(SearchService.FETCH_PHASE_CHUNKED_ENABLED.getKey(), true) + .build(); + } + + public void testCircuitBreakerTripsOnCoordinator() throws Exception { + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + + createIndex(INDEX_NAME); + + List builders = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + builders.add( + prepareIndex(INDEX_NAME).setId(Integer.toString(i)) + .setSource( + jsonBuilder().startObject() + .field(SORT_FIELD, i) + .field("text", "document " + i) + .field("huge_content", Strings.repeat("x", 2_000_000)) // 2MB each + .endObject() + ) + ); + } + indexRandom(true, builders); + refresh(INDEX_NAME); + ensureGreen(INDEX_NAME); + + long breakerBefore = getRequestBreakerUsed(coordinatorNode); + + ElasticsearchException exception = null; + SearchResponse resp = null; + try { + resp = internalCluster().client(coordinatorNode) + .prepareSearch(INDEX_NAME) + .setQuery(matchAllQuery()) + .setSize(5) // Request 3 huge docs = ~6MB > 5MB limit + .setAllowPartialSearchResults(false) + .addSort(SORT_FIELD, SortOrder.ASC) + .get(); + } catch (ElasticsearchException e) { + exception = e; + } finally { + if (resp != null) { + resp.decRef(); + } + } + + Throwable cause = exception.getCause(); + while (cause != null && (cause instanceof CircuitBreakingException) == false) { + cause = cause.getCause(); + } + assertThat("Should have CircuitBreakingException in cause chain", cause, instanceOf(CircuitBreakingException.class)); + + CircuitBreakingException breakerException = (CircuitBreakingException) cause; + assertThat(breakerException.getMessage(), containsString("[request] Data too large")); + + assertThat( + "Circuit breaking should map to 429 TOO_MANY_REQUESTS", + ExceptionsHelper.status(exception), + equalTo(RestStatus.TOO_MANY_REQUESTS) + ); + + assertBusy(() -> { + long currentBreaker = getRequestBreakerUsed(coordinatorNode); + assertThat( + "Coordinator circuit breaker should be released even after tripping, current: " + + currentBreaker + + ", before: " + + breakerBefore, + currentBreaker, + lessThanOrEqualTo(breakerBefore) + ); + }); + } + + public void testCircuitBreakerTripsWithConcurrentSearches() throws Exception { + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + createIndex(INDEX_NAME); + + List builders = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + builders.add( + prepareIndex(INDEX_NAME).setId(Integer.toString(i)) + .setSource( + jsonBuilder().startObject() + .field(SORT_FIELD, i) + .field("text", "document " + i) + .field("large_content", Strings.repeat("x", 1_500_000)) // 1.5MB each + .endObject() + ) + ); + } + indexRandom(true, builders); + refresh(INDEX_NAME); + ensureGreen(INDEX_NAME); + + long breakerBefore = getRequestBreakerUsed(coordinatorNode); + + int numSearches = 5; + ExecutorService executor = Executors.newFixedThreadPool(numSearches); + try { + List> futures = IntStream.range(0, numSearches).mapToObj(i -> CompletableFuture.runAsync(() -> { + var client = internalCluster().client(coordinatorNode); + var resp = client.prepareSearch(INDEX_NAME) + .setQuery(matchAllQuery()) + .setSize(4) + .setAllowPartialSearchResults(false) + .addSort(SORT_FIELD, SortOrder.ASC) + .get(); + resp.decRef(); + }, executor)).toList(); + + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).exceptionally(ex -> null).get(30, TimeUnit.SECONDS); + + List exceptions = new ArrayList<>(); + for (CompletableFuture future : futures) { + try { + future.get(); + } catch (ExecutionException e) { + exceptions.add((Exception) e.getCause()); + } + } + assertThat("Expected at least one circuit breaker exception", exceptions.size(), greaterThan(0)); + + boolean foundBreakerException = false; + for (Exception e : exceptions) { + if (containsCircuitBreakerException(e)) { + foundBreakerException = true; + break; + } + + assertThat( + "Circuit breaking should map to 429 TOO_MANY_REQUESTS", + ExceptionsHelper.status(e), + equalTo(RestStatus.TOO_MANY_REQUESTS) + ); + } + assertThat("Should have found a CircuitBreakingException", foundBreakerException, equalTo(true)); + } finally { + executor.shutdown(); + assertTrue("Executor should terminate", executor.awaitTermination(10, TimeUnit.SECONDS)); + } + + assertBusy(() -> { + long currentBreaker = getRequestBreakerUsed(coordinatorNode); + assertThat( + "Coordinator circuit breaker should recover after concurrent breaker trips, current: " + + currentBreaker + + ", before: " + + breakerBefore, + currentBreaker, + lessThanOrEqualTo(breakerBefore) + ); + }); + } + + public void testCircuitBreakerTripsOnSingleLargeDocument() throws Exception { + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + createIndex(INDEX_NAME); + + prepareIndex(INDEX_NAME).setId("huge") + .setSource( + jsonBuilder().startObject() + .field(SORT_FIELD, 0) + .field("text", "huge document") + .field("huge_field", Strings.repeat("x", 6_000_000)) // 6MB + .endObject() + ) + .get(); + populateLargeDocuments(INDEX_NAME, 10, 1_000); + refresh(INDEX_NAME); + + long breakerBefore = getRequestBreakerUsed(coordinatorNode); + ElasticsearchException exception = null; + SearchResponse resp = null; + try { + resp = internalCluster().client(coordinatorNode) + .prepareSearch(INDEX_NAME) + .setQuery(matchAllQuery()) + .setSize(5) + .setAllowPartialSearchResults(false) + .addSort(SORT_FIELD, SortOrder.ASC) + .get(); + } catch (ElasticsearchException e) { + exception = e; + } finally { + if (resp != null) { + resp.decRef(); + } + } + + boolean foundBreakerException = containsCircuitBreakerException(exception); + assertThat("Circuit breaker should have tripped on single large document", foundBreakerException, equalTo(true)); + + assertThat( + "Circuit breaking should map to 429 TOO_MANY_REQUESTS", + ExceptionsHelper.status(exception), + equalTo(RestStatus.TOO_MANY_REQUESTS) + ); + + assertBusy(() -> { + long currentBreaker = getRequestBreakerUsed(coordinatorNode); + assertThat( + "Coordinator circuit breaker should be released after single large doc trip", + currentBreaker, + lessThanOrEqualTo(breakerBefore) + ); + }); + } + + /** + * Test that multiple sequential breaker trips don't cause memory leaks. + * Repeatedly tripping the breaker should not accumulate memory. + */ + public void testRepeatedCircuitBreakerTripsNoLeak() throws Exception { + internalCluster().startNode(); + String coordinatorNode = internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + createIndex(INDEX_NAME); + + List builders = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + builders.add( + prepareIndex(INDEX_NAME).setId(Integer.toString(i)) + .setSource( + jsonBuilder().startObject() + .field(SORT_FIELD, i) + .field("text", "document " + i) + .field("large_content", Strings.repeat("x", 1_500_000)) // 1.5MB each + .endObject() + ) + ); + } + indexRandom(true, builders); + refresh(INDEX_NAME); + ensureGreen(INDEX_NAME); + + long initialBreaker = getRequestBreakerUsed(coordinatorNode); + + ElasticsearchException exception = null; + for (int i = 0; i < 10; i++) { + SearchResponse resp = null; + try { + resp = internalCluster().client(coordinatorNode) + .prepareSearch(INDEX_NAME) + .setQuery(matchAllQuery()) + .setSize(5) // 5 docs × 1.2MB = 6MB > 5MB limit + .setAllowPartialSearchResults(false) + .addSort(SORT_FIELD, SortOrder.ASC) + .get(); + } catch (ElasticsearchException e) { + exception = e; + } finally { + if (resp != null) { + resp.decRef(); + } + } + } + + boolean foundBreakerException = containsCircuitBreakerException(exception); + assertThat("Circuit breaker should have tripped on single large document", foundBreakerException, equalTo(true)); + + assertThat( + "Circuit breaking should map to 429 TOO_MANY_REQUESTS", + ExceptionsHelper.status(exception), + equalTo(RestStatus.TOO_MANY_REQUESTS) + ); + + assertBusy(() -> { + long currentBreaker = getRequestBreakerUsed(coordinatorNode); + assertThat( + "Circuit breaker should not leak after repeated trips, current: " + currentBreaker + ", initial: " + initialBreaker, + currentBreaker, + lessThanOrEqualTo(initialBreaker) + ); + }); + } + + private void populateLargeDocuments(String indexName, int nDocs, int contentSize) throws IOException { + int batchSize = 10; + for (int batch = 0; batch < nDocs; batch += batchSize) { + int endDoc = Math.min(batch + batchSize, nDocs); + List builders = new ArrayList<>(); + + for (int i = batch; i < endDoc; i++) { + builders.add( + prepareIndex(indexName).setId(Integer.toString(i)) + .setSource( + jsonBuilder().startObject() + .field(SORT_FIELD, i) + .field("text", "document " + i) + .field("large_content", Strings.repeat("x", contentSize)) + .endObject() + ) + ); + } + indexRandom(batch == 0, builders); + } + refresh(indexName); + } + + private void createIndex(String indexName) { + assertAcked( + prepareCreate(indexName).setSettings( + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 2).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + ) + .setMapping( + SORT_FIELD, + "type=long", + "text", + "type=text,store=false", + "large_content", + "type=text,store=false", + "huge_field", + "type=text,store=false" + ) + ); + } + + private long getRequestBreakerUsed(String nodeName) { + CircuitBreakerService breakerService = internalCluster().getInstance(CircuitBreakerService.class, nodeName); + CircuitBreaker breaker = breakerService.getBreaker(CircuitBreaker.REQUEST); + return breaker.getUsed(); + } + + private boolean containsCircuitBreakerException(Throwable t) { + if (t == null) { + return false; + } + if (t instanceof CircuitBreakingException) { + return true; + } + if (t.getMessage() != null && t.getMessage().contains("CircuitBreakingException")) { + return true; + } + return containsCircuitBreakerException(t.getCause()); + } +} diff --git a/server/src/main/java/org/elasticsearch/action/ActionModule.java b/server/src/main/java/org/elasticsearch/action/ActionModule.java index 3810b1c3e2dab..b87b28b91f8a9 100644 --- a/server/src/main/java/org/elasticsearch/action/ActionModule.java +++ b/server/src/main/java/org/elasticsearch/action/ActionModule.java @@ -409,6 +409,9 @@ import org.elasticsearch.rest.action.synonyms.RestPutSynonymRuleAction; import org.elasticsearch.rest.action.synonyms.RestPutSynonymsAction; import org.elasticsearch.search.crossproject.CrossProjectModeDecider; +import org.elasticsearch.search.fetch.chunk.ActiveFetchPhaseTasks; +import org.elasticsearch.search.fetch.chunk.TransportFetchPhaseCoordinationAction; +import org.elasticsearch.search.fetch.chunk.TransportFetchPhaseResponseChunkAction; import org.elasticsearch.snapshots.TransportUpdateSnapshotStatusAction; import org.elasticsearch.tasks.Task; import org.elasticsearch.telemetry.TelemetryProvider; @@ -745,6 +748,8 @@ public void reg actions.register(TransportMultiSearchAction.TYPE, TransportMultiSearchAction.class); actions.register(TransportExplainAction.TYPE, TransportExplainAction.class); actions.register(TransportClearScrollAction.TYPE, TransportClearScrollAction.class); + actions.register(TransportFetchPhaseCoordinationAction.TYPE, TransportFetchPhaseCoordinationAction.class); + actions.register(RecoveryAction.INSTANCE, TransportRecoveryAction.class); actions.register(TransportNodesReloadSecureSettingsAction.TYPE, TransportNodesReloadSecureSettingsAction.class); actions.register(AutoCreateAction.INSTANCE, AutoCreateAction.TransportAction.class); @@ -1056,6 +1061,8 @@ protected void configure() { bind(new TypeLiteral>() {}).toInstance(mappingRequestValidators); bind(new TypeLiteral>() {}).toInstance(indicesAliasesRequestRequestValidators); bind(AutoCreateIndex.class).toInstance(autoCreateIndex); + bind(ActiveFetchPhaseTasks.class).asEagerSingleton(); + bind(TransportFetchPhaseResponseChunkAction.class).asEagerSingleton(); // register ActionType -> transportAction Map used by NodeClient @SuppressWarnings("rawtypes") diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index 7d2f804812e57..0a7569f6fb71e 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -237,6 +237,7 @@ public void onFailure(Exception e) { } } }; + final Transport.Connection connection; try { connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()); @@ -257,7 +258,8 @@ public void onFailure(Exception e) { shardPhaseResult.getRescoreDocIds(), aggregatedDfs ), - context.getTask(), + context, + shardTarget, listener ); } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java index 9774ba54d6b90..ac671fee3534d 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java @@ -25,23 +25,30 @@ import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThrottledTaskRunner; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchService; +import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.dfs.DfsSearchResult; import org.elasticsearch.search.fetch.FetchSearchResult; import org.elasticsearch.search.fetch.QueryFetchSearchResult; import org.elasticsearch.search.fetch.ScrollQueryFetchSearchResult; import org.elasticsearch.search.fetch.ShardFetchRequest; import org.elasticsearch.search.fetch.ShardFetchSearchRequest; +import org.elasticsearch.search.fetch.chunk.FetchPhaseResponseChunk; +import org.elasticsearch.search.fetch.chunk.TransportFetchPhaseCoordinationAction; +import org.elasticsearch.search.fetch.chunk.TransportFetchPhaseResponseChunkAction; import org.elasticsearch.search.internal.InternalScrollSearchRequest; import org.elasticsearch.search.internal.ShardSearchContextId; import org.elasticsearch.search.internal.ShardSearchRequest; @@ -53,6 +60,7 @@ import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.AbstractTransportRequest; +import org.elasticsearch.transport.BytesTransportRequest; import org.elasticsearch.transport.BytesTransportResponse; import org.elasticsearch.transport.RemoteClusterService; import org.elasticsearch.transport.TaskTransportChannel; @@ -72,6 +80,9 @@ import java.util.Objects; import java.util.concurrent.Executor; import java.util.function.BiFunction; +import java.util.function.Supplier; + +import static org.elasticsearch.search.fetch.chunk.TransportFetchPhaseCoordinationAction.CHUNKED_FETCH_PHASE; /** * An encapsulation of {@link SearchService} operations exposed through @@ -115,6 +126,7 @@ public class SearchTransportService { Transport.Connection, ActionListener, ActionListener> responseWrapper; + private SearchService searchService; private final Map clientConnections = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency(); public SearchTransportService( @@ -130,6 +142,10 @@ public SearchTransportService( this.responseWrapper = responseWrapper; } + public void setSearchService(SearchService searchService) { + this.searchService = searchService; + } + public TransportService transportService() { return transportService; } @@ -194,6 +210,12 @@ public void sendExecuteQuery( SearchTask task, final ActionListener listener ) { + + // Set coordinator node so data node can detect chunked fetch scenarios + if (request.getCoordinatingNode() == null) { + request.setCoordinatingNode(transportService.getLocalNode()); + } + // we optimize this and expect a QueryFetchSearchResult if we only have a single shard in the search request // this used to be the QUERY_AND_FETCH which doesn't exist anymore. final boolean fetchDocuments = request.numberOfShards() == 1 @@ -270,13 +292,88 @@ public void sendExecuteScrollFetch( ); } + /** + * Sends a fetch request to retrieve documents from a data node. + * + *

This method decides between two fetch strategies: + *

    + *
  • Chunked fetch: Results are streamed back in chunks.
  • + *
  • Traditional fetch: All results returned in a single response.
  • + *
+ * + *

For chunked fetch, the request is routed through {@link TransportFetchPhaseCoordinationAction} + * on the local (coordinator) node, which registers a response stream before forwarding to the data node. + * The data node then streams chunks back via {@link TransportFetchPhaseResponseChunkAction}. + * + * @param connection the transport connection to the data node + * @param shardFetchRequest the fetch request containing doc IDs to retrieve + * @param context the search context for this async action + * @param shardTarget identifies the shard being fetched from + * @param listener callback for the fetch result + */ public void sendExecuteFetch( Transport.Connection connection, - final ShardFetchSearchRequest request, - SearchTask task, - final ActionListener listener + ShardFetchSearchRequest shardFetchRequest, + AbstractSearchAsyncAction context, + SearchShardTarget shardTarget, + ActionListener listener ) { - sendExecuteFetch(connection, FETCH_ID_ACTION_NAME, request, task, listener); + SearchTask task = context.getTask(); + + final TransportVersion dataNodeVersion = connection.getTransportVersion(); + boolean dataNodeSupports = dataNodeVersion.supports(CHUNKED_FETCH_PHASE); + boolean isCCSQuery = shardTarget.getClusterAlias() != null; + boolean isScrollOrReindex = context.getRequest().scroll() != null + || (shardFetchRequest.getShardSearchRequest() != null && shardFetchRequest.getShardSearchRequest().scroll() != null); + + if (logger.isDebugEnabled()) { + logger.debug( + "FetchSearchPhase decision for shard {}: chunkEnabled={}, " + + "dataNodeSupports={}, dataNodeVersionId={}, CHUNKED_FETCH_PHASE_id={}, " + + "targetNode={}, isCCSQuery={}, isScrollOrReindex={}", + shardTarget.getShardId(), + searchService.fetchPhaseChunked(), + dataNodeSupports, + dataNodeVersion.id(), + CHUNKED_FETCH_PHASE.id(), + connection.getNode(), + isCCSQuery, + isScrollOrReindex + ); + } + + // Determine if chunked fetch can be used for this request, checking + // 1. Feature flag enabled + // 2. Data node supports CHUNKED_FETCH_PHASE transport version + // 3. Not a cross-cluster search (CCS) + // 4. Not a scroll or reindex operation + if (searchService.fetchPhaseChunked() && dataNodeSupports && isCCSQuery == false && isScrollOrReindex == false) { + // Route through local TransportFetchPhaseCoordinationAction + shardFetchRequest.setCoordinatingNode(context.getSearchTransport().transportService().getLocalNode()); + shardFetchRequest.setCoordinatingTaskId(task.getId()); + + // Capture ThreadContext headers (security credentials etc.) to propagate + // through the local coordination action. ThreadContext is thread-local and would be + // lost when the coordination action executes on a different thread/executor. + // This is required for authentication/authorization where applied. + ThreadContext threadContext = transportService.getThreadPool().getThreadContext(); + Map headers = new HashMap<>(threadContext.getHeaders()); + + transportService.sendChildRequest( + transportService.getConnection(transportService.getLocalNode()), + TransportFetchPhaseCoordinationAction.TYPE.name(), + new TransportFetchPhaseCoordinationAction.Request(shardFetchRequest, connection.getNode(), headers), + task, + TransportRequestOptions.EMPTY, + new ActionListenerResponseHandler<>( + listener.map(TransportFetchPhaseCoordinationAction.Response::getResult), + TransportFetchPhaseCoordinationAction.Response::new, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ) + ); + } else { + sendExecuteFetch(connection, FETCH_ID_ACTION_NAME, shardFetchRequest, task, listener); + } } public void sendExecuteFetchScroll( @@ -546,12 +643,113 @@ public static void registerRequestHandler( namedWriteableRegistry ); - final TransportRequestHandler shardFetchRequestHandler = (request, channel, task) -> searchService - .executeFetchPhase( - request, - (SearchShardTask) task, - channelListener(transportService, channel, searchService.getCircuitBreaker()) - ); + /** + * Handler for fetch requests on the data node side. + * + *

When chunked fetch is used, creates a {@link FetchPhaseResponseChunk.Writer} that + * sends chunks back to the coordinator via {@link TransportFetchPhaseResponseChunkAction}. + * The writer preserves the ThreadContext to maintain security headers across async chunk sends. + */ + final TransportRequestHandler shardFetchRequestHandler = (request, channel, task) -> { + boolean fetchPhaseChunkedEnabled = searchService.fetchPhaseChunked(); + boolean hasCoordinator = request instanceof ShardFetchSearchRequest fetchSearchReq + && fetchSearchReq.getCoordinatingNode() != null; + + TransportVersion channelVersion = channel.getVersion(); + boolean versionSupported = channelVersion.supports(CHUNKED_FETCH_PHASE); + + // Check if we can connect to the coordinator (CCS detection) + boolean canConnectToCoordinator = false; + boolean coordinatorSupportsChunkedFetch = false; + if (hasCoordinator) { + ShardFetchSearchRequest fetchSearchReq = (ShardFetchSearchRequest) request; + DiscoveryNode coordinatorNode = fetchSearchReq.getCoordinatingNode(); + canConnectToCoordinator = transportService.nodeConnected(coordinatorNode); + + if (canConnectToCoordinator) { + try { + Transport.Connection coordConnection = transportService.getConnection(coordinatorNode); + coordinatorSupportsChunkedFetch = coordConnection.getTransportVersion().supports(CHUNKED_FETCH_PHASE); + } catch (Exception e) { + coordinatorSupportsChunkedFetch = false; + } + } + } + + if (logger.isDebugEnabled()) { + logger.debug( + "CHUNKED_FETCH decision: enabled={}, versionSupported={}, hasCoordinator={}, " + + "canConnectToCoordinator={}, channelVersion={}", + fetchPhaseChunkedEnabled, + versionSupported, + hasCoordinator, + canConnectToCoordinator, + channelVersion + ); + } + + FetchPhaseResponseChunk.Writer chunkWriter = null; + + // Decides whether to use chunked or traditional fetch based on: + // 1. Feature flag enabled on this node (fetchPhaseChunkedEnabled) + // 2. Channel transport version supports chunked fetch (versionSupported) + // 3. Request includes coordinator node info (hasCoordinator) - set by coordinator when using chunked path + // 4. Can establish connection back to coordinator (canConnectToCoordinator) - fails for CCS scenarios + // 5. Coordinator's connection supports chunked fetch version (coordinatorSupportsChunkedFetch) + // + // Double-checking here (already checked on coordinator side) ensures compatibility when + // coordinator and data node have different feature flag states or versions. + if (fetchPhaseChunkedEnabled && versionSupported && coordinatorSupportsChunkedFetch) { + ShardFetchSearchRequest fetchSearchReq = (ShardFetchSearchRequest) request; + + // Capture the current ThreadContext to preserve authentication headers + final Supplier contextSupplier = transportService.getThreadPool() + .getThreadContext() + .newRestorableContext(true); + + // Create chunk writer that provides both sending and buffer allocation. Each chunk is sent to the coordinator's + // TransportFetchPhaseResponseChunkAction endpoint. The coordinator accumulates chunks in a FetchPhaseResponseStream and + // sends ACKs. + chunkWriter = new FetchPhaseResponseChunk.Writer() { + @Override + public void writeResponseChunk(FetchPhaseResponseChunk responseChunk, ActionListener listener) { + ReleasableBytesReference bytesToSend = null; + // Restore the ThreadContext before sending the chunk + try (ThreadContext.StoredContext ignored = contextSupplier.get()) { + Transport.Connection connection = transportService.getConnection(fetchSearchReq.getCoordinatingNode()); + bytesToSend = responseChunk.toReleasableBytesReference(fetchSearchReq.getCoordinatingTaskId()); + BytesTransportRequest request = new BytesTransportRequest(bytesToSend, connection.getTransportVersion()); + + final ReleasableBytesReference bytesRef = bytesToSend; + bytesToSend = null; + + transportService.sendChildRequest( + connection, + TransportFetchPhaseResponseChunkAction.ZERO_COPY_ACTION_NAME, + request, + task, + TransportRequestOptions.EMPTY, + new ActionListenerResponseHandler<>( + ActionListener.releaseBefore(bytesRef, listener.map(r -> null)), + in -> ActionResponse.Empty.INSTANCE, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ) + ); + } catch (Exception e) { + Releasables.closeWhileHandlingException(bytesToSend); + listener.onFailure(e); + } + } + + @Override + public RecyclerBytesStreamOutput newNetworkBytesStream() { + return transportService.newNetworkBytesStream(searchService.getCircuitBreaker()); + } + }; + } + searchService.executeFetchPhase(request, (SearchShardTask) task, chunkWriter, new ChannelActionListener<>(channel)); + }; + transportService.registerRequestHandler( FETCH_ID_SCROLL_ACTION_NAME, EsExecutors.DIRECT_EXECUTOR_SERVICE, diff --git a/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java b/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java index 54e2b8ce71f83..d3f3e4d1c2c54 100644 --- a/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java @@ -567,6 +567,8 @@ public void apply(Settings value, Settings current, Settings previous) { SearchModule.SCRIPTED_METRICS_AGG_ALLOWED_STORED_SCRIPTS, SearchService.SEARCH_WORKER_THREADS_ENABLED, SearchService.QUERY_PHASE_PARALLEL_COLLECTION_ENABLED, + SearchService.FETCH_PHASE_CHUNKED_ENABLED, + SearchService.FETCH_PHASE_MAX_IN_FLIGHT_CHUNKS, SearchService.MEMORY_ACCOUNTING_BUFFER_SIZE, ThreadPool.ESTIMATED_TIME_INTERVAL_SETTING, ThreadPool.LATE_TIME_INTERVAL_WARN_THRESHOLD_SETTING, diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index 76a513a6905d9..c225e8990ef96 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -1298,6 +1298,7 @@ public void sendRequest( telemetryProvider.getTracer(), onlinePrewarmingService ); + searchTransportService.setSearchService(searchService); final SearchTaskWatchdog searchTaskWatchdog = new SearchTaskWatchdog( settingsModule.getClusterSettings(), diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 2bf8ab3313719..535ddb5faccc4 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -102,6 +102,7 @@ import org.elasticsearch.search.fetch.QueryFetchSearchResult; import org.elasticsearch.search.fetch.ScrollQueryFetchSearchResult; import org.elasticsearch.search.fetch.ShardFetchRequest; +import org.elasticsearch.search.fetch.chunk.FetchPhaseResponseChunk; import org.elasticsearch.search.fetch.subphase.FetchDocValuesContext; import org.elasticsearch.search.fetch.subphase.FetchFieldsContext; import org.elasticsearch.search.fetch.subphase.ScriptFieldsContext.ScriptField; @@ -257,6 +258,23 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv Property.Dynamic ); + private static final boolean CHUNKED_FETCH_PHASE_FEATURE_FLAG = new FeatureFlag("chunked_fetch_phase_enabled").isEnabled(); + + public static final Setting FETCH_PHASE_CHUNKED_ENABLED = Setting.boolSetting( + "search.fetch_phase_chunked_enabled", + CHUNKED_FETCH_PHASE_FEATURE_FLAG, + Property.NodeScope, + Property.Dynamic + ); + + public static final Setting FETCH_PHASE_MAX_IN_FLIGHT_CHUNKS = Setting.intSetting( + "search.fetch_phase_chunked_max_in_flight_chunks", + 3, // Conservative default: keeps a few chunk sends pipelined without allowing unbounded in-flight chunk memory. + 1, + Property.Dynamic, + Property.NodeScope + ); + public static final Setting MAX_OPEN_SCROLL_CONTEXT = Setting.intSetting( "search.max_open_scroll_context", 500, @@ -350,6 +368,8 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv private final int prewarmingMaxPoolFactorThreshold; private volatile Executor searchExecutor; private volatile boolean enableQueryPhaseParallelCollection; + private volatile boolean enableFetchPhaseChunked; + private volatile int fetchPhaseMaxInFlightChunks; private volatile long defaultKeepAlive; @@ -452,8 +472,13 @@ public SearchService( enableQueryPhaseParallelCollection = QUERY_PHASE_PARALLEL_COLLECTION_ENABLED.get(settings); batchQueryPhase = BATCHED_QUERY_PHASE.get(settings); + enableFetchPhaseChunked = FETCH_PHASE_CHUNKED_ENABLED.get(settings); + fetchPhaseMaxInFlightChunks = FETCH_PHASE_MAX_IN_FLIGHT_CHUNKS.get(settings); clusterService.getClusterSettings() .addSettingsUpdateConsumer(QUERY_PHASE_PARALLEL_COLLECTION_ENABLED, this::setEnableQueryPhaseParallelCollection); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FETCH_PHASE_CHUNKED_ENABLED, this::setEnableFetchPhaseChunked); + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(FETCH_PHASE_MAX_IN_FLIGHT_CHUNKS, this::setFetchPhaseMaxInFlightChunks); clusterService.getClusterSettings() .addSettingsUpdateConsumer(BATCHED_QUERY_PHASE, bulkExecuteQueryPhase -> this.batchQueryPhase = bulkExecuteQueryPhase); memoryAccountingBufferSize = MEMORY_ACCOUNTING_BUFFER_SIZE.get(settings).getBytes(); @@ -494,6 +519,18 @@ private void setEnableQueryPhaseParallelCollection(boolean enableQueryPhaseParal this.enableQueryPhaseParallelCollection = enableQueryPhaseParallelCollection; } + private void setEnableFetchPhaseChunked(boolean enableFetchPhaseChunked) { + this.enableFetchPhaseChunked = enableFetchPhaseChunked; + } + + public boolean fetchPhaseChunked() { + return enableFetchPhaseChunked; + } + + private void setFetchPhaseMaxInFlightChunks(int fetchPhaseMaxInFlightChunks) { + this.fetchPhaseMaxInFlightChunks = fetchPhaseMaxInFlightChunks; + } + private static void validateKeepAlives(TimeValue defaultKeepAlive, TimeValue maxKeepAlive) { if (defaultKeepAlive.millis() > maxKeepAlive.millis()) { throw new IllegalArgumentException( @@ -912,6 +949,7 @@ private static void runAsync( */ private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, CancellableTask task) throws Exception { final ReaderContext readerContext = createOrGetReaderContext(request); + try ( Releasable scope = tracer.withScope(task); Releasable ignored = readerContext.markAsUsed(getKeepAlive(request)); @@ -1028,6 +1066,145 @@ private QueryFetchSearchResult executeFetchPhase(ReaderContext reader, SearchCon return QueryFetchSearchResult.of(context.queryResult(), context.fetchResult()); } + /* + * Fetch phase lifecycle overview: + * + * 1. Fetch build phase: + * - Executes fetch sub-phases and builds hits + * - Signals success/failure via buildListener + * - Records stats and releases shard search context + * + * 2. Final completion phase: + * - For streaming responses, waits for all chunk ACKs + * - Completes the request listener + */ + public void executeFetchPhase( + ShardFetchRequest request, + CancellableTask task, + FetchPhaseResponseChunk.Writer writer, + ActionListener listener + ) { + final ActionListener releaseListener = releaseCircuitBreakerOnResponse(listener, result -> result); + final ReaderContext readerContext = findReaderContext(request.contextId(), request); + final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest()); + final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest)); + + // FetchPhase.execute() completes asynchronously: immediately for non-streaming, after all chunk ACKs for streaming + rewriteAndFetchShardRequest( + readerContext.indexShard(), + shardSearchRequest, + ActionListener.wrap( + rewritten -> doFetchPhase(request, readerContext, rewritten, task, markAsUsed, writer, releaseListener), + e -> { + Releasables.close(markAsUsed); + releaseListener.onFailure(e); + } + ) + ); + } + + /** + * Submits the fetch phase work to the search thread pool. Invoked after the shard request has been rewritten. + */ + private void doFetchPhase( + ShardFetchRequest request, + ReaderContext readerContext, + ShardSearchRequest rewritten, + CancellableTask task, + Releasable markAsUsed, + FetchPhaseResponseChunk.Writer writer, + ActionListener listener + ) { + getExecutor(readerContext.indexShard()).execute(new AbstractRunnable() { + private volatile SearchContext searchContext; + + private final Releasable closeOnce = Releasables.releaseOnce(Releasables.wrap(() -> { + if (readerContext.singleSession()) freeReaderContext(request.contextId()); + }, () -> Releasables.close(searchContext), markAsUsed)); + + @Override + protected void doRun() throws Exception { + final long startTime; + final SearchOperationListener opsListener; + + this.searchContext = createContext(readerContext, rewritten, task, ResultsType.FETCH, false); + startTime = System.nanoTime(); + opsListener = searchContext.indexShard().getSearchOperationListener(); + opsListener.onPreFetchPhase(searchContext); + + final FetchSearchResult fetchResult = searchContext.fetchResult(); + fetchResult.incRef(); + + try { + prepareFetchContext(request, readerContext, searchContext); + + fetchPhase.execute( + searchContext, + request.docIds(), + request.getRankDocks(), + null, + writer, + fetchPhaseMaxInFlightChunks, + newFetchBuildListener(opsListener, searchContext, startTime, closeOnce), + newFetchCompletionListener(listener, fetchResult) + ); + } catch (Exception e) { + try { + opsListener.onFailedFetchPhase(searchContext); + } finally { + Releasables.close(closeOnce, fetchResult::decRef); + } + throw e; + } + } + + @Override + public void onFailure(Exception e) { + assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e); + Releasables.close(closeOnce); + listener.onFailure(e); + } + }); + } + + private static void prepareFetchContext(ShardFetchRequest request, ReaderContext readerContext, SearchContext searchContext) { + if (request.lastEmittedDoc() != null) { + searchContext.scrollContext().lastEmittedDoc = request.lastEmittedDoc(); + } + searchContext.assignRescoreDocIds(readerContext.getRescoreDocIds(request.getRescoreDocIds())); + searchContext.searcher().setAggregatedDfs(readerContext.getAggregatedDfs(request.getAggregatedDfs())); + } + + /** + * Creates a listener that records fetch phase timing/failure stats and releases the SearchContext and shard resources + * once the fetch build completes (hits assembled or failed). + */ + private static ActionListener newFetchBuildListener( + SearchOperationListener opsListener, + SearchContext searchContext, + long startTime, + Releasable closeOnce + ) { + return ActionListener.runAfter( + ActionListener.wrap( + ignored -> opsListener.onFetchPhase(searchContext, System.nanoTime() - startTime), + e -> opsListener.onFailedFetchPhase(searchContext) + ), + closeOnce::close + ); + } + + /** + * Creates a listener that forwards the {@link FetchSearchResult} to the caller and manages the result's ref count. + * For streaming, this fires only after all response chunks have been ACKed. + */ + private static ActionListener newFetchCompletionListener( + ActionListener listener, + FetchSearchResult fetchResult + ) { + return ActionListener.releaseAfter(listener.map(ignored -> fetchResult), fetchResult::decRef); + } + public void executeQueryPhase( InternalScrollSearchRequest request, SearchShardTask task, @@ -1220,46 +1397,6 @@ public void executeFetchPhase( ); } - public void executeFetchPhase(ShardFetchRequest request, CancellableTask task, ActionListener listener) { - final ReaderContext readerContext = findReaderContext(request.contextId(), request); - final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest()); - final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest)); - rewriteAndFetchShardRequest(readerContext.indexShard(), shardSearchRequest, listener.delegateFailure((l, rewritten) -> { - runAsync(getExecutor(readerContext.indexShard()), () -> { - try (SearchContext searchContext = createContext(readerContext, rewritten, task, ResultsType.FETCH, false)) { - if (request.lastEmittedDoc() != null) { - searchContext.scrollContext().lastEmittedDoc = request.lastEmittedDoc(); - } - searchContext.assignRescoreDocIds(readerContext.getRescoreDocIds(request.getRescoreDocIds())); - searchContext.searcher().setAggregatedDfs(readerContext.getAggregatedDfs(request.getAggregatedDfs())); - final long startTime = System.nanoTime(); - var opsListener = searchContext.indexShard().getSearchOperationListener(); - opsListener.onPreFetchPhase(searchContext); - try { - fetchPhase.execute(searchContext, request.docIds(), request.getRankDocks()); - if (readerContext.singleSession()) { - freeReaderContext(request.contextId()); - } - opsListener.onFetchPhase(searchContext, System.nanoTime() - startTime); - opsListener = null; - } finally { - if (opsListener != null) { - opsListener.onFailedFetchPhase(searchContext); - } - } - var fetchResult = searchContext.fetchResult(); - // inc-ref fetch result because we close the SearchContext that references it in this try-with-resources block - fetchResult.incRef(); - return fetchResult; - } catch (Exception e) { - assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e); - // we handle the failure in the failure listener below - throw e; - } - }, wrapFailureListener(releaseCircuitBreakerOnResponse(listener, result -> result), readerContext, markAsUsed)); - })); - } - protected void checkCancelled(CancellableTask task) { // check cancellation as early as possible, as it avoids opening up a Lucene reader on FrozenEngine try { @@ -1278,6 +1415,7 @@ private ReaderContext findReaderContext(ShardSearchContextId id, TransportReques if (reader == null) { throw new SearchContextMissingException(id); } + try { reader.validate(request); } catch (Exception exc) { diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java index 121e3b00e2e90..e31d8e72107da 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java @@ -13,8 +13,14 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.TotalHits; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.RefCountingListener; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.util.concurrent.UncategorizedExecutionException; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Releasables; import org.elasticsearch.index.fieldvisitor.LeafStoredFieldLoader; import org.elasticsearch.index.fieldvisitor.StoredFieldLoader; import org.elasticsearch.index.mapper.IdLoader; @@ -25,8 +31,10 @@ import org.elasticsearch.search.SearchContextSourcePrinter; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.fetch.FetchSubPhase.HitContext; +import org.elasticsearch.search.fetch.chunk.FetchPhaseResponseChunk; import org.elasticsearch.search.fetch.subphase.FetchFieldsContext; import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.InnerHitsContext; @@ -48,6 +56,8 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.IntConsumer; import java.util.function.Supplier; @@ -56,9 +66,11 @@ /** * Fetch phase of a search request, used to fetch the actual top matching documents to be returned to the client, identified - * after reducing all of the matches returned by the query phase + * after reducing all the matches returned by the query phase + * Supports both traditional mode (all results in memory) and streaming mode (results sent in chunks). */ public final class FetchPhase { + private static final Logger LOGGER = LogManager.getLogger(FetchPhase.class); private final FetchSubPhase[] fetchSubPhases; @@ -68,18 +80,80 @@ public FetchPhase(List fetchSubPhases) { this.fetchSubPhases[fetchSubPhases.size()] = new InnerHitsPhase(this); } + /** + * Executes the fetch phase without memory checking or streaming. + * + * @param context the search context + * @param docIdsToLoad document IDs to fetch + * @param rankDocs ranking information + */ public void execute(SearchContext context, int[] docIdsToLoad, RankDocShardInfo rankDocs) { - execute(context, docIdsToLoad, rankDocs, null); + // Synchronous wrapper for backward compatibility, + PlainActionFuture future = new PlainActionFuture<>(); + execute(context, docIdsToLoad, rankDocs, null, null, null, null, future); + try { + future.actionGet(); + } catch (UncategorizedExecutionException e) { + // PlainActionFuture wraps non-ElasticsearchException failures in UncategorizedExecutionException. + // Translate to FetchPhaseExecutionException to preserve the expected exception type and cause. + throw new FetchPhaseExecutionException(context.shardTarget(), "Fetch phase failed", e.getCause()); + } } /** + * Executes the fetch phase with an optional caller-supplied memory tracking callback and no streaming * - * @param context - * @param docIdsToLoad - * @param rankDocs - * @param memoryChecker if not provided, the fetch phase will use the circuit breaker to check memory usage + * @param context the search context + * @param docIdsToLoad document IDs to fetch + * @param rankDocs ranking information + * @param memoryChecker optional callback for memory tracking, may be null */ public void execute(SearchContext context, int[] docIdsToLoad, RankDocShardInfo rankDocs, @Nullable IntConsumer memoryChecker) { + // Synchronous wrapper for backward compatibility, + PlainActionFuture future = new PlainActionFuture<>(); + execute(context, docIdsToLoad, rankDocs, memoryChecker, null, null, null, future); + try { + future.actionGet(); + } catch (UncategorizedExecutionException e) { + // PlainActionFuture wraps non-ElasticsearchException failures in UncategorizedExecutionException. + // Translate to FetchPhaseExecutionException to preserve the expected exception type and cause. + throw new FetchPhaseExecutionException(context.shardTarget(), "Fetch phase failed", e.getCause()); + } + } + + /** + * Executes the fetch phase with an optional caller-supplied memory tracking callback and optional streaming. + * + *

When {@code writer} is {@code null} (non-streaming), all hits are accumulated in memory and returned at once. + * When {@code writer} is provided (streaming), hits are emitted in chunks to reduce peak memory usage. In streaming mode, + * the final completion may be delayed by transport-level acknowledgements, but the fetch build completion is signaled as + * soon as the fetch work has finished. + * + * @param context the search context + * @param docIdsToLoad document IDs to fetch + * @param rankDocs ranking information + * @param memoryChecker optional callback for memory tracking, may be {@code null} + * @param writer optional chunk writer for streaming mode, may be {@code null} + * @param buildListener optional listener invoked when all {@link SearchHit} objects have been constructed + * (and, in streaming mode, serialized into chunks and dispatched to the writer). + * In non-streaming mode this fires immediately after the hits are built, just like {@code listener}. + * In streaming mode this fires before chunk ACKs arrive, allowing the caller to release + * shard resources (e.g. close the SearchContext) without waiting for network acknowledgements. + * @param listener final completion listener. In streaming mode this is invoked only after all chunks are ACKed; in + * non-streaming mode it is invoked immediately after hits are built. + * + * @throws TaskCancelledException if the task is cancelled + */ + public void execute( + SearchContext context, + int[] docIdsToLoad, + RankDocShardInfo rankDocs, + @Nullable IntConsumer memoryChecker, + @Nullable FetchPhaseResponseChunk.Writer writer, + @Nullable Integer maxInFlightChunks, + @Nullable ActionListener buildListener, + ActionListener listener + ) { if (LOGGER.isTraceEnabled()) { LOGGER.trace("{}", new SearchContextSourcePrinter(context)); } @@ -88,41 +162,60 @@ public void execute(SearchContext context, int[] docIdsToLoad, RankDocShardInfo throw new TaskCancelledException("cancelled"); } + final ActionListener resolvedBuildListener = buildListener != null ? buildListener : ActionListener.noop(); + if (docIdsToLoad == null || docIdsToLoad.length == 0) { // no individual hits to process, so we shortcut context.fetchResult() .shardResult(SearchHits.empty(context.queryResult().getTotalHits(), context.queryResult().getMaxScore()), null); + resolvedBuildListener.onResponse(null); + listener.onResponse(null); return; } - Profiler profiler = context.getProfilers() == null + final Profiler profiler = context.getProfilers() == null || (context.request().source() != null && context.request().source().rankBuilder() != null) ? Profiler.NOOP : Profilers.startProfilingFetchPhase(); - SearchHits hits = null; - long searchHitsBytesSize = 0L; - try { - SearchHitsWithSizeBytes result = buildSearchHits(context, docIdsToLoad, profiler, rankDocs, memoryChecker); - hits = result.hits; - searchHitsBytesSize = result.searchHitsBytesSize; - } finally { + + var docsIterator = createDocsIterator(context, profiler, rankDocs, writer != null ? bytes -> {} : memoryChecker); + + // Common completion handler for both sync and streaming modes + // finalizes profiling, stores the shard result, and signals the outer listener. + ActionListener hitsListener = listener.map(hitsAndBytes -> { + SearchHits hitsToRelease = hitsAndBytes.hits; try { - // Always finish profiling ProfileResult profileResult = profiler.finish(); - // Only set the shardResults if building search hits was successful - if (hits != null) { - context.fetchResult().shardResult(hits, profileResult); - context.fetchResult().setSearchHitsSizeBytes(searchHitsBytesSize); - hits = null; - } else { - assert searchHitsBytesSize == 0L - : "searchHitsBytesSize must be 0 when hits are null but was [" + searchHitsBytesSize + "]"; + context.fetchResult().shardResult(hitsAndBytes.hits, profileResult); + + if (writer == null) { + context.fetchResult().setSearchHitsSizeBytes(hitsAndBytes.searchHitsBytesSize); } + + hitsToRelease = null; + return null; } finally { - if (hits != null) { - hits.decRef(); + if (hitsToRelease != null) { + hitsToRelease.decRef(); } } + }); + + if (writer == null) { + buildSearchHits(context, docIdsToLoad, docsIterator, resolvedBuildListener, hitsListener); + } else { + int resolvedMaxInFlightChunks = maxInFlightChunks != null + ? maxInFlightChunks + : SearchService.FETCH_PHASE_MAX_IN_FLIGHT_CHUNKS.get(context.getSearchExecutionContext().getIndexSettings().getSettings()); + buildSearchHitsStreaming( + context, + docIdsToLoad, + docsIterator, + writer, + resolvedMaxInFlightChunks, + resolvedBuildListener, + hitsListener + ); } } @@ -136,12 +229,15 @@ public Source getSource(LeafReaderContext ctx, int doc) { } } - private SearchHitsWithSizeBytes buildSearchHits( + /** + * Creates the docs iterator that handles per-document fetching and sub-phase processing. + * Shared between sync and streaming modes; the memoryChecker parameter controls per-hit memory accounting. + */ + private StreamingFetchPhaseDocsIterator createDocsIterator( SearchContext context, - int[] docIdsToLoad, Profiler profiler, RankDocShardInfo rankDocs, - IntConsumer memoryChecker + @Nullable IntConsumer memoryChecker ) { var lookup = context.getSearchExecutionContext().getMappingLookup(); @@ -198,7 +294,7 @@ private SearchHitsWithSizeBytes buildSearchHits( final int[] locallyAccumulatedBytes = new int[1]; NestedDocuments nestedDocuments = context.getSearchExecutionContext().getNestedDocuments(); - FetchPhaseDocsIterator docsIterator = new FetchPhaseDocsIterator() { + return new StreamingFetchPhaseDocsIterator() { LeafReaderContext ctx; LeafNestedDocuments leafNestedDocuments; @@ -253,6 +349,7 @@ protected SearchHit nextDoc(int doc) throws IOException { leafIdLoader, rankDocs == null ? null : rankDocs.get(doc) ); + boolean success = false; try { sourceProvider.source = hit.source(); @@ -276,35 +373,152 @@ protected SearchHit nextDoc(int doc) throws IOException { } } }; + } - try { - SearchHit[] hits = docsIterator.iterate( + /** + * Synchronous fetch: iterates all documents, collects hits in memory, and returns them at once. + */ + private void buildSearchHits( + SearchContext context, + int[] docIdsToLoad, + FetchPhaseDocsIterator docsIterator, + ActionListener buildListener, + ActionListener listener + ) { + ActionListener wrappedListener = new ActionListener<>() { + @Override + public void onResponse(SearchHitsWithSizeBytes result) { + buildListener.onResponse(null); + listener.onResponse(result); + } + + @Override + public void onFailure(Exception e) { + long leakedBytes = docsIterator.getRequestBreakerBytes(); + if (leakedBytes > 0) { + context.circuitBreaker().addWithoutBreaking(-leakedBytes); + } + buildListener.onFailure(e); + listener.onFailure(e); + } + }; + + ActionListener.runWithResource( + wrappedListener, + () -> docsIterator.iterate( context.shardTarget(), context.searcher().getIndexReader(), docIdsToLoad, context.request().allowPartialSearchResults(), context.queryResult() - ); + ), + (l, result) -> { + if (context.isCancelled()) { + for (SearchHit hit : result.hits) { + if (hit != null) { + hit.decRef(); + } + } + throw new TaskCancelledException("cancelled"); + } + TotalHits totalHits = context.getTotalHits(); + SearchHits searchHits = new SearchHits(result.hits, totalHits, context.getMaxScore()); + l.onResponse(new SearchHitsWithSizeBytes(searchHits, docsIterator.getRequestBreakerBytes())); + } + ); + } + + /** + * Streaming fetch: iterates documents and streams them in chunks to reduce peak memory usage. + * Each chunk is sent via the writer and ACKed by the coordinator; backpressure is applied + * through page-level circuit breaker tracking in the network byte stream and in-flight chunk limits. + */ + private void buildSearchHitsStreaming( + SearchContext context, + int[] docIdsToLoad, + StreamingFetchPhaseDocsIterator docsIterator, + FetchPhaseResponseChunk.Writer writer, + int maxInFlightChunks, + ActionListener buildListener, + ActionListener listener + ) { + final AtomicReference sendFailure = new AtomicReference<>(); + final AtomicReference lastChunkBytesRef = new AtomicReference<>(); + final AtomicLong lastChunkHitCountRef = new AtomicLong(0); + final AtomicLong lastChunkSequenceStartRef = new AtomicLong(-1); + + final int targetChunkBytes = StreamingFetchPhaseDocsIterator.DEFAULT_TARGET_CHUNK_BYTES; + + // RefCountingListener tracks chunk ACKs in streaming mode. + // Each chunk calls acquire() to get a listener, which is completed when the ACK arrives. + // When all acquired listeners complete, the completion callback below runs, + // returning the final SearchHits (last chunk) to the caller. + final RefCountingListener chunkCompletionRefs = new RefCountingListener(listener.delegateFailureAndWrap((l, ignored) -> { + ReleasableBytesReference lastChunkBytes = lastChunkBytesRef.getAndSet(null); + try { + long seqStart = lastChunkSequenceStartRef.get(); + if (seqStart >= 0) { + context.fetchResult().setLastChunkSequenceStart(seqStart); + } - if (context.isCancelled()) { - for (SearchHit hit : hits) { - // release all hits that would otherwise become owned and eventually released by SearchHits below - hit.decRef(); + long countLong = lastChunkHitCountRef.get(); + if (lastChunkBytes != null && countLong > 0) { + int hitCount = Math.toIntExact(countLong); + context.fetchResult().setLastChunkBytes(lastChunkBytes, hitCount); + lastChunkBytes = null; } - throw new TaskCancelledException("cancelled"); + + l.onResponse(new SearchHitsWithSizeBytes(SearchHits.empty(context.getTotalHits(), context.getMaxScore()), 0)); + } finally { + Releasables.closeWhileHandlingException(lastChunkBytes); } + })); + + final ActionListener mainBuildListener = chunkCompletionRefs.acquire(); + chunkCompletionRefs.close(); + + docsIterator.iterateAsync( + context.shardTarget(), + context.searcher().getIndexReader(), + docIdsToLoad, + writer, + targetChunkBytes, + chunkCompletionRefs, + maxInFlightChunks, + sendFailure, + context::isCancelled, + new ActionListener<>() { + @Override + public void onResponse(FetchPhaseDocsIterator.IterateResult result) { + try (result) { + if (context.isCancelled()) { + onFailure(new TaskCancelledException("cancelled")); + return; + } + + if (result.lastChunkBytes != null) { + lastChunkBytesRef.set(result.takeLastChunkBytes()); + lastChunkHitCountRef.set(result.lastChunkHitCount); + lastChunkSequenceStartRef.set(result.lastChunkSequenceStart); + } + } catch (Exception e) { + onFailure(e); + return; + } + buildListener.onResponse(null); + mainBuildListener.onResponse(null); + } - TotalHits totalHits = context.getTotalHits(); - SearchHits searchHits = new SearchHits(hits, totalHits, context.getMaxScore()); - return new SearchHitsWithSizeBytes(searchHits, docsIterator.getRequestBreakerBytes()); - } catch (Exception e) { - // On exception, release the breaker bytes immediately since the hits won't make it to the result - long bytes = docsIterator.getRequestBreakerBytes(); - if (bytes > 0L) { - context.circuitBreaker().addWithoutBreaking(-bytes); + @Override + public void onFailure(Exception e) { + ReleasableBytesReference lastChunkBytes = lastChunkBytesRef.getAndSet(null); + Releasables.closeWhileHandlingException(lastChunkBytes); + + buildListener.onFailure(e); + mainBuildListener.onFailure(e); + } } - throw e; - } + ); } List getProcessors(SearchShardTarget target, FetchContext context, Profiler profiler) { diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java index df29b4d1fad88..8f662f5b8e5d0 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java @@ -13,8 +13,9 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.ReaderUtil; import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.core.Releasables; import org.elasticsearch.search.SearchHit; -import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.internal.ContextIndexSearcher; import org.elasticsearch.search.query.QuerySearchResult; @@ -24,11 +25,14 @@ import java.util.Arrays; /** - * Given a set of doc ids and an index reader, sorts the docs by id, splits the sorted - * docs by leaf reader, and iterates through them calling abstract methods - * {@link #setNextReader(LeafReaderContext, int[])} for each new leaf reader and - * {@link #nextDoc(int)} for each document; then collects the resulting {@link SearchHit}s - * into an array and returns them in the order of the original doc ids. + * Iterates through a set of document IDs, fetching each document and collecting + * the resulting {@link SearchHit}s. + *

+ * Documents are sorted by doc ID for efficient sequential Lucene access, then results + * are mapped back to their original score-based order. All hits are collected in memory + * and returned at once. + * + * @see StreamingFetchPhaseDocsIterator */ abstract class FetchPhaseDocsIterator { @@ -47,23 +51,37 @@ public long getRequestBreakerBytes() { } /** - * Called when a new leaf reader is reached - * @param ctx the leaf reader for this set of doc ids - * @param docsInLeaf the reader-specific docids to be fetched in this leaf reader + * Called when a new leaf reader is reached. + * + * @param ctx the leaf reader for this set of doc ids + * @param docsInLeaf the reader-specific docids to be fetched in this leaf reader */ protected abstract void setNextReader(LeafReaderContext ctx, int[] docsInLeaf) throws IOException; /** - * Called for each document within a leaf reader - * @param doc the global doc id + * Called for each document within a leaf reader. + * + * @param doc the global doc id * @return a {@link SearchHit} for the document */ protected abstract SearchHit nextDoc(int doc) throws IOException; /** - * Iterate over a set of docsIds within a particular shard and index reader + * Synchronous iteration for non-streaming mode. + * Documents are sorted by doc ID for efficient sequential Lucene access, + * then results are mapped back to their original (score-based) order. + * + * @param shardTarget the shard being fetched from + * @param indexReader the index reader for accessing documents + * @param docIds document IDs to fetch (in score order) + * @param allowPartialResults if true, return partial results on timeout instead of failing + * @param querySearchResult query result for recording timeout state + * + * @return IterateResult containing fetched hits in original score order + * @throws SearchTimeoutException if timeout occurs and partial results not allowed + * @throws FetchPhaseExecutionException if fetch fails for a document */ - public final SearchHit[] iterate( + public final IterateResult iterate( SearchShardTarget shardTarget, IndexReader indexReader, int[] docIds, @@ -77,19 +95,26 @@ public final SearchHit[] iterate( } // make sure that we iterate in doc id order Arrays.sort(docs); - int currentDoc = docs[0].docId; + int currentDoc = docs.length > 0 ? docs[0].docId : -1; + try { + if (docs.length == 0) { + return new IterateResult(searchHits); + } + int leafOrd = ReaderUtil.subIndex(docs[0].docId, indexReader.leaves()); LeafReaderContext ctx = indexReader.leaves().get(leafOrd); int endReaderIdx = endReaderIdx(ctx, 0, docs); int[] docsInLeaf = docIdsInLeaf(0, endReaderIdx, docs, ctx.docBase); + try { setNextReader(ctx, docsInLeaf); } catch (ContextIndexSearcher.TimeExceededException e) { SearchTimeoutException.handleTimeout(allowPartialResults, shardTarget, querySearchResult); assert allowPartialResults; - return SearchHits.EMPTY; + return new IterateResult(new SearchHit[0]); } + for (int i = 0; i < docs.length; i++) { try { if (i >= endReaderIdx) { @@ -110,7 +135,7 @@ public final SearchHit[] iterate( assert allowPartialResults; SearchHit[] partialSearchHits = new SearchHit[i]; System.arraycopy(searchHits, 0, partialSearchHits, 0, i); - return partialSearchHits; + return new IterateResult(partialSearchHits); } } } catch (SearchTimeoutException e) { @@ -122,7 +147,7 @@ public final SearchHit[] iterate( purgeSearchHits(searchHits); throw new FetchPhaseExecutionException(shardTarget, "Error running fetch phase for doc [" + currentDoc + "]", e); } - return searchHits; + return new IterateResult(searchHits); } private static void purgeSearchHits(SearchHit[] searchHits) { @@ -169,4 +194,58 @@ public int compareTo(DocIdToIndex o) { return Integer.compare(docId, o.docId); } } + + /** + * Result of iteration. + * For non-streaming: contains hits array. + * For streaming: contains last chunk bytes to be sent after all ACKs. The bytes carry + * page-level circuit breaker tracking from the {@link org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput}; + * releasing the bytes automatically decrements the breaker. + */ + static class IterateResult implements AutoCloseable { + final SearchHit[] hits; // Non-streaming mode only + final ReleasableBytesReference lastChunkBytes; + final int lastChunkHitCount; + final long lastChunkSequenceStart; + private boolean closed = false; + private boolean bytesOwnershipTransferred = false; + + // Non-streaming constructor + IterateResult(SearchHit[] hits) { + this.hits = hits; + this.lastChunkBytes = null; + this.lastChunkHitCount = 0; + this.lastChunkSequenceStart = -1; + } + + // Streaming constructor + IterateResult(ReleasableBytesReference lastChunkBytes, int hitCount, long seqStart) { + this.hits = null; + this.lastChunkBytes = lastChunkBytes; + this.lastChunkHitCount = hitCount; + this.lastChunkSequenceStart = seqStart; + } + + /** + * Takes ownership of the last chunk bytes. + * After calling, close() will not release the bytes. The caller becomes responsible + * for eventually releasing the {@link ReleasableBytesReference} (which decrements the circuit breaker). + * + * @return the last chunk bytes, or null if none + */ + ReleasableBytesReference takeLastChunkBytes() { + bytesOwnershipTransferred = true; + return lastChunkBytes; + } + + @Override + public void close() { + if (closed) return; + closed = true; + + if (bytesOwnershipTransferred == false) { + Releasables.closeWhileHandlingException(lastChunkBytes); + } + } + } } diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchSearchResult.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchSearchResult.java index 6703bb5bd1920..e6b4161eeca1f 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchSearchResult.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchSearchResult.java @@ -10,9 +10,12 @@ package org.elasticsearch.search.fetch; import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.RefCounted; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; import org.elasticsearch.core.SimpleRefCounted; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; @@ -24,6 +27,8 @@ import java.io.IOException; +import static org.elasticsearch.search.fetch.chunk.TransportFetchPhaseCoordinationAction.CHUNKED_FETCH_PHASE; + public final class FetchSearchResult extends SearchPhaseResult { private SearchHits hits; @@ -35,6 +40,24 @@ public final class FetchSearchResult extends SearchPhaseResult { private ProfileResult profileResult; + /** + * Sequence number of the first hit in the last chunk (embedded in this result). + * Used by the coordinator to maintain correct ordering when processing the last chunk. + * Value of -1 indicates no last chunk or sequence tracking not applicable. + */ + private long lastChunkSequenceStart = -1; + + /** + * Raw serialized bytes of the last chunk's hits. + */ + private BytesReference lastChunkBytes; + + /** + * Number of hits in the last chunk bytes. + * Used by the coordinator to know how many hits to deserialize from lastChunkBytes. + */ + private int lastChunkHitCount; + private final RefCounted refCounted = LeakTracker.wrap(new SimpleRefCounted()); public FetchSearchResult() {} @@ -48,6 +71,14 @@ public FetchSearchResult(StreamInput in) throws IOException { contextId = new ShardSearchContextId(in); hits = SearchHits.readFrom(in, true); profileResult = in.readOptionalWriteable(ProfileResult::new); + + if (in.getTransportVersion().supports(CHUNKED_FETCH_PHASE)) { + lastChunkSequenceStart = in.readLong(); + lastChunkHitCount = in.readInt(); + if (lastChunkHitCount > 0) { + lastChunkBytes = in.readReleasableBytesReference(); + } + } } @Override @@ -56,6 +87,14 @@ public void writeTo(StreamOutput out) throws IOException { contextId.writeTo(out); hits.writeTo(out); out.writeOptionalWriteable(profileResult); + + if (out.getTransportVersion().supports(CHUNKED_FETCH_PHASE)) { + out.writeLong(lastChunkSequenceStart); + out.writeInt(lastChunkHitCount); + if (lastChunkHitCount > 0 && lastChunkBytes != null) { + out.writeBytesReference(lastChunkBytes); + } + } } @Override @@ -139,10 +178,77 @@ private void deallocate() { hits.decRef(); hits = null; } + releaseLastChunkBytes(); } @Override public boolean hasReferences() { return refCounted.hasReferences(); } + + /** + * Sets the sequence start for the last chunk embedded in this result. + * Called on the data node after iterating fetch phase results. + * + * @param sequenceStart the sequence number of the first hit in the last chunk + */ + public void setLastChunkSequenceStart(long sequenceStart) { + this.lastChunkSequenceStart = sequenceStart; + } + + /** + * Gets the sequence start for the last chunk embedded in this result. + * Used by the coordinator to properly order last chunk hits with other chunks. + * + * @return the sequence number of the first hit in the last chunk, or -1 if not set + */ + public long getLastChunkSequenceStart() { + return lastChunkSequenceStart; + } + + /** + * Sets the raw bytes of the last chunk. + * Called on the data node in chunked fetch mode to avoid deserializing + * large hit data that would cause OOM. + * + *

Takes ownership of the bytes reference - caller must not release it. + * + * @param bytes the serialized hit bytes + * @param hitCount the number of hits in the bytes + */ + public void setLastChunkBytes(BytesReference bytes, int hitCount) { + releaseLastChunkBytes(); // Release any existing bytes + this.lastChunkBytes = bytes; + this.lastChunkHitCount = hitCount; + } + + /** + * Gets the raw bytes of the last chunk. + * Used by the coordinator to deserialize and merge with other accumulated chunks. + * + * @return the serialized hit bytes, or null if not set + */ + public BytesReference getLastChunkBytes() { + return lastChunkBytes; + } + + /** + * Gets the number of hits in the last chunk bytes. + * + * @return the hit count, or 0 if no last chunk + */ + public int getLastChunkHitCount() { + return lastChunkHitCount; + } + + /** + * Releases the last chunk bytes if they are releasable. + */ + private void releaseLastChunkBytes() { + if (lastChunkBytes instanceof Releasable releasable) { + Releasables.closeWhileHandlingException(releasable); + } + lastChunkBytes = null; + lastChunkHitCount = 0; + } } diff --git a/server/src/main/java/org/elasticsearch/search/fetch/ShardFetchSearchRequest.java b/server/src/main/java/org/elasticsearch/search/fetch/ShardFetchSearchRequest.java index 8c068fd6f8839..5b2dd7eede945 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/ShardFetchSearchRequest.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/ShardFetchSearchRequest.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.IndicesRequest; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.search.RescoreDocIds; @@ -24,6 +25,8 @@ import java.io.IOException; import java.util.List; +import static org.elasticsearch.search.fetch.chunk.TransportFetchPhaseCoordinationAction.CHUNKED_FETCH_PHASE; + /** * Shard level fetch request used with search. Holds indices taken from the original search request * and implements {@link org.elasticsearch.action.IndicesRequest}. @@ -35,6 +38,8 @@ public class ShardFetchSearchRequest extends ShardFetchRequest implements Indice private final RescoreDocIds rescoreDocIds; private final AggregatedDfs aggregatedDfs; private final RankDocShardInfo rankDocs; + private DiscoveryNode coordinatingNode; + private long coordinatingTaskId; public ShardFetchSearchRequest( OriginalIndices originalIndices, @@ -61,6 +66,11 @@ public ShardFetchSearchRequest(StreamInput in) throws IOException { rescoreDocIds = new RescoreDocIds(in); aggregatedDfs = in.readOptionalWriteable(AggregatedDfs::new); this.rankDocs = in.readOptionalWriteable(RankDocShardInfo::new); + + if (in.getTransportVersion().supports(CHUNKED_FETCH_PHASE)) { + coordinatingNode = in.readOptionalWriteable(DiscoveryNode::new); + coordinatingTaskId = in.readLong(); + } } @Override @@ -71,6 +81,11 @@ public void writeTo(StreamOutput out) throws IOException { rescoreDocIds.writeTo(out); out.writeOptionalWriteable(aggregatedDfs); out.writeOptionalWriteable(rankDocs); + + if (out.getTransportVersion().supports(CHUNKED_FETCH_PHASE)) { + out.writeOptionalWriteable(coordinatingNode); + out.writeLong(coordinatingTaskId); + } } @Override @@ -109,6 +124,22 @@ public RankDocShardInfo getRankDocks() { return this.rankDocs; } + public DiscoveryNode getCoordinatingNode() { + return coordinatingNode; + } + + public long getCoordinatingTaskId() { + return coordinatingTaskId; + } + + public void setCoordinatingNode(DiscoveryNode coordinatingNode) { + this.coordinatingNode = coordinatingNode; + } + + public void setCoordinatingTaskId(long coordinatingTaskId) { + this.coordinatingTaskId = coordinatingTaskId; + } + @Override public String getDescription() { StringBuilder sb = new StringBuilder(super.getDescription()); diff --git a/server/src/main/java/org/elasticsearch/search/fetch/StreamingFetchPhaseDocsIterator.java b/server/src/main/java/org/elasticsearch/search/fetch/StreamingFetchPhaseDocsIterator.java new file mode 100644 index 0000000000000..fe725cd65551a --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/fetch/StreamingFetchPhaseDocsIterator.java @@ -0,0 +1,432 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.fetch; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.ReaderUtil; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.RefCountingListener; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ThrottledTaskRunner; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.fetch.chunk.FetchPhaseResponseChunk; +import org.elasticsearch.tasks.TaskCancelledException; + +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +/** + * Extends {@link FetchPhaseDocsIterator} with asynchronous chunked iteration + * via {@link #iterateAsync}. The synchronous {@link #iterate} method from the + * parent class remains available for non-streaming use. + *

+ * Uses {@link ThrottledTaskRunner} with {@link EsExecutors#DIRECT_EXECUTOR_SERVICE} to + * manage chunk sends: + *

    + *
  • Fetches documents and creates chunks
  • + *
  • Send tasks are enqueued directly to ThrottledTaskRunner
  • + *
  • Tasks run inline when under maxInFlightChunks capacity
  • + *
  • When at capacity, tasks queue internally until ACKs arrive
  • + *
  • ACK callbacks signal task completion, triggering queued tasks
  • + *
+ * Threading: All Lucene operations execute on the calling thread to satisfy + * Lucene's thread-affinity requirements. Send tasks run inline (DIRECT_EXECUTOR) when + * under capacity; ACK handling occurs asynchronously on network threads. + *

+ * Memory Management: The circuit breaker tracks recycler page allocations via the + * {@link RecyclerBytesStreamOutput} passed from the chunk writer. If the breaker trips + * during serialization, the producer fails immediately with a + * {@link org.elasticsearch.common.breaker.CircuitBreakingException}, preventing unbounded + * memory growth. Pages are released (and the breaker decremented) when the + * {@link ReleasableBytesReference} from {@link RecyclerBytesStreamOutput#moveToBytesReference()} + * is closed — either on ACK for intermediate chunks or when the last chunk is consumed. + *

+ * Backpressure: {@link ThrottledTaskRunner} limits concurrent in-flight sends to + * {@code maxInFlightChunks}. The circuit breaker provides the memory limit. + *

+ * Cancellation: The producer checks the cancellation flag periodically. + */ +abstract class StreamingFetchPhaseDocsIterator extends FetchPhaseDocsIterator { + + /** + * Default target chunk size in bytes (256KB). + * Chunks may slightly exceed this as we complete the current hit before checking. + */ + static final int DEFAULT_TARGET_CHUNK_BYTES = 256 * 1024; + + /** + * Asynchronous iteration using {@link ThrottledTaskRunner} for streaming mode. + * + * @param shardTarget the shard being fetched from + * @param indexReader the index reader + * @param docIds document IDs to fetch (in score order) + * @param chunkWriter writer for sending chunks (also provides buffer allocation with CB tracking) + * @param targetChunkBytes target size in bytes for each chunk + * @param chunkCompletionRefs ref-counting listener for tracking chunk ACKs + * @param maxInFlightChunks maximum concurrent unacknowledged chunks + * @param sendFailure atomic reference to capture send failures + * @param isCancelled supplier for cancellation checking + * @param listener receives the result with the last chunk bytes + */ + void iterateAsync( + SearchShardTarget shardTarget, + IndexReader indexReader, + int[] docIds, + FetchPhaseResponseChunk.Writer chunkWriter, + int targetChunkBytes, + RefCountingListener chunkCompletionRefs, + int maxInFlightChunks, + AtomicReference sendFailure, + Supplier isCancelled, + ActionListener listener + ) { + if (docIds == null || docIds.length == 0) { + listener.onResponse(new IterateResult(new SearchHit[0])); + return; + } + + final AtomicReference lastChunkHolder = new AtomicReference<>(); + final AtomicReference producerError = new AtomicReference<>(); + + // ThrottledTaskRunner manages send concurrency + final ThrottledTaskRunner sendRunner = new ThrottledTaskRunner("fetch", maxInFlightChunks, EsExecutors.DIRECT_EXECUTOR_SERVICE); + + // RefCountingListener fires completion callback when all refs are released. + final RefCountingListener completionRefs = new RefCountingListener(ActionListener.wrap(ignored -> { + + final Throwable pError = producerError.get(); + if (pError != null) { + cleanupLastChunk(lastChunkHolder); + listener.onFailure(pError instanceof Exception ? (Exception) pError : new RuntimeException(pError)); + return; + } + + final Throwable sError = sendFailure.get(); + if (sError != null) { + cleanupLastChunk(lastChunkHolder); + listener.onFailure(sError instanceof Exception ? (Exception) sError : new RuntimeException(sError)); + return; + } + + if (isCancelled.get()) { + cleanupLastChunk(lastChunkHolder); + listener.onFailure(new TaskCancelledException("cancelled")); + return; + } + + final PendingChunk lastChunk = lastChunkHolder.getAndSet(null); + if (lastChunk == null) { + listener.onResponse(new IterateResult(new SearchHit[0])); + return; + } + + try { + listener.onResponse(new IterateResult(lastChunk.bytes, lastChunk.hitCount, lastChunk.sequenceStart)); + } catch (Exception e) { + lastChunk.close(); + throw e; + } + }, e -> { + cleanupLastChunk(lastChunkHolder); + listener.onFailure(e); + })); + + try { + produceChunks( + shardTarget.getShardId(), + indexReader, + docIds, + chunkWriter, + targetChunkBytes, + sendRunner, + completionRefs, + lastChunkHolder, + sendFailure, + chunkCompletionRefs, + isCancelled + ); + } catch (Exception e) { + producerError.set(e); + } finally { + completionRefs.close(); + } + } + + /** + * Produces chunks and enqueues send tasks to ThrottledTaskRunner. + *

+ * For each chunk: + *

    + *
  1. Fetch documents and serialize to bytes (page allocations tracked by the CB in the stream)
  2. + *
  3. For intermediate chunks: acquire ref and enqueue send task to ThrottledTaskRunner
  4. + *
  5. For last chunk: store in lastChunkHolder (returned via listener after all ACKs)
  6. + *
+ */ + private void produceChunks( + ShardId shardId, + IndexReader indexReader, + int[] docIds, + FetchPhaseResponseChunk.Writer chunkWriter, + int targetChunkBytes, + ThrottledTaskRunner sendRunner, + RefCountingListener completionRefs, + AtomicReference lastChunkHolder, + AtomicReference sendFailure, + RefCountingListener chunkCompletionRefs, + Supplier isCancelled + ) throws Exception { + int totalDocs = docIds.length; + RecyclerBytesStreamOutput chunkBuffer = null; + + try { + chunkBuffer = chunkWriter.newNetworkBytesStream(); + int chunkStartIndex = 0; + int hitsInChunk = 0; + + for (int scoreIndex = 0; scoreIndex < totalDocs; scoreIndex++) { + if (scoreIndex % 32 == 0) { + if (isCancelled.get()) { + throw new TaskCancelledException("cancelled"); + } + Throwable failure = sendFailure.get(); + if (failure != null) { + throw failure instanceof Exception ? (Exception) failure : new RuntimeException(failure); + } + } + + int docId = docIds[scoreIndex]; + + int leafOrd = ReaderUtil.subIndex(docId, indexReader.leaves()); + LeafReaderContext ctx = indexReader.leaves().get(leafOrd); + int leafDocId = docId - ctx.docBase; + setNextReader(ctx, new int[] { leafDocId }); + + SearchHit hit = nextDoc(docId); + try { + hit.writeTo(chunkBuffer); + } finally { + hit.decRef(); + } + hitsInChunk++; + + boolean isLast = (scoreIndex == totalDocs - 1); + boolean bufferFull = chunkBuffer.size() >= targetChunkBytes; + + if (bufferFull || isLast) { + final ReleasableBytesReference chunkBytes = chunkBuffer.moveToBytesReference(); + chunkBuffer = null; + + try { + PendingChunk chunk = new PendingChunk(chunkBytes, hitsInChunk, chunkStartIndex, isLast); + + if (isLast) { + lastChunkHolder.set(chunk); + } else { + ActionListener completionRef = null; + try { + completionRef = completionRefs.acquire(); + sendRunner.enqueueTask( + new SendChunkTask( + chunk, + completionRef, + chunkWriter, + shardId, + totalDocs, + sendFailure, + chunkCompletionRefs, + isCancelled + ) + ); + completionRef = null; + } finally { + if (completionRef != null) { + completionRef.onResponse(null); + chunk.close(); + } + } + } + + if (isLast == false) { + chunkBuffer = chunkWriter.newNetworkBytesStream(); + chunkStartIndex = scoreIndex + 1; + hitsInChunk = 0; + } + } catch (Exception e) { + Releasables.closeWhileHandlingException(chunkBytes); + throw e; + } + } + } + } finally { + if (chunkBuffer != null) { + Releasables.closeWhileHandlingException(chunkBuffer); + } + } + } + + /** + * Task that sends a single chunk. Implements {@link ActionListener} to receive + * the throttle releasable from {@link ThrottledTaskRunner}. + */ + private static final class SendChunkTask implements ActionListener { + private final PendingChunk chunk; + private final ActionListener completionRef; + private final FetchPhaseResponseChunk.Writer writer; + private final ShardId shardId; + private final int totalDocs; + private final AtomicReference sendFailure; + private final RefCountingListener chunkCompletionRefs; + private final Supplier isCancelled; + + private SendChunkTask( + PendingChunk chunk, + ActionListener completionRef, + FetchPhaseResponseChunk.Writer writer, + ShardId shardId, + int totalDocs, + AtomicReference sendFailure, + RefCountingListener chunkCompletionRefs, + Supplier isCancelled + ) { + this.chunk = chunk; + this.completionRef = completionRef; + this.writer = writer; + this.shardId = shardId; + this.totalDocs = totalDocs; + this.sendFailure = sendFailure; + this.chunkCompletionRefs = chunkCompletionRefs; + this.isCancelled = isCancelled; + } + + @Override + public void onResponse(Releasable throttleReleasable) { + sendChunk(chunk, throttleReleasable, completionRef, writer, shardId, totalDocs, sendFailure, chunkCompletionRefs, isCancelled); + } + + @Override + public void onFailure(Exception e) { + chunk.close(); + sendFailure.compareAndSet(null, e); + completionRef.onFailure(e); + } + } + + /** + * Sends a single chunk. Called by ThrottledTaskRunner. + *

+ * The send is asynchronous - this method initiates the network write and returns immediately. + * The ACK callback handles cleanup and signals task completion to ThrottledTaskRunner. + * Page-level CB tracking is released when the {@link ReleasableBytesReference} is closed. + */ + private static void sendChunk( + PendingChunk chunk, + Releasable throttleReleasable, + ActionListener completionRef, + FetchPhaseResponseChunk.Writer writer, + ShardId shardId, + int totalDocs, + AtomicReference sendFailure, + RefCountingListener chunkCompletionRefs, + Supplier isCancelled + ) { + if (isCancelled.get()) { + chunk.close(); + completionRef.onResponse(null); + throttleReleasable.close(); + return; + } + + // Check for prior failure before sending + final Throwable failure = sendFailure.get(); + if (failure != null) { + chunk.close(); + completionRef.onResponse(null); + throttleReleasable.close(); + return; + } + + FetchPhaseResponseChunk responseChunk = null; + ActionListener ackRef = null; + try { + responseChunk = new FetchPhaseResponseChunk(shardId, chunk.bytes, chunk.hitCount, totalDocs, chunk.sequenceStart); + + final FetchPhaseResponseChunk chunkToClose = responseChunk; + + ackRef = chunkCompletionRefs.acquire(); + final ActionListener finalAckRef = ackRef; + + writer.writeResponseChunk(responseChunk, ActionListener.wrap(v -> { + chunkToClose.close(); + finalAckRef.onResponse(null); + completionRef.onResponse(null); + throttleReleasable.close(); + }, e -> { + chunkToClose.close(); + sendFailure.compareAndSet(null, e); + finalAckRef.onFailure(e); + completionRef.onFailure(e); + throttleReleasable.close(); + })); + + responseChunk = null; + } catch (Exception e) { + if (responseChunk != null) { + responseChunk.close(); + } else { + chunk.close(); + } + sendFailure.compareAndSet(null, e); + if (ackRef != null) { + ackRef.onFailure(e); + } + completionRef.onFailure(e); + throttleReleasable.close(); + } + } + + private static void cleanupLastChunk(AtomicReference lastChunkHolder) { + PendingChunk lastChunk = lastChunkHolder.getAndSet(null); + if (lastChunk != null) { + lastChunk.close(); + } + } + + /** + * Represents a chunk ready to be sent. The underlying {@link ReleasableBytesReference} carries + * the page-level circuit breaker release callback from {@link RecyclerBytesStreamOutput#moveToBytesReference()}. + */ + static class PendingChunk implements AutoCloseable { + final ReleasableBytesReference bytes; + final int hitCount; + final int sequenceStart; + final boolean isLast; + + PendingChunk(ReleasableBytesReference bytes, int hitCount, int sequenceStart, boolean isLast) { + this.bytes = bytes; + this.hitCount = hitCount; + this.sequenceStart = sequenceStart; + this.isLast = isLast; + } + + @Override + public void close() { + if (bytes != null) { + Releasables.closeWhileHandlingException(bytes); + } + } + } +} diff --git a/server/src/main/java/org/elasticsearch/search/fetch/chunk/ActiveFetchPhaseTasks.java b/server/src/main/java/org/elasticsearch/search/fetch/chunk/ActiveFetchPhaseTasks.java new file mode 100644 index 0000000000000..8d24f26228965 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/fetch/chunk/ActiveFetchPhaseTasks.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.fetch.chunk; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.index.shard.ShardId; + +import java.util.concurrent.ConcurrentMap; + +/** + * Manages the registry of active fetch response streams on the coordinator node. + */ + +public final class ActiveFetchPhaseTasks { + + record ResponseStreamKey(long coordinatingTaskId, ShardId shardId) {} + + private final ConcurrentMap tasks = ConcurrentCollections.newConcurrentMap(); + + /** + * Registers a response stream for a specific coordinating task and shard. + * + * This method is called by {@link TransportFetchPhaseCoordinationAction} when starting + * a chunked fetch. The returned {@link Releasable} must be closed when the fetch + * completes to remove the stream from the registry. + * + * @param coordinatingTaskId the ID of the coordinating search task + * @param shardId the shard ID being fetched + * @param responseStream the stream to register (must have at least one reference count) + * @return a releasable that removes the registration when closed + * @throws IllegalStateException if a stream for this task+shard combination is already registered + */ + Releasable registerResponseBuilder(long coordinatingTaskId, ShardId shardId, FetchPhaseResponseStream responseStream) { + assert responseStream.hasReferences(); + + ResponseStreamKey key = new ResponseStreamKey(coordinatingTaskId, shardId); + + final var previous = tasks.putIfAbsent(key, responseStream); + if (previous != null) { + throw new IllegalStateException("already executing fetch task [" + coordinatingTaskId + "]"); + } + + return Releasables.assertOnce(() -> { + final var removed = tasks.remove(key, responseStream); + if (removed == false) { + throw new IllegalStateException("already completed fetch task [" + coordinatingTaskId + "]"); + } + }); + } + + /** + * Acquires the response stream for the given coordinating task and shard, incrementing its reference count. + * + * This method is called by {@link TransportFetchPhaseResponseChunkAction} for each arriving chunk. + * The caller must call {@link FetchPhaseResponseStream#decRef()} when done processing the chunk. + * + * @param coordinatingTaskId the ID of the coordinating search task + * @param shardId the shard ID + * @return the response stream with an incremented reference count + * @throws ResourceNotFoundException if the task is not registered or has already completed + */ + public FetchPhaseResponseStream acquireResponseStream(long coordinatingTaskId, ShardId shardId) { + final var outerRequest = tasks.get(new ResponseStreamKey(coordinatingTaskId, shardId)); + if (outerRequest == null || outerRequest.tryIncRef() == false) { + throw new ResourceNotFoundException("fetch task [" + coordinatingTaskId + "] not found"); + } + return outerRequest; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/fetch/chunk/FetchPhaseResponseChunk.java b/server/src/main/java/org/elasticsearch/search/fetch/chunk/FetchPhaseResponseChunk.java new file mode 100644 index 0000000000000..ad0f8c9c7ca24 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/fetch/chunk/FetchPhaseResponseChunk.java @@ -0,0 +1,206 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.fetch.chunk; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.CompositeBytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.search.SearchHit; + +import java.io.IOException; + +/** + * A single chunk of fetch results streamed from a data node to the coordinator. + * Contains sequence information to maintain correct ordering when chunks arrive out of order. + * + *

Supports zero-copy transport by separating header metadata from serialized hits. + * The header is created after hits are serialized (since we don't know hit count until + * the buffer is full), then combined using {@link CompositeBytesReference} to avoid copying. + */ +public class FetchPhaseResponseChunk implements Writeable, Releasable { + + /** + * Initial capacity hint for chunk metadata serialization. + *

+ * The metadata contains a few fields plus a reference to the already serialized + * hit payload. The payload size dominates and the stream can grow if needed, so this is + * intentionally a small preallocation to avoid over-reserving per chunk. + */ + private static final int INITIAL_CHUNK_SERIALIZATION_CAPACITY = 128; + + private final ShardId shardId; + private final int hitCount; + private final int expectedTotalDocs; + private final long sequenceStart; + + private BytesReference serializedHits; + private SearchHit[] deserializedHits; + private NamedWriteableRegistry namedWriteableRegistry; + + /** + * Creates a chunk with pre-serialized hits. + * Takes ownership of serializedHits - caller must not release it. + * + * @param shardId source shard + * @param serializedHits pre-serialized hit bytes + * @param hitCount number of hits in the serialized bytes + * @param expectedTotalDocs total number of documents requested for this shard fetch operation + * across all chunks (derived from requested doc IDs, not an observed + * count of docs received so far) + * @param sequenceStart sequence number of first hit for ordering + */ + public FetchPhaseResponseChunk( + ShardId shardId, + BytesReference serializedHits, + int hitCount, + int expectedTotalDocs, + long sequenceStart + ) { + this.shardId = shardId; + this.serializedHits = serializedHits; + this.hitCount = hitCount; + this.expectedTotalDocs = expectedTotalDocs; + this.sequenceStart = sequenceStart; + } + + /** + * Deserializes from stream (receiving side). + */ + public FetchPhaseResponseChunk(StreamInput in) throws IOException { + this.shardId = new ShardId(in); + this.hitCount = in.readVInt(); + this.expectedTotalDocs = in.readVInt(); + this.sequenceStart = in.readVLong(); + this.serializedHits = in.readBytesReference(); + this.namedWriteableRegistry = in.namedWriteableRegistry(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + shardId.writeTo(out); + out.writeVInt(hitCount); + out.writeVInt(expectedTotalDocs); + out.writeVLong(sequenceStart); + out.writeBytesReference(serializedHits); + } + + public ReleasableBytesReference toReleasableBytesReference(long coordinatingTaskId) throws IOException { + final ReleasableBytesReference result; + try (BytesStreamOutput header = new BytesStreamOutput(INITIAL_CHUNK_SERIALIZATION_CAPACITY)) { + header.writeVLong(coordinatingTaskId); + shardId.writeTo(header); + header.writeVInt(hitCount); + header.writeVInt(expectedTotalDocs); + header.writeVLong(sequenceStart); + header.writeVInt(serializedHits.length()); + + BytesReference composite = CompositeBytesReference.of(header.copyBytes(), serializedHits); + if (serializedHits instanceof ReleasableBytesReference releasableHits) { + result = new ReleasableBytesReference(composite, releasableHits::decRef); + } else { + result = ReleasableBytesReference.wrap(composite); + } + this.serializedHits = null; + } + return result; + } + + public long getBytesLength() { + return serializedHits == null ? 0 : serializedHits.length(); + } + + public SearchHit[] getHits() throws IOException { + if (deserializedHits == null && serializedHits != null && hitCount > 0) { + deserializedHits = new SearchHit[hitCount]; + try (StreamInput in = createStreamInput()) { + for (int i = 0; i < hitCount; i++) { + deserializedHits[i] = SearchHit.readFrom(in, false); + } + } + } + return deserializedHits != null ? deserializedHits : new SearchHit[0]; + } + + private StreamInput createStreamInput() throws IOException { + StreamInput in = serializedHits.streamInput(); + if (namedWriteableRegistry != null) { + in = new NamedWriteableAwareStreamInput(in, namedWriteableRegistry); + } + return in; + } + + public ShardId shardId() { + return shardId; + } + + public int hitCount() { + return hitCount; + } + + public int expectedTotalDocs() { + return expectedTotalDocs; + } + + public long sequenceStart() { + return sequenceStart; + } + + @Override + public void close() { + if (serializedHits instanceof Releasable) { + Releasables.closeWhileHandlingException((Releasable) serializedHits); + } + serializedHits = null; + + if (deserializedHits != null) { + for (SearchHit hit : deserializedHits) { + if (hit != null) { + hit.decRef(); + } + } + deserializedHits = null; + } + } + + /** + * Interface for sending chunk responses from the data node to the coordinator. + *

+ * Implementations handle network transport using {@link org.elasticsearch.transport.BytesTransportRequest} + * for zero-copy transmission, and provide buffer allocation using Netty's pooled allocator. + */ + public interface Writer { + + /** + * Sends a chunk to the coordinator using zero-copy transport. + * + * @param responseChunk the chunk to send + * @param listener called when the chunk is acknowledged or fails + */ + void writeResponseChunk(FetchPhaseResponseChunk responseChunk, ActionListener listener); + + /** + * Creates a new byte stream for serializing hits. Uses a network buffer pool for efficient allocation. + * + * @return a new RecyclerBytesStreamOutput from the network buffer pool + */ + RecyclerBytesStreamOutput newNetworkBytesStream(); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/fetch/chunk/FetchPhaseResponseStream.java b/server/src/main/java/org/elasticsearch/search/fetch/chunk/FetchPhaseResponseStream.java new file mode 100644 index 0000000000000..e054e6c30e3df --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/fetch/chunk/FetchPhaseResponseStream.java @@ -0,0 +1,236 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.fetch.chunk; + +import org.apache.lucene.search.TotalHits; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.core.AbstractRefCounted; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.fetch.FetchSearchResult; +import org.elasticsearch.search.internal.ShardSearchContextId; +import org.elasticsearch.search.profile.ProfileResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Accumulates {@link SearchHit} chunks sent from a data node during a chunked fetch operation. + * Runs on the coordinator node and maintains an in-memory buffer of hits received + * from a single shard on a data node. The data node sends hits in small chunks to + * avoid large network messages and memory pressure. + * + * Uses sequence numbers to maintain correct ordering when chunks arrive out of order. + **/ +class FetchPhaseResponseStream extends AbstractRefCounted { + + private static final Logger logger = LogManager.getLogger(FetchPhaseResponseStream.class); + + private final int shardIndex; + private final int expectedTotalDocs; + + // Accumulate hits with sequence numbers for ordering + private final Queue queue = new ConcurrentLinkedQueue<>(); + private volatile boolean ownershipTransferred = false; + + // Circuit breaker accounting + private final CircuitBreaker circuitBreaker; + private final AtomicLong totalBreakerBytes = new AtomicLong(0); + + /** + * Creates a new response stream for accumulating hits from a single shard. + * + * @param shardIndex the shard ID this stream is collecting hits for + * @param expectedTotalDocs total number of documents requested for this shard fetch operation + * across all chunks (target/requested count, not guaranteed delivered count) + * @param circuitBreaker circuit breaker to check memory usage during accumulation (typically REQUEST breaker) + */ + FetchPhaseResponseStream(int shardIndex, int expectedTotalDocs, CircuitBreaker circuitBreaker) { + this.shardIndex = shardIndex; + this.expectedTotalDocs = expectedTotalDocs; + this.circuitBreaker = circuitBreaker; + } + + /** + * Adds a chunk of hits to the accumulated result. + * + * This method increments the reference count of each {@link SearchHit} + * via {@link SearchHit#incRef()} to take ownership. The hits will be released in {@link #closeInternal()}. + * + * @param chunk the chunk containing hits to accumulate + * @param releasable a releasable to close after processing (typically releases the acquired stream reference) + */ + void writeChunk(FetchPhaseResponseChunk chunk, Releasable releasable) { + boolean success = false; + try { + // Track memory usage + long bytesSize = chunk.getBytesLength(); + circuitBreaker.addEstimateBytesAndMaybeBreak(bytesSize, "fetch_chunk_accumulation"); + totalBreakerBytes.addAndGet(bytesSize); + + SearchHit[] chunkHits = chunk.getHits(); + long sequenceStart = chunk.sequenceStart(); + + for (int i = 0; i < chunkHits.length; i++) { + SearchHit hit = chunkHits[i]; + hit.incRef(); + + // Calculate sequence: chunk start + index within chunk + long hitSequence = sequenceStart + i; + queue.add(new SequencedHit(hit, hitSequence)); + } + + if (logger.isDebugEnabled()) { + logger.debug( + "Received chunk [{}] docs for shard [{}]: [{}/{}] hits accumulated, [{}] breaker bytes, used breaker bytes [{}]", + chunkHits.length, + shardIndex, + queue.size(), + expectedTotalDocs, + totalBreakerBytes.get(), + circuitBreaker.getUsed() + ); + } + success = true; + } catch (IOException e) { + throw new RuntimeException("Failed to deserialize hits from chunk", e); + } finally { + if (success) { + releasable.close(); + } + } + } + + /** + * Builds the final {@link FetchSearchResult} from all accumulated hits. + * Sorts hits by sequence number to restore correct order. + * + * @param ctxId the shard search context ID + * @param shardTarget the shard target information + * @param profileResult the profile result from the data node (may be null) + * @return a complete {@link FetchSearchResult} containing all accumulated hits in correct order + */ + FetchSearchResult buildFinalResult(ShardSearchContextId ctxId, SearchShardTarget shardTarget, @Nullable ProfileResult profileResult) { + if (logger.isDebugEnabled()) { + logger.debug("Building final result for shard [{}] with [{}] hits", shardIndex, queue.size()); + } + + // Convert queue to list and sort by sequence number to restore correct order + List sequencedHits = new ArrayList<>(queue); + sequencedHits.sort(Comparator.comparingLong(sh -> sh.sequence)); + + // Extract hits in correct order and calculate maxScore + List orderedHits = new ArrayList<>(sequencedHits.size()); + float maxScore = Float.NEGATIVE_INFINITY; + + for (SequencedHit sequencedHit : sequencedHits) { + SearchHit hit = sequencedHit.hit; + orderedHits.add(hit); + + if (Float.isNaN(hit.getScore()) == false) { + maxScore = Math.max(maxScore, hit.getScore()); + } + } + + if (maxScore == Float.NEGATIVE_INFINITY) { + maxScore = Float.NaN; + } + + ownershipTransferred = true; + + SearchHits searchHits = new SearchHits( + orderedHits.toArray(SearchHit[]::new), + new TotalHits(orderedHits.size(), TotalHits.Relation.EQUAL_TO), + maxScore + ); + + FetchSearchResult result = new FetchSearchResult(ctxId, shardTarget); + result.shardResult(searchHits, profileResult); + return result; + } + + /** + * Adds a single hit with explicit sequence number to the accumulated result. + * Used for processing the last chunk embedded in FetchSearchResult where sequence is known. + * + * @param hit the hit to add + * @param sequence the sequence number for this hit + */ + void addHitWithSequence(SearchHit hit, long sequence) { + queue.add(new SequencedHit(hit, sequence)); + } + + /** + * Tracks circuit breaker bytes without checking. Used when coordinator processes the embedded last chunk. + */ + void trackBreakerBytes(int bytes) { + totalBreakerBytes.addAndGet(bytes); + } + + /** + * Releases accumulated hits and circuit breaker bytes when hits are released from memory. + */ + @Override + protected void closeInternal() { + if (logger.isDebugEnabled()) { + logger.debug( + "Closing response stream for shard [{}], releasing [{}] hits, [{}] breaker bytes", + shardIndex, + queue.size(), + totalBreakerBytes.get() + ); + } + + if (ownershipTransferred == false) { + for (SequencedHit sequencedHit : queue) { + sequencedHit.hit.decRef(); + } + } + queue.clear(); + + // Release circuit breaker bytes added during accumulation when hits are released from memory + if (totalBreakerBytes.get() > 0) { + circuitBreaker.addWithoutBreaking(-totalBreakerBytes.get()); + if (logger.isDebugEnabled()) { + logger.debug( + "Released [{}] breaker bytes for shard [{}], used breaker bytes [{}]", + totalBreakerBytes.get(), + shardIndex, + circuitBreaker.getUsed() + ); + } + totalBreakerBytes.set(0); + } + } + + /** + * Wrapper class that pairs a SearchHit with its sequence number. + * This ensures we can restore the correct order even if chunks arrive out of order. + */ + private static class SequencedHit { + final SearchHit hit; + final long sequence; + + SequencedHit(SearchHit hit, long sequence) { + this.hit = hit; + this.sequence = sequence; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/search/fetch/chunk/TransportFetchPhaseCoordinationAction.java b/server/src/main/java/org/elasticsearch/search/fetch/chunk/TransportFetchPhaseCoordinationAction.java new file mode 100644 index 0000000000000..0193c741165ac --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/fetch/chunk/TransportFetchPhaseCoordinationAction.java @@ -0,0 +1,272 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.fetch.chunk; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.indices.breaker.CircuitBreakerService; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.fetch.FetchSearchResult; +import org.elasticsearch.search.fetch.ShardFetchSearchRequest; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.transport.TransportRequestOptions; +import org.elasticsearch.transport.TransportService; + +import java.io.IOException; +import java.util.Map; + +import static org.elasticsearch.action.search.SearchTransportService.FETCH_ID_ACTION_NAME; + +public class TransportFetchPhaseCoordinationAction extends HandledTransportAction< + TransportFetchPhaseCoordinationAction.Request, + TransportFetchPhaseCoordinationAction.Response> { + + /* + * Transport action that coordinates chunked fetch operations from the coordinator node. + * Handles receiving chunks, accumulating them in order, and building the final result. + *

+ * This action orchestrates the chunked fetch flow by: + *

    + *
  1. Registering a {@link FetchPhaseResponseStream} for accumulating chunks
  2. + *
  3. Setting coordinator information on the fetch request
  4. + *
  5. Sending the request to the data node via the standard fetch transport action
  6. + *
  7. Building the final result from accumulated chunks when the data node completes
  8. + *
+ *

+ * +-------------------+ +-------------+ +-----------+ + * | FetchSearchPhase | | Coordinator | | Data Node | + * +-------------------+ +-------------+ +-----------+ + * | | | + * |- execute(request, dataNode)-------->| | --[Initialization Phase] + * | |---[ShardFetchRequest]------------------->| + * | | | --[[Chunked Streaming Phase] + * | |<---[HITS chunk 1]------------------------| + * | |----[ACK (Empty)]------------------------>| + * | | .... | + * | |<---[HITS chunk N]------------------------| + * | |----[ACK (Empty)]------------------------>| + * | | | --[Completion Phase] + * | |<--FetchSearchResult----------------------| + * | | (final response) | + * | | | + * | |--[Build final result] | + * | | (from accumulated chunks) | + * |<-- FetchSearchResult (complete) ----| | + */ + private static final Logger LOGGER = LogManager.getLogger(TransportFetchPhaseCoordinationAction.class); + + public static final ActionType TYPE = new ActionType<>("internal:data/read/search/fetch/coordination"); + + public static final TransportVersion CHUNKED_FETCH_PHASE = TransportVersion.fromName("chunked_fetch_phase"); + + private final TransportService transportService; + private final ActiveFetchPhaseTasks activeFetchPhaseTasks; + private final CircuitBreakerService circuitBreakerService; + + /** + * Required for deserializing SearchHits from chunk bytes that may contain NamedWriteable + * fields (e.g., LookupField from lookup runtime fields). See {@link NamedWriteableAwareStreamInput}. + */ + private final NamedWriteableRegistry namedWriteableRegistry; + + public static class Request extends ActionRequest { + private final ShardFetchSearchRequest shardFetchRequest; + private final DiscoveryNode dataNode; + private final Map headers; + + public Request(ShardFetchSearchRequest shardFetchRequest, DiscoveryNode dataNode, Map headers) { + this.shardFetchRequest = shardFetchRequest; + this.dataNode = dataNode; + this.headers = headers; + } + + public Request(StreamInput in) throws IOException { + super(in); + this.shardFetchRequest = new ShardFetchSearchRequest(in); + this.dataNode = new DiscoveryNode(in); + this.headers = in.readMap(StreamInput::readString); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + shardFetchRequest.writeTo(out); + dataNode.writeTo(out); + out.writeMap(headers, StreamOutput::writeString); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + public ShardFetchSearchRequest getShardFetchRequest() { + return shardFetchRequest; + } + + public DiscoveryNode getDataNode() { + return dataNode; + } + + public Map getHeaders() { + return headers; + } + } + + public static class Response extends ActionResponse { + private final FetchSearchResult result; + + public Response(FetchSearchResult result) { + this.result = result; + } + + public Response(StreamInput in) throws IOException { + this.result = new FetchSearchResult(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + result.writeTo(out); + } + + public FetchSearchResult getResult() { + return result; + } + } + + @Inject + public TransportFetchPhaseCoordinationAction( + TransportService transportService, + ActionFilters actionFilters, + ActiveFetchPhaseTasks activeFetchPhaseTasks, + CircuitBreakerService circuitBreakerService, + NamedWriteableRegistry namedWriteableRegistry + ) { + super(TYPE.name(), transportService, actionFilters, Request::new, EsExecutors.DIRECT_EXECUTOR_SERVICE); + this.transportService = transportService; + this.activeFetchPhaseTasks = activeFetchPhaseTasks; + this.circuitBreakerService = circuitBreakerService; + this.namedWriteableRegistry = namedWriteableRegistry; + } + + // Creates and registers a response stream for the coordinating task + @Override + public void doExecute(Task task, Request request, ActionListener listener) { + final long coordinatingTaskId = task.getId(); + + // Set coordinator information on the request + final ShardFetchSearchRequest fetchReq = request.getShardFetchRequest(); + fetchReq.setCoordinatingNode(transportService.getLocalNode()); + fetchReq.setCoordinatingTaskId(coordinatingTaskId); + + // Create and register response stream + assert fetchReq.getShardSearchRequest() != null; + ShardId shardId = fetchReq.getShardSearchRequest().shardId(); + int expectedTotalDocs = fetchReq.docIds().length; + + CircuitBreaker circuitBreaker = circuitBreakerService.getBreaker(CircuitBreaker.REQUEST); + FetchPhaseResponseStream responseStream = new FetchPhaseResponseStream(shardId.getId(), expectedTotalDocs, circuitBreaker); + Releasable registration = activeFetchPhaseTasks.registerResponseBuilder(coordinatingTaskId, shardId, responseStream); + + // Listener that builds final result from accumulated chunks + ActionListener childListener = ActionListener.runAfter(ActionListener.wrap(dataNodeResult -> { + BytesReference lastChunkBytes = dataNodeResult.getLastChunkBytes(); + int hitCount = dataNodeResult.getLastChunkHitCount(); + long lastChunkSequenceStart = dataNodeResult.getLastChunkSequenceStart(); + + // Process the embedded last chunk if present + if (lastChunkBytes != null && hitCount > 0) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "Received final chunk [{}] for shard [{}]", + hitCount, + request.shardFetchRequest.getShardSearchRequest().shardId() + ); + } + + // Track memory usage + int bytesSize = lastChunkBytes.length(); + circuitBreaker.addEstimateBytesAndMaybeBreak(bytesSize, "fetch_chunk_accumulation"); + responseStream.trackBreakerBytes(bytesSize); + + try (StreamInput in = new NamedWriteableAwareStreamInput(lastChunkBytes.streamInput(), namedWriteableRegistry)) { + for (int i = 0; i < hitCount; i++) { + SearchHit hit = SearchHit.readFrom(in, false); + + // Add with explicit sequence number + long hitSequence = lastChunkSequenceStart + i; + responseStream.addHitWithSequence(hit, hitSequence); + } + } + } + + // Build final result from all accumulated hits + FetchSearchResult finalResult = responseStream.buildFinalResult( + dataNodeResult.getContextId(), + dataNodeResult.getSearchShardTarget(), + dataNodeResult.profileResult() + ); + + ActionListener.respondAndRelease(listener.map(Response::new), finalResult); + }, listener::onFailure), () -> { + registration.close(); + responseStream.decRef(); + }); + + final ThreadContext threadContext = transportService.getThreadPool().getThreadContext(); + try (ThreadContext.StoredContext ignored = threadContext.stashContext()) { + for (var e : request.getHeaders().entrySet()) { + final String key = e.getKey(); + final String value = e.getValue(); + final String existing = threadContext.getHeader(key); + if (existing == null) { + threadContext.putHeader(key, value); + } else { + assert existing.equals(value) : "header [" + key + "] already present with different value"; + } + } + + final TaskId parent = task.getParentTaskId(); + if (parent != null && parent.isSet()) { + fetchReq.setParentTask(parent); + } + + transportService.sendRequest( + request.getDataNode(), + FETCH_ID_ACTION_NAME, + fetchReq, + TransportRequestOptions.EMPTY, + new ActionListenerResponseHandler<>(childListener, FetchSearchResult::new, EsExecutors.DIRECT_EXECUTOR_SERVICE) + ); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/search/fetch/chunk/TransportFetchPhaseResponseChunkAction.java b/server/src/main/java/org/elasticsearch/search/fetch/chunk/TransportFetchPhaseResponseChunkAction.java new file mode 100644 index 0000000000000..21fada7b470cc --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/fetch/chunk/TransportFetchPhaseResponseChunkAction.java @@ -0,0 +1,150 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.fetch.chunk; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.transport.BytesTransportRequest; +import org.elasticsearch.transport.TransportService; + +/** + * Receives fetch result chunks from data nodes via zero-copy transport. This component runs on the + * coordinator node and serves as the receiver endpoint for {@link FetchPhaseResponseChunk} + * messages sent by data nodes during chunked fetch operations. + * + *

Chunks arrive as {@link BytesTransportRequest} on the {@link #ZERO_COPY_ACTION_NAME} endpoint. + * Bytes flow directly from Netty buffers without an intermediate deserialization/re-serialization step. + */ +public class TransportFetchPhaseResponseChunkAction { + + /* + * [Data Node] [Coordinator] + * | | + * | FetchPhase.execute(writer) | + * | ↓ | + * | writer.writeResponseChunk(chunk) ------------>| TransportFetchPhaseResponseChunkAction + * | (via BytesTransportRequest, zero-copy) | ↓ + * | | activeFetchPhaseTasks.acquireResponseStream() + * | | ↓ + * | | responseStream.writeChunk() + * | | + * |<------------- [ACK (Empty)]-------------------| + */ + + /** + * Action name for zero-copy BytesTransportRequest path. + * Sender uses this action name when sending via BytesTransportRequest. + */ + public static final String ZERO_COPY_ACTION_NAME = "internal:data/read/search/fetch/chunk[bytes]"; + + private final ActiveFetchPhaseTasks activeFetchPhaseTasks; + + /** + * Required for deserializing SearchHits that contain NamedWriteable objects. + *

+ * SearchHit's DocumentFields can contain types like {@link org.elasticsearch.search.fetch.subphase.LookupField} + * which implement NamedWriteable. When reading serialized hits from raw bytes (from chunks), + * the basic StreamInput cannot deserialize these types. Wrapping with + * {@link NamedWriteableAwareStreamInput} provides the registry needed to resolve + * NamedWriteable types by their registered names. + */ + private final NamedWriteableRegistry namedWriteableRegistry; + + /** + * Creates a new chunk receiver and registers the zero-copy transport handler. + * + * @param transportService the transport service used to register the handler + * @param activeFetchPhaseTasks the registry of active fetch response streams + * @param namedWriteableRegistry registry for deserializing NamedWriteable types in chunks + */ + @Inject + public TransportFetchPhaseResponseChunkAction( + TransportService transportService, + ActiveFetchPhaseTasks activeFetchPhaseTasks, + NamedWriteableRegistry namedWriteableRegistry + ) { + this.activeFetchPhaseTasks = activeFetchPhaseTasks; + this.namedWriteableRegistry = namedWriteableRegistry; + registerZeroCopyHandler(transportService); + } + + /** + * Registers the handler for zero-copy chunk reception via BytesTransportRequest. + * The incoming bytes contain a routing header (coordinatingTaskId) followed by the chunk data. + * We parse the header to extract the task ID, then deserialize and process the chunk. + */ + private void registerZeroCopyHandler(TransportService transportService) { + transportService.registerRequestHandler( + ZERO_COPY_ACTION_NAME, + EsExecutors.DIRECT_EXECUTOR_SERVICE, + false, + true, + BytesTransportRequest::new, + (request, channel, task) -> { + ReleasableBytesReference bytesRef = request.bytes(); + long coordinatingTaskId; + FetchPhaseResponseChunk chunk; + + try (StreamInput in = new NamedWriteableAwareStreamInput(bytesRef.streamInput(), namedWriteableRegistry)) { + coordinatingTaskId = in.readVLong(); + chunk = new FetchPhaseResponseChunk(in); + } catch (Exception e) { + channel.sendResponse(e); + return; + } + + processChunk( + coordinatingTaskId, + chunk, + ActionListener.releaseAfter( + ActionListener.wrap(ignored -> channel.sendResponse(ActionResponse.Empty.INSTANCE), channel::sendResponse), + chunk + ) + ); + } + ); + } + + /** + * Running on the coordinator node. Processes an incoming chunk by routing it to the appropriate response stream. + * + *

This method: + *

    + *
  1. Extracts the shard ID from the chunk
  2. + *
  3. Acquires the response stream from {@link ActiveFetchPhaseTasks}
  4. + *
  5. Delegates to {@link FetchPhaseResponseStream#writeChunk}
  6. + *
  7. Releases the response stream reference
  8. + *
  9. Sends an acknowledgment response to the data node
  10. + *
+ * + * @param coordinatingTaskId the ID of the coordinating search task + * @param chunk the chunk to process + * @param listener callback for sending the acknowledgment + */ + private void processChunk(long coordinatingTaskId, FetchPhaseResponseChunk chunk, ActionListener listener) { + ActionListener.run(listener, l -> { + ShardId shardId = chunk.shardId(); + + final var responseStream = activeFetchPhaseTasks.acquireResponseStream(coordinatingTaskId, shardId); + try { + responseStream.writeChunk(chunk, () -> l.onResponse(ActionResponse.Empty.INSTANCE)); + } finally { + responseStream.decRef(); + } + }); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java b/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java index cacc8a3ce6967..b3ad2aab19da2 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java @@ -18,6 +18,7 @@ import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.cluster.metadata.AliasMetadata; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.SplitShardCountSummary; import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.bytes.BytesArray; @@ -55,6 +56,7 @@ import java.util.Map; import static java.util.Collections.emptyMap; +import static org.elasticsearch.search.fetch.chunk.TransportFetchPhaseCoordinationAction.CHUNKED_FETCH_PHASE; import static org.elasticsearch.search.internal.SearchContext.TRACK_TOTAL_HITS_DISABLED; /** @@ -88,6 +90,8 @@ public class ShardSearchRequest extends AbstractTransportRequest implements Indi private final TransportVersion channelVersion; + private DiscoveryNode coordinatingNode; + /** * Should this request force {@link SourceLoader.Synthetic synthetic source}? * Use this to test if the mapping supports synthetic _source and to get a sense @@ -326,6 +330,10 @@ public ShardSearchRequest(StreamInput in) throws IOException { } originalIndices = OriginalIndices.readOriginalIndices(in); + + if (in.getTransportVersion().supports(CHUNKED_FETCH_PHASE)) { + coordinatingNode = in.readOptionalWriteable(DiscoveryNode::new); + } } @Override @@ -333,6 +341,10 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); innerWriteTo(out, false); OriginalIndices.writeOriginalIndices(originalIndices, out); + + if (out.getTransportVersion().supports(CHUNKED_FETCH_PHASE)) { + out.writeOptionalWriteable(coordinatingNode); + } } protected final void innerWriteTo(StreamOutput out, boolean asKey) throws IOException { @@ -664,4 +676,13 @@ public TransportVersion getChannelVersion() { public boolean isForceSyntheticSource() { return forceSyntheticSource; } + + public void setCoordinatingNode(DiscoveryNode node) { + this.coordinatingNode = node; + } + + public DiscoveryNode getCoordinatingNode() { + return coordinatingNode; + } + } diff --git a/server/src/main/resources/transport/definitions/referable/chunked_fetch_phase.csv b/server/src/main/resources/transport/definitions/referable/chunked_fetch_phase.csv new file mode 100644 index 0000000000000..26ba95803cf8b --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/chunked_fetch_phase.csv @@ -0,0 +1 @@ +9325000 diff --git a/server/src/main/resources/transport/upper_bounds/9.4.csv b/server/src/main/resources/transport/upper_bounds/9.4.csv index f1bc9f17b874e..f70d5acb61e8f 100644 --- a/server/src/main/resources/transport/upper_bounds/9.4.csv +++ b/server/src/main/resources/transport/upper_bounds/9.4.csv @@ -1 +1 @@ -inference_api_chat_completion_reasoning_max_tokens_removed,9324000 +chunked_fetch_phase,9325000 diff --git a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseChunkedTests.java b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseChunkedTests.java new file mode 100644 index 0000000000000..10f6d93e23471 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseChunkedTests.java @@ -0,0 +1,941 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.action.search; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.VersionInformation; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.indices.EmptySystemIndices; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.RescoreDocIds; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchService; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.builder.PointInTimeBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.FetchPhase; +import org.elasticsearch.search.fetch.FetchSearchResult; +import org.elasticsearch.search.fetch.ShardFetchSearchRequest; +import org.elasticsearch.search.fetch.chunk.ActiveFetchPhaseTasks; +import org.elasticsearch.search.fetch.chunk.TransportFetchPhaseCoordinationAction; +import org.elasticsearch.search.internal.AliasFilter; +import org.elasticsearch.search.internal.ShardSearchContextId; +import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.telemetry.tracing.Tracer; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.InternalAggregationTestCase; +import org.elasticsearch.test.transport.MockTransport; +import org.elasticsearch.test.transport.MockTransportService; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.CloseableConnection; +import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportException; +import org.elasticsearch.transport.TransportRequest; +import org.elasticsearch.transport.TransportRequestOptions; +import org.elasticsearch.transport.TransportService; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiFunction; + +import static org.elasticsearch.action.search.FetchSearchPhaseTests.addProfiling; +import static org.elasticsearch.action.search.FetchSearchPhaseTests.fetchProfile; +import static org.elasticsearch.action.search.FetchSearchPhaseTests.searchPhaseFactory; + +public class FetchSearchPhaseChunkedTests extends ESTestCase { + + /** + * Test that chunked fetch is used when all conditions are met: + * - fetchPhaseChunked is true + * - data node supports CHUNKED_FETCH_PHASE + * - not a CCS query (no cluster alias) + * - not a scroll or reindex query + */ + public void testChunkedFetchUsedWhenConditionsMet() throws Exception { + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); + ThreadPool threadPool = new TestThreadPool("test"); + try { + TransportService mockTransportService = createMockTransportService(threadPool); + + try (SearchPhaseResults results = createSearchPhaseResults(mockSearchPhaseContext)) { + boolean profiled = randomBoolean(); + + // Add first shard result + final ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + SearchShardTarget shardTarget1 = new SearchShardTarget("node1", new ShardId("test", "na", 0), null); + addQuerySearchResult(ctx1, shardTarget1, profiled, 0, results); + + // Add first shard result + final ShardSearchContextId ctx2 = new ShardSearchContextId(UUIDs.base64UUID(), 124); + SearchShardTarget shardTarget2 = new SearchShardTarget("node2", new ShardId("test", "na", 1), null); + addQuerySearchResult(ctx2, shardTarget2, profiled, 1, results); + + AtomicBoolean chunkedFetchUsed = new AtomicBoolean(false); + + // Create the coordination action that will be called for chunked fetch + TransportFetchPhaseCoordinationAction fetchCoordinationAction = new TransportFetchPhaseCoordinationAction( + mockTransportService, + new ActionFilters(Collections.emptySet()), + new ActiveFetchPhaseTasks(), + newLimitedBreakerService(ByteSizeValue.ofMb(10)), + new NamedWriteableRegistry(Collections.emptyList()) + ) { + @Override + public void doExecute(Task task, Request request, ActionListener listener) { + chunkedFetchUsed.set(true); + FetchSearchResult fetchResult = new FetchSearchResult(); + try { + // Return result based on context ID + SearchShardTarget target = request.getShardFetchRequest().contextId().equals(ctx1) + ? shardTarget1 + : shardTarget2; + int docId = request.getShardFetchRequest().contextId().equals(ctx1) ? 42 : 43; + + fetchResult.setSearchShardTarget(target); + SearchHits hits = SearchHits.unpooled( + new SearchHit[] { SearchHit.unpooled(docId) }, + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + 1.0F + ); + fetchResult.shardResult(hits, fetchProfile(profiled)); + listener.onResponse(new Response(fetchResult)); + } finally { + fetchResult.decRef(); + } + } + }; + provideSearchTransportWithChunkedFetch(mockSearchPhaseContext, mockTransportService, threadPool, fetchCoordinationAction); + + SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); + FetchSearchPhase phase = new FetchSearchPhase(results, null, mockSearchPhaseContext, reducedQueryPhase) { + @Override + protected SearchPhase nextPhase( + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { + return searchPhaseFactory(mockSearchPhaseContext).apply(searchResponseSections, queryPhaseResults); + } + }; + + phase.run(); + mockSearchPhaseContext.assertNoFailure(); + + assertTrue("Chunked fetch should be used", chunkedFetchUsed.get()); + + SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); + assertNotNull(searchResponse); + assertEquals(2, searchResponse.getHits().getTotalHits().value()); + assertTrue(searchResponse.getHits().getAt(0).docId() == 42 || searchResponse.getHits().getAt(0).docId() == 43); + } finally { + mockSearchPhaseContext.results.close(); + var resp = mockSearchPhaseContext.searchResponse.get(); + if (resp != null) { + resp.decRef(); + } + } + } finally { + ThreadPool.terminate(threadPool, 10, TimeValue.timeValueSeconds(5).timeUnit()); + } + } + + public void testChunkedFetchUsedForPointInTimeQuery() throws Exception { + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); + mockSearchPhaseContext.getRequest() + .source(new SearchSourceBuilder().pointInTimeBuilder(new PointInTimeBuilder(new BytesArray("test-pit-id")))); + ThreadPool threadPool = new TestThreadPool("test"); + try { + TransportService mockTransportService = createMockTransportService(threadPool); + + try (SearchPhaseResults results = createSearchPhaseResults(mockSearchPhaseContext)) { + boolean profiled = randomBoolean(); + + final ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + SearchShardTarget shardTarget1 = new SearchShardTarget("node1", new ShardId("test", "na", 0), null); + addQuerySearchResult(ctx1, shardTarget1, profiled, 0, results); + + final ShardSearchContextId ctx2 = new ShardSearchContextId(UUIDs.base64UUID(), 124); + SearchShardTarget shardTarget2 = new SearchShardTarget("node2", new ShardId("test", "na", 1), null); + addQuerySearchResult(ctx2, shardTarget2, profiled, 1, results); + + AtomicBoolean chunkedFetchUsed = new AtomicBoolean(false); + TransportFetchPhaseCoordinationAction fetchCoordinationAction = new TransportFetchPhaseCoordinationAction( + mockTransportService, + new ActionFilters(Collections.emptySet()), + new ActiveFetchPhaseTasks(), + newLimitedBreakerService(ByteSizeValue.ofMb(10)), + new NamedWriteableRegistry(Collections.emptyList()) + ) { + @Override + public void doExecute(Task task, Request request, ActionListener listener) { + chunkedFetchUsed.set(true); + FetchSearchResult fetchResult = new FetchSearchResult(); + try { + SearchShardTarget target = request.getShardFetchRequest().contextId().equals(ctx1) + ? shardTarget1 + : shardTarget2; + int docId = request.getShardFetchRequest().contextId().equals(ctx1) ? 42 : 43; + fetchResult.setSearchShardTarget(target); + SearchHits hits = SearchHits.unpooled( + new SearchHit[] { SearchHit.unpooled(docId) }, + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + 1.0F + ); + fetchResult.shardResult(hits, fetchProfile(profiled)); + listener.onResponse(new Response(fetchResult)); + } finally { + fetchResult.decRef(); + } + } + }; + provideSearchTransportWithChunkedFetch(mockSearchPhaseContext, mockTransportService, threadPool, fetchCoordinationAction); + + SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); + + // PIT response generation requires query phase results for building the PIT id. + AtomicArray queryResults = new AtomicArray<>(2); + queryResults.set(0, results.getAtomicArray().get(0)); + queryResults.set(1, results.getAtomicArray().get(1)); + + FetchSearchPhase phase = new FetchSearchPhase(results, null, mockSearchPhaseContext, reducedQueryPhase) { + @Override + protected SearchPhase nextPhase( + SearchResponseSections searchResponseSections, + AtomicArray fetchResults + ) { + return searchPhaseFactoryBi(mockSearchPhaseContext, queryResults).apply(searchResponseSections, fetchResults); + } + }; + + phase.run(); + mockSearchPhaseContext.assertNoFailure(); + assertTrue("Chunked fetch should be used for PIT queries", chunkedFetchUsed.get()); + + SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); + assertNotNull(searchResponse); + assertNotNull("PIT id should be present in response", searchResponse.pointInTimeId()); + assertEquals(2, searchResponse.getHits().getTotalHits().value()); + } finally { + mockSearchPhaseContext.results.close(); + var resp = mockSearchPhaseContext.searchResponse.get(); + if (resp != null) { + resp.decRef(); + } + } + } finally { + ThreadPool.terminate(threadPool, 10, TimeValue.timeValueSeconds(5).timeUnit()); + } + } + + public void testChunkedFetchHandlesPartialShardFailure() throws Exception { + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); + ThreadPool threadPool = new TestThreadPool("test"); + try { + TransportService mockTransportService = createMockTransportService(threadPool); + + try (SearchPhaseResults results = createSearchPhaseResults(mockSearchPhaseContext)) { + boolean profiled = randomBoolean(); + + final ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + SearchShardTarget shardTarget1 = new SearchShardTarget("node1", new ShardId("test", "na", 0), null); + addQuerySearchResult(ctx1, shardTarget1, profiled, 0, results); + + final ShardSearchContextId ctx2 = new ShardSearchContextId(UUIDs.base64UUID(), 124); + SearchShardTarget shardTarget2 = new SearchShardTarget("node2", new ShardId("test", "na", 1), null); + addQuerySearchResult(ctx2, shardTarget2, profiled, 1, results); + + AtomicBoolean chunkedFetchUsed = new AtomicBoolean(false); + TransportFetchPhaseCoordinationAction fetchCoordinationAction = new TransportFetchPhaseCoordinationAction( + mockTransportService, + new ActionFilters(Collections.emptySet()), + new ActiveFetchPhaseTasks(), + newLimitedBreakerService(ByteSizeValue.ofMb(10)), + new NamedWriteableRegistry(Collections.emptyList()) + ) { + @Override + public void doExecute(Task task, Request request, ActionListener listener) { + chunkedFetchUsed.set(true); + if (request.getShardFetchRequest().contextId().equals(ctx2)) { + listener.onFailure(new RuntimeException("simulated chunked fetch failure")); + return; + } + + FetchSearchResult fetchResult = new FetchSearchResult(); + try { + fetchResult.setSearchShardTarget(shardTarget1); + SearchHits hits = SearchHits.unpooled( + new SearchHit[] { SearchHit.unpooled(42) }, + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + 1.0F + ); + fetchResult.shardResult(hits, fetchProfile(profiled)); + listener.onResponse(new Response(fetchResult)); + } finally { + fetchResult.decRef(); + } + } + }; + provideSearchTransportWithChunkedFetch(mockSearchPhaseContext, mockTransportService, threadPool, fetchCoordinationAction); + + SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); + FetchSearchPhase phase = new FetchSearchPhase(results, null, mockSearchPhaseContext, reducedQueryPhase) { + @Override + protected SearchPhase nextPhase( + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { + return searchPhaseFactory(mockSearchPhaseContext).apply(searchResponseSections, queryPhaseResults); + } + }; + + phase.run(); + mockSearchPhaseContext.assertNoFailure(); + assertTrue("Chunked fetch should be used", chunkedFetchUsed.get()); + + SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); + assertNotNull(searchResponse); + assertEquals(1, searchResponse.getFailedShards()); + assertEquals(1, searchResponse.getSuccessfulShards()); + assertEquals(1, searchResponse.getShardFailures().length); + assertEquals("simulated chunked fetch failure", searchResponse.getShardFailures()[0].getCause().getMessage()); + assertEquals(1, searchResponse.getHits().getHits().length); + } finally { + mockSearchPhaseContext.results.close(); + var resp = mockSearchPhaseContext.searchResponse.get(); + if (resp != null) { + resp.decRef(); + } + } + } finally { + ThreadPool.terminate(threadPool, 10, TimeValue.timeValueSeconds(5).timeUnit()); + } + } + + public void testChunkedFetchTreatsTaskCancellationAsShardFailure() throws Exception { + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); + ThreadPool threadPool = new TestThreadPool("test"); + try { + TransportService mockTransportService = createMockTransportService(threadPool); + + try (SearchPhaseResults results = createSearchPhaseResults(mockSearchPhaseContext)) { + boolean profiled = randomBoolean(); + + final ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + SearchShardTarget shardTarget1 = new SearchShardTarget("node1", new ShardId("test", "na", 0), null); + addQuerySearchResult(ctx1, shardTarget1, profiled, 0, results); + + final ShardSearchContextId ctx2 = new ShardSearchContextId(UUIDs.base64UUID(), 124); + SearchShardTarget shardTarget2 = new SearchShardTarget("node2", new ShardId("test", "na", 1), null); + addQuerySearchResult(ctx2, shardTarget2, profiled, 1, results); + + TransportFetchPhaseCoordinationAction fetchCoordinationAction = new TransportFetchPhaseCoordinationAction( + mockTransportService, + new ActionFilters(Collections.emptySet()), + new ActiveFetchPhaseTasks(), + newLimitedBreakerService(ByteSizeValue.ofMb(10)), + new NamedWriteableRegistry(Collections.emptyList()) + ) { + @Override + public void doExecute(Task task, Request request, ActionListener listener) { + if (request.getShardFetchRequest().contextId().equals(ctx2)) { + listener.onFailure(new TaskCancelledException("simulated cancellation")); + return; + } + + FetchSearchResult fetchResult = new FetchSearchResult(); + try { + fetchResult.setSearchShardTarget(shardTarget1); + SearchHits hits = SearchHits.unpooled( + new SearchHit[] { SearchHit.unpooled(42) }, + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + 1.0F + ); + fetchResult.shardResult(hits, fetchProfile(profiled)); + listener.onResponse(new Response(fetchResult)); + } finally { + fetchResult.decRef(); + } + } + }; + provideSearchTransportWithChunkedFetch(mockSearchPhaseContext, mockTransportService, threadPool, fetchCoordinationAction); + + SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); + FetchSearchPhase phase = new FetchSearchPhase(results, null, mockSearchPhaseContext, reducedQueryPhase) { + @Override + protected SearchPhase nextPhase( + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { + return searchPhaseFactory(mockSearchPhaseContext).apply(searchResponseSections, queryPhaseResults); + } + }; + + phase.run(); + mockSearchPhaseContext.assertNoFailure(); + + SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); + assertNotNull(searchResponse); + assertEquals(1, searchResponse.getFailedShards()); + assertEquals(1, searchResponse.getShardFailures().length); + assertTrue(searchResponse.getShardFailures()[0].getCause() instanceof TaskCancelledException); + assertEquals("simulated cancellation", searchResponse.getShardFailures()[0].getCause().getMessage()); + } finally { + mockSearchPhaseContext.results.close(); + var resp = mockSearchPhaseContext.searchResponse.get(); + if (resp != null) { + resp.decRef(); + } + } + } finally { + ThreadPool.terminate(threadPool, 10, TimeValue.timeValueSeconds(5).timeUnit()); + } + } + + /** + * Test that traditional fetch is used when fetchPhaseChunked is disabled + */ + public void testTraditionalFetchUsedWhenChunkedDisabled() throws Exception { + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); + ThreadPool threadPool = new TestThreadPool("test"); + try { + TransportService mockTransportService = createMockTransportService(threadPool); + + try (SearchPhaseResults results = createSearchPhaseResults(mockSearchPhaseContext)) { + boolean profiled = randomBoolean(); + + // Add first shard result + final ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + SearchShardTarget shardTarget1 = new SearchShardTarget("node1", new ShardId("test", "na", 0), null); + addQuerySearchResult(ctx1, shardTarget1, profiled, 0, results); + + // Add second shard result + final ShardSearchContextId ctx2 = new ShardSearchContextId(UUIDs.base64UUID(), 124); + SearchShardTarget shardTarget2 = new SearchShardTarget("node2", new ShardId("test", "na", 1), null); + addQuerySearchResult(ctx2, shardTarget2, profiled, 1, results); + + AtomicBoolean traditionalFetchUsed = new AtomicBoolean(false); + + provideSearchTransport( + mockSearchPhaseContext, + mockTransportService, + traditionalFetchUsed, + ctx1, + shardTarget1, + shardTarget2, + profiled + ); + + SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); + FetchSearchPhase phase = new FetchSearchPhase(results, null, mockSearchPhaseContext, reducedQueryPhase) { + @Override + protected SearchPhase nextPhase( + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { + return searchPhaseFactory(mockSearchPhaseContext).apply(searchResponseSections, queryPhaseResults); + } + }; + + phase.run(); + mockSearchPhaseContext.assertNoFailure(); + + assertTrue("Traditional fetch should be used when chunked fetch is disabled", traditionalFetchUsed.get()); + + SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); + assertNotNull(searchResponse); + assertEquals(2, searchResponse.getHits().getTotalHits().value()); + } finally { + mockSearchPhaseContext.results.close(); + var resp = mockSearchPhaseContext.searchResponse.get(); + if (resp != null) { + resp.decRef(); + } + } + } finally { + ThreadPool.terminate(threadPool, 10, TimeValue.timeValueSeconds(5).timeUnit()); + } + } + + /** + * Test that traditional fetch is used for scroll queries + */ + public void testTraditionalFetchUsedForScrollQuery() throws Exception { + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); + mockSearchPhaseContext.getRequest().scroll(TimeValue.timeValueMinutes(1)); + + ThreadPool threadPool = new TestThreadPool("test"); + try { + TransportService mockTransportService = createMockTransportService(threadPool); + + try (SearchPhaseResults results = createSearchPhaseResults(mockSearchPhaseContext)) { + boolean profiled = randomBoolean(); + + // Add first shard result + final ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + SearchShardTarget shardTarget1 = new SearchShardTarget("node1", new ShardId("test", "na", 0), null); + addQuerySearchResult(ctx1, shardTarget1, profiled, 0, results); + + // Add second shard result + final ShardSearchContextId ctx2 = new ShardSearchContextId(UUIDs.base64UUID(), 124); + SearchShardTarget shardTarget2 = new SearchShardTarget("node2", new ShardId("test", "na", 1), null); + addQuerySearchResult(ctx2, shardTarget2, profiled, 1, results); + + AtomicBoolean traditionalFetchUsed = new AtomicBoolean(false); + + provideSearchTransport( + mockSearchPhaseContext, + mockTransportService, + traditionalFetchUsed, + ctx1, + shardTarget1, + shardTarget2, + profiled + ); + + SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); + + // Store query results in an AtomicArray for scroll ID generation + AtomicArray queryResults = new AtomicArray<>(2); + queryResults.set(0, results.getAtomicArray().get(0)); + queryResults.set(1, results.getAtomicArray().get(1)); + + FetchSearchPhase phase = new FetchSearchPhase(results, null, mockSearchPhaseContext, reducedQueryPhase) { + @Override + protected SearchPhase nextPhase( + SearchResponseSections searchResponseSections, + AtomicArray fetchResults + ) { + // Pass the query results for scroll ID generation + return searchPhaseFactoryBi(mockSearchPhaseContext, queryResults).apply(searchResponseSections, fetchResults); + } + }; + + phase.run(); + mockSearchPhaseContext.assertNoFailure(); + + assertTrue("Traditional fetch should be used for scroll queries", traditionalFetchUsed.get()); + + SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); + assertNotNull(searchResponse); + assertEquals(2, searchResponse.getHits().getTotalHits().value()); + assertNotNull("Scroll ID should be present for scroll queries", searchResponse.getScrollId()); + } finally { + mockSearchPhaseContext.results.close(); + var resp = mockSearchPhaseContext.searchResponse.get(); + if (resp != null) { + resp.decRef(); + } + } + } finally { + ThreadPool.terminate(threadPool, 10, TimeValue.timeValueSeconds(5).timeUnit()); + } + } + + private static BiFunction, SearchPhase> searchPhaseFactoryBi( + MockSearchPhaseContext mockSearchPhaseContext, + AtomicArray queryResults + ) { + return (searchResponseSections, fetchResults) -> new SearchPhase("test") { + @Override + public void run() { + mockSearchPhaseContext.sendSearchResponse(searchResponseSections, queryResults); + } + }; + } + + /** + * Test that traditional fetch is used for CCS queries + */ + public void testTraditionalFetchUsedForCCSQuery() throws Exception { + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); + ThreadPool threadPool = new TestThreadPool("test"); + try { + TransportService mockTransportService = createMockTransportService(threadPool); + + try (SearchPhaseResults results = createSearchPhaseResults(mockSearchPhaseContext)) { + boolean profiled = randomBoolean(); + + // Add first shard result - CCS query with cluster alias + final ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + SearchShardTarget shardTarget1 = new SearchShardTarget("node1", new ShardId("test", "na", 0), "remote_cluster"); + addQuerySearchResult(ctx1, shardTarget1, profiled, 0, results); + + // Add second shard result - CCS query with cluster alias + final ShardSearchContextId ctx2 = new ShardSearchContextId(UUIDs.base64UUID(), 124); + SearchShardTarget shardTarget2 = new SearchShardTarget("node2", new ShardId("test", "na", 1), "remote_cluster"); + addQuerySearchResult(ctx2, shardTarget2, profiled, 1, results); + + AtomicBoolean traditionalFetchUsed = new AtomicBoolean(false); + + provideSearchTransport( + mockSearchPhaseContext, + mockTransportService, + traditionalFetchUsed, + ctx1, + shardTarget1, + shardTarget2, + profiled + ); + + SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); + FetchSearchPhase phase = new FetchSearchPhase(results, null, mockSearchPhaseContext, reducedQueryPhase) { + @Override + protected SearchPhase nextPhase( + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { + return searchPhaseFactory(mockSearchPhaseContext).apply(searchResponseSections, queryPhaseResults); + } + }; + + phase.run(); + mockSearchPhaseContext.assertNoFailure(); + + assertTrue("Traditional fetch should be used for CCS queries", traditionalFetchUsed.get()); + + SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); + assertNotNull(searchResponse); + assertEquals(2, searchResponse.getHits().getTotalHits().value()); + } finally { + mockSearchPhaseContext.results.close(); + var resp = mockSearchPhaseContext.searchResponse.get(); + if (resp != null) { + resp.decRef(); + } + } + } finally { + ThreadPool.terminate(threadPool, 10, TimeValue.timeValueSeconds(5).timeUnit()); + } + } + + public void testTraditionalFetchUsedWhenDataNodeDoesNotSupportChunkedTransportVersion() throws Exception { + ThreadPool threadPool = new TestThreadPool("test"); + ClusterService clusterService = new ClusterService( + Settings.EMPTY, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), + threadPool, + null + ); + MockSearchPhaseContext mockSearchPhaseContext = null; + + MockTransportService transportService = MockTransportService.createNewService( + Settings.EMPTY, + VersionInformation.CURRENT, + TransportFetchPhaseCoordinationAction.CHUNKED_FETCH_PHASE, + threadPool + ); + + try { + transportService.start(); + transportService.acceptIncomingRequests(); + + AtomicBoolean traditionalFetchUsed = new AtomicBoolean(false); + AtomicBoolean chunkedPathUsed = new AtomicBoolean(false); + + transportService.registerRequestHandler( + SearchTransportService.FETCH_ID_ACTION_NAME, + threadPool.executor(ThreadPool.Names.GENERIC), + ShardFetchSearchRequest::new, + (req, channel, task) -> { + traditionalFetchUsed.set(true); + FetchSearchResult result = createFetchSearchResult(); + channel.sendResponse(result); + } + ); + + transportService.registerRequestHandler( + TransportFetchPhaseCoordinationAction.TYPE.name(), + threadPool.executor(ThreadPool.Names.GENERIC), + TransportFetchPhaseCoordinationAction.Request::new, + (req, channel, task) -> { + chunkedPathUsed.set(true); + channel.sendResponse(new IllegalStateException("chunked coordination path should not be used")); + } + ); + + SearchTransportService searchTransportService = new SearchTransportService(transportService, null, null); + searchTransportService.setSearchService(new StubSearchService(true, clusterService, threadPool)); + + mockSearchPhaseContext = new MockSearchPhaseContext(1); + mockSearchPhaseContext.searchTransport = searchTransportService; + + ShardId shardId = new ShardId("test", "na", 0); + SearchShardTarget shardTarget = new SearchShardTarget("node1", shardId, null); + ShardFetchSearchRequest shardFetchRequest = createShardFetchSearchRequest(shardId); + + Transport.Connection delegateConnection = transportService.getConnection(transportService.getLocalNode()); + TransportVersion unsupportedVersion = TransportVersion.fromId( + TransportFetchPhaseCoordinationAction.CHUNKED_FETCH_PHASE.id() - 1 + ); + Transport.Connection oldVersionConnection = withTransportVersion(delegateConnection, unsupportedVersion); + + PlainActionFuture future = new PlainActionFuture<>(); + searchTransportService.sendExecuteFetch(oldVersionConnection, shardFetchRequest, mockSearchPhaseContext, shardTarget, future); + + FetchSearchResult result = future.actionGet(10, TimeUnit.SECONDS); + result.decRef(); + + assertTrue("Traditional fetch should be used for unsupported data node version", traditionalFetchUsed.get()); + assertFalse("Chunked coordination path should not be used", chunkedPathUsed.get()); + } finally { + if (mockSearchPhaseContext != null) { + mockSearchPhaseContext.results.close(); + var resp = mockSearchPhaseContext.searchResponse.get(); + if (resp != null) { + resp.decRef(); + } + } + transportService.close(); + clusterService.close(); + ThreadPool.terminate(threadPool, 10, TimeValue.timeValueSeconds(5).timeUnit()); + } + } + + private SearchPhaseResults createSearchPhaseResults(MockSearchPhaseContext mockSearchPhaseContext) { + SearchPhaseController controller = new SearchPhaseController((t, s) -> InternalAggregationTestCase.emptyReduceContextBuilder()); + + return controller.newSearchPhaseResults( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + () -> false, + SearchProgressListener.NOOP, + mockSearchPhaseContext.getRequest(), + 2, + exc -> {} + ); + } + + private void provideSearchTransportWithChunkedFetch( + MockSearchPhaseContext mockSearchPhaseContext, + TransportService transportService, + ThreadPool threadPool, + TransportFetchPhaseCoordinationAction fetchCoordinationAction + ) { + ClusterService clusterService = new ClusterService( + Settings.EMPTY, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), + threadPool, + null + ); + + transportService.start(); + transportService.acceptIncomingRequests(); + + SearchTransportService searchTransport = new SearchTransportService(transportService, null, null); + searchTransport.setSearchService(new StubSearchService(true, clusterService, threadPool)); + + mockSearchPhaseContext.searchTransport = searchTransport; + mockSearchPhaseContext.addReleasable(clusterService::close); + } + + /** + * Minimal stub SearchService that only implements fetchPhaseChunked() + */ + private static class StubSearchService extends SearchService { + private final boolean chunkedEnabled; + + StubSearchService(boolean chunkedEnabled, ClusterService clusterService, ThreadPool threadPool) { + super( + clusterService, + null, // indicesService + threadPool, + null, // scriptService + null, // bigArrays + new FetchPhase(Collections.emptyList()), + newLimitedBreakerService(ByteSizeValue.ofMb(10)), + EmptySystemIndices.INSTANCE.getExecutorSelector(), + Tracer.NOOP, + OnlinePrewarmingService.NOOP + ); + this.chunkedEnabled = chunkedEnabled; + } + + @Override + public boolean fetchPhaseChunked() { + return chunkedEnabled; + } + + @Override + protected void doStart() {} + + @Override + protected void doStop() {} + + @Override + protected void doClose() {} + } + + private void provideSearchTransport( + MockSearchPhaseContext mockSearchPhaseContext, + TransportService mockTransportService, + AtomicBoolean traditionalFetchUsed, + ShardSearchContextId ctx1, + SearchShardTarget shardTarget1, + SearchShardTarget shardTarget2, + boolean profiled + ) { + mockSearchPhaseContext.searchTransport = new SearchTransportService(mockTransportService, null, null) { + @Override + public void sendExecuteFetch( + Transport.Connection connection, + ShardFetchSearchRequest request, + AbstractSearchAsyncAction context, + SearchShardTarget shardTarget, + ActionListener listener + ) { + traditionalFetchUsed.set(true); + FetchSearchResult fetchResult = new FetchSearchResult(); + try { + SearchShardTarget target = request.contextId().equals(ctx1) ? shardTarget1 : shardTarget2; + int docId = request.contextId().equals(ctx1) ? 42 : 43; + + fetchResult.setSearchShardTarget(target); + SearchHits hits = SearchHits.unpooled( + new SearchHit[] { SearchHit.unpooled(docId) }, + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + 1.0F + ); + fetchResult.shardResult(hits, fetchProfile(profiled)); + listener.onResponse(fetchResult); + } finally { + fetchResult.decRef(); + } + } + }; + } + + private ShardFetchSearchRequest createShardFetchSearchRequest(ShardId shardId) { + ShardSearchContextId contextId = new ShardSearchContextId("test", randomLong()); + OriginalIndices originalIndices = new OriginalIndices( + new String[] { "test-index" }, + IndicesOptions.strictExpandOpenAndForbidClosed() + ); + ShardSearchRequest shardSearchRequest = new ShardSearchRequest(shardId, System.currentTimeMillis(), AliasFilter.EMPTY); + List docIds = List.of(0, 1, 2, 3, 4); + return new ShardFetchSearchRequest(originalIndices, contextId, shardSearchRequest, docIds, null, null, RescoreDocIds.EMPTY, null); + } + + private Transport.Connection withTransportVersion(Transport.Connection delegate, TransportVersion version) { + return new CloseableConnection() { + @Override + public DiscoveryNode getNode() { + return delegate.getNode(); + } + + @Override + public TransportVersion getTransportVersion() { + return version; + } + + @Override + public void sendRequest(long requestId, String action, TransportRequest request, TransportRequestOptions options) + throws TransportException { + try { + delegate.sendRequest(requestId, action, request, options); + } catch (Exception e) { + throw new TransportException("failed to send request", e); + } + } + + @Override + public void close() { + delegate.close(); + super.close(); + } + + @Override + public void onRemoved() { + delegate.onRemoved(); + super.onRemoved(); + } + }; + } + + private void addQuerySearchResult( + ShardSearchContextId ctx, + SearchShardTarget shardTarget, + boolean profiled, + int shardIndex, + SearchPhaseResults results + ) { + QuerySearchResult queryResult = new QuerySearchResult(ctx, shardTarget, null); + try { + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42 + shardIndex, 1.0F) }), + 1.0F + ), + new DocValueFormat[0] + ); + queryResult.size(10); + queryResult.setShardIndex(shardIndex); + addProfiling(profiled, queryResult); + results.consumeResult(queryResult, () -> {}); + } finally { + queryResult.decRef(); + } + } + + private FetchSearchResult createFetchSearchResult() { + ShardSearchContextId contextId = new ShardSearchContextId("test", randomLong()); + FetchSearchResult result = new FetchSearchResult(contextId, new SearchShardTarget("node", new ShardId("test", "na", 0), null)); + result.shardResult(SearchHits.unpooled(new SearchHit[0], null, Float.NaN), null); + return result; + } + + private TransportService createMockTransportService(ThreadPool threadPool) { + DiscoveryNode localNode = new DiscoveryNode( + "local", + "local", + new TransportAddress(TransportAddress.META_ADDRESS, 9200), + Collections.emptyMap(), + Collections.emptySet(), + null + ); + + return new MockTransport().createTransportService( + Settings.EMPTY, + threadPool, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + boundAddress -> localNode, + null, + Collections.emptySet() + ); + } +} diff --git a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java index 32802b9024bb7..428184be190de 100644 --- a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java @@ -239,7 +239,8 @@ public void testFetchTwoDocument() throws Exception { public void sendExecuteFetch( Transport.Connection connection, ShardFetchSearchRequest request, - SearchTask task, + AbstractSearchAsyncAction context, + SearchShardTarget shardTarget, ActionListener listener ) { FetchSearchResult fetchResult = new FetchSearchResult(); @@ -350,7 +351,8 @@ public void testFailFetchOneDoc() throws Exception { public void sendExecuteFetch( Transport.Connection connection, ShardFetchSearchRequest request, - SearchTask task, + AbstractSearchAsyncAction context, + SearchShardTarget shardTarget, ActionListener listener ) { if (request.contextId().getId() == 321) { @@ -453,7 +455,8 @@ public void testFetchDocsConcurrently() throws Exception { public void sendExecuteFetch( Transport.Connection connection, ShardFetchSearchRequest request, - SearchTask task, + AbstractSearchAsyncAction context, + SearchShardTarget shardTarget, ActionListener listener ) { new Thread(() -> { @@ -591,7 +594,8 @@ public void testExceptionFailsPhase() throws Exception { public void sendExecuteFetch( Transport.Connection connection, ShardFetchSearchRequest request, - SearchTask task, + AbstractSearchAsyncAction context, + SearchShardTarget shardTarget, ActionListener listener ) { FetchSearchResult fetchResult = new FetchSearchResult(); @@ -706,7 +710,8 @@ public void testCleanupIrrelevantContexts() throws Exception { // contexts that public void sendExecuteFetch( Transport.Connection connection, ShardFetchSearchRequest request, - SearchTask task, + AbstractSearchAsyncAction context, + SearchShardTarget shardTarget, ActionListener listener ) { FetchSearchResult fetchResult = new FetchSearchResult(); @@ -759,7 +764,7 @@ public void sendExecuteFetch( } - private static BiFunction, SearchPhase> searchPhaseFactory( + static BiFunction, SearchPhase> searchPhaseFactory( MockSearchPhaseContext mockSearchPhaseContext ) { return (searchResponse, scrollId) -> new SearchPhase("test") { @@ -770,13 +775,13 @@ protected void run() { }; } - private static void addProfiling(boolean profiled, QuerySearchResult queryResult) { + static void addProfiling(boolean profiled, QuerySearchResult queryResult) { if (profiled) { queryResult.profileResults(new SearchProfileQueryPhaseResult(List.of(), null)); } } - private static ProfileResult fetchProfile(boolean profiled) { + public static ProfileResult fetchProfile(boolean profiled) { return profiled ? new ProfileResult("fetch", "fetch", Map.of(), Map.of(), FETCH_PROFILE_TIME, List.of()) : null; } diff --git a/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java b/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java index 4bc5b0f2ee582..3ac551c527acc 100644 --- a/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java +++ b/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java @@ -10,12 +10,15 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.PageCacheRecycler; @@ -26,7 +29,10 @@ import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.internal.ShardSearchContextId; import org.elasticsearch.telemetry.TelemetryProvider; +import org.elasticsearch.transport.CloseableConnection; import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportRequest; +import org.elasticsearch.transport.TransportRequestOptions; import org.junit.Assert; import java.util.ArrayList; @@ -61,7 +67,7 @@ public MockSearchPhaseContext(int numShards) { new NamedWriteableRegistry(List.of()), mock(SearchTransportService.class), new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofBytes(Long.MAX_VALUE)), - (clusterAlias, nodeId) -> null, + (clusterAlias, nodeId) -> createMockConnection(nodeId), null, null, Runnable::run, @@ -82,6 +88,32 @@ public MockSearchPhaseContext(int numShards) { numSuccess = new AtomicInteger(numShards); } + private static Transport.Connection createMockConnection(String nodeId) { + return new CloseableConnection() { + @Override + public DiscoveryNode getNode() { + return new DiscoveryNode( + nodeId, // nodeName + nodeId, // nodeId + new TransportAddress(TransportAddress.META_ADDRESS, 9300), // address + Collections.emptyMap(), // attributes + Collections.emptySet(), // roles + null // versionInfo (null = use current) + ); + } + + @Override + public TransportVersion getTransportVersion() { + return TransportVersion.current(); + } + + @Override + public void sendRequest(long requestId, String action, TransportRequest request, TransportRequestOptions options) { + // Mock implementation - not needed for these tests + } + }; + } + public void assertNoFailure() { if (phaseFailure.get() != null) { throw new AssertionError(phaseFailure.get()); diff --git a/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java b/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java index cc31052327a86..ecff19cd15950 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java @@ -416,7 +416,7 @@ public void testSearchWhileIndexDeleted() throws InterruptedException { null/* not a scroll */ ); PlainActionFuture listener = new PlainActionFuture<>(); - service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, emptyMap()), listener); + service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, emptyMap()), null, listener); listener.get(); if (useScroll) { // have to free context since this test does not remove the index from IndicesService. @@ -627,7 +627,7 @@ public void onFailure(Exception e) { throw new AssertionError("No failure should have been raised", e); } }; - service.executeFetchPhase(fetchRequest, searchTask, fetchListener); + service.executeFetchPhase(fetchRequest, searchTask, null, fetchListener); fetchListener.get(); } catch (Exception ex) { if (queryResult != null) { @@ -781,7 +781,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) for (SearchHit hit : hits.getHits()) { assertEquals(hit.getRank(), 3 + index); assertTrue(hit.getScore() >= 0); - assertEquals(hit.getFields().get(fetchFieldName).getValue(), fetchFieldValue + "_" + hit.docId()); + assertEquals(hit.getFields().get(fetchFieldName).getValue(), fetchFieldValue + "_" + hit.getId()); index++; } } diff --git a/server/src/test/java/org/elasticsearch/search/TelemetryMetrics/ShardSearchPhaseAPMMetricsTests.java b/server/src/test/java/org/elasticsearch/search/TelemetryMetrics/ShardSearchPhaseAPMMetricsTests.java index cf4e45d67d162..81ed33eaf4ec0 100644 --- a/server/src/test/java/org/elasticsearch/search/TelemetryMetrics/ShardSearchPhaseAPMMetricsTests.java +++ b/server/src/test/java/org/elasticsearch/search/TelemetryMetrics/ShardSearchPhaseAPMMetricsTests.java @@ -107,8 +107,7 @@ public void testMetricsDfsQueryThenFetch() { assertAttributes(dfsMeasurements, false, false); final List queryMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(QUERY_SEARCH_PHASE_METRIC); assertEquals(num_primaries, queryMeasurements.size()); - final List fetchMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(FETCH_SEARCH_PHASE_METRIC); - assertEquals(1, fetchMeasurements.size()); + final List fetchMeasurements = getFetchMeasurementsEventually(1); assertAttributes(fetchMeasurements, false, false); } @@ -125,8 +124,7 @@ public void testMetricsDfsQueryThenFetchSystem() { final List queryMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(QUERY_SEARCH_PHASE_METRIC); assertEquals(1, queryMeasurements.size()); assertAttributes(queryMeasurements, true, false); - final List fetchMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(FETCH_SEARCH_PHASE_METRIC); - assertEquals(1, fetchMeasurements.size()); + final List fetchMeasurements = getFetchMeasurementsEventually(1); assertAttributes(fetchMeasurements, true, false); } @@ -138,8 +136,7 @@ public void testSearchTransportMetricsQueryThenFetch() { final List queryMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(QUERY_SEARCH_PHASE_METRIC); assertEquals(num_primaries, queryMeasurements.size()); assertAttributes(queryMeasurements, false, false); - final List fetchMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(FETCH_SEARCH_PHASE_METRIC); - assertEquals(1, fetchMeasurements.size()); + final List fetchMeasurements = getFetchMeasurementsEventually(1); assertAttributes(fetchMeasurements, false, false); } @@ -153,8 +150,7 @@ public void testSearchTransportMetricsQueryThenFetchSystem() { final List queryMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(QUERY_SEARCH_PHASE_METRIC); assertEquals(1, queryMeasurements.size()); assertAttributes(queryMeasurements, true, false); - final List fetchMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(FETCH_SEARCH_PHASE_METRIC); - assertEquals(1, fetchMeasurements.size()); + final List fetchMeasurements = getFetchMeasurementsEventually(1); assertAttributes(fetchMeasurements, true, false); } @@ -190,8 +186,7 @@ public void testSearchMultipleIndices() { assertEquals(1, systemTarget); } { - final List fetchMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(FETCH_SEARCH_PHASE_METRIC); - assertEquals(2, fetchMeasurements.size()); + final List fetchMeasurements = getFetchMeasurementsEventually(2); int userTarget = 0; int systemTarget = 0; for (Measurement measurement : fetchMeasurements) { @@ -214,7 +209,7 @@ public void testSearchMultipleIndices() { } } - public void testSearchTransportMetricsScroll() { + public void testSearchTransportMetricsScroll() throws Exception { assertScrollResponsesAndHitCount( client(), TimeValue.timeValueSeconds(60), @@ -226,10 +221,21 @@ public void testSearchTransportMetricsScroll() { assertAttributes(queryMeasurements, false, true); // No hits, no fetching done if (response.getHits().getHits().length > 0) { + try { + assertBusy(() -> { + final List fetchMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement( + FETCH_SEARCH_PHASE_METRIC + ); + assertThat(fetchMeasurements.size(), Matchers.greaterThan(0)); + }); + } catch (Exception e) { + throw new RuntimeException(e); + } + + // Get fresh list for subsequent assertions final List fetchMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement( FETCH_SEARCH_PHASE_METRIC ); - assertThat(fetchMeasurements.size(), Matchers.greaterThan(0)); int numFetchShards = Math.min(2, num_primaries); assertThat(fetchMeasurements.size(), Matchers.lessThanOrEqualTo(numFetchShards)); assertAttributes(fetchMeasurements, false, true); @@ -254,8 +260,7 @@ public void testSearchTransportMetricsScrollSystem() { final List queryMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(QUERY_SEARCH_PHASE_METRIC); assertEquals(1, queryMeasurements.size()); assertAttributes(queryMeasurements, true, true); - final List fetchMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(FETCH_SEARCH_PHASE_METRIC); - assertEquals(1, fetchMeasurements.size()); + final List fetchMeasurements = getFetchMeasurementsEventually(1); assertAttributes(fetchMeasurements, true, true); resetMeter(); } @@ -276,8 +281,7 @@ public void testCanMatchSearch() { final List queryMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(QUERY_SEARCH_PHASE_METRIC); assertEquals(num_primaries, queryMeasurements.size()); assertAttributes(queryMeasurements, false, false); - final List fetchMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(FETCH_SEARCH_PHASE_METRIC); - assertEquals(1, fetchMeasurements.size()); + final List fetchMeasurements = getFetchMeasurementsEventually(1); assertAttributes(fetchMeasurements, false, false); } @@ -307,8 +311,7 @@ public void testTimeRangeFilterOneResult() { final List queryMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(QUERY_SEARCH_PHASE_METRIC); assertEquals(1, queryMeasurements.size()); assertTimeRangeAttributes(queryMeasurements, ".others", true, false); - final List fetchMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(FETCH_SEARCH_PHASE_METRIC); - assertEquals(1, fetchMeasurements.size()); + final List fetchMeasurements = getFetchMeasurementsEventually(1); assertTimeRangeAttributes(fetchMeasurements, ".others", true, false); } @@ -321,8 +324,7 @@ public void testTimeRangeFilterRetrieverOneResult() { final List queryMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(QUERY_SEARCH_PHASE_METRIC); assertEquals(1, queryMeasurements.size()); assertTimeRangeAttributes(queryMeasurements, ".others", true, false); - final List fetchMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(FETCH_SEARCH_PHASE_METRIC); - assertEquals(1, fetchMeasurements.size()); + final List fetchMeasurements = getFetchMeasurementsEventually(1); assertTimeRangeAttributes(fetchMeasurements, ".others", true, false); } @@ -341,8 +343,7 @@ public void testTimeRangeFilterCompoundRetrieverOneResult() { // compound retriever does its own search as an async action, whose metrics are recorded separately assertEquals(2, queryMeasurements.size()); assertTimeRangeAttributes(queryMeasurements, ".others", true, true); - final List fetchMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(FETCH_SEARCH_PHASE_METRIC); - assertEquals(2, fetchMeasurements.size()); + final List fetchMeasurements = getFetchMeasurementsEventually(2); assertTimeRangeAttributes(fetchMeasurements, ".others", true, true); } @@ -400,9 +401,8 @@ public void testTimeRangeFilterAllResults() { // the time range filter field because no range query is executed at the shard level. assertEquals("older_than_14_days", attributes.get("time_range_filter_from")); } - final List fetchMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(FETCH_SEARCH_PHASE_METRIC); + final List fetchMeasurements = getFetchMeasurementsEventually(queryMeasurements.size()); // in this case, each shard queried has results to be fetched - assertEquals(queryMeasurements.size(), fetchMeasurements.size()); // no range info stored because we had no bounds after rewrite, basically a match_all for (Measurement measurement : fetchMeasurements) { Map attributes = measurement.attributes(); @@ -460,6 +460,18 @@ private void resetMeter() { getTestTelemetryPlugin().resetMeter(); } + private List getFetchMeasurementsEventually(int expectedSize) { + try { + assertBusy(() -> { + final List fetchMeasurements = getTestTelemetryPlugin().getLongHistogramMeasurement(FETCH_SEARCH_PHASE_METRIC); + assertEquals(expectedSize, fetchMeasurements.size()); + }); + } catch (Exception e) { + throw new RuntimeException(e); + } + return getTestTelemetryPlugin().getLongHistogramMeasurement(FETCH_SEARCH_PHASE_METRIC); + } + private TestTelemetryPlugin getTestTelemetryPlugin() { return getInstanceFromNode(PluginsService.class).filterPlugins(TestTelemetryPlugin.class).toList().get(0); } diff --git a/server/src/test/java/org/elasticsearch/search/fetch/FetchPhaseDocsIteratorTests.java b/server/src/test/java/org/elasticsearch/search/fetch/FetchPhaseDocsIteratorTests.java index c8d1b6721c64b..8742978ca3db6 100644 --- a/server/src/test/java/org/elasticsearch/search/fetch/FetchPhaseDocsIteratorTests.java +++ b/server/src/test/java/org/elasticsearch/search/fetch/FetchPhaseDocsIteratorTests.java @@ -16,25 +16,50 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.RefCountingListener; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchService; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.fetch.FetchPhaseDocsIterator.IterateResult; +import org.elasticsearch.search.fetch.chunk.FetchPhaseResponseChunk; import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.BytesRefRecycler; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThan; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; public class FetchPhaseDocsIteratorTests extends ESTestCase { public void testInOrderIteration() throws IOException { - int docCount = random().nextInt(300) + 100; Directory directory = newDirectory(); RandomIndexWriter writer = new RandomIndexWriter(random(), directory); @@ -78,7 +103,7 @@ protected SearchHit nextDoc(int doc) { } }; - SearchHit[] hits = it.iterate(null, reader, docs, randomBoolean(), new QuerySearchResult()); + SearchHit[] hits = it.iterate(null, reader, docs, randomBoolean(), new QuerySearchResult()).hits; assertThat(hits.length, equalTo(docs.length)); for (int i = 0; i < hits.length; i++) { @@ -88,11 +113,9 @@ protected SearchHit nextDoc(int doc) { reader.close(); directory.close(); - } public void testExceptions() throws IOException { - int docCount = randomIntBetween(300, 400); Directory directory = newDirectory(); RandomIndexWriter writer = new RandomIndexWriter(random(), directory); @@ -113,9 +136,7 @@ public void testExceptions() throws IOException { FetchPhaseDocsIterator it = new FetchPhaseDocsIterator() { @Override - protected void setNextReader(LeafReaderContext ctx, int[] docsInLeaf) { - - } + protected void setNextReader(LeafReaderContext ctx, int[] docsInLeaf) {} @Override protected SearchHit nextDoc(int doc) { @@ -137,6 +158,436 @@ protected SearchHit nextDoc(int doc) { directory.close(); } + public void testIterateAsyncNullOrEmptyDocIds() throws Exception { + CircuitBreaker circuitBreaker = newLimitedBreaker(ByteSizeValue.ofBytes(Long.MAX_VALUE)); + TestChunkWriter chunkWriter = new TestChunkWriter(circuitBreaker); + AtomicReference sendFailure = new AtomicReference<>(); + AtomicBoolean cancelled = new AtomicBoolean(false); + + StreamingFetchPhaseDocsIterator it = createStreamingIterator(); + + PlainActionFuture future = new PlainActionFuture<>(); + CountDownLatch refsComplete = new CountDownLatch(1); + RefCountingListener refs = new RefCountingListener(ActionListener.running(refsComplete::countDown)); + + it.iterateAsync( + createShardTarget(), + null, + randomBoolean() ? null : new int[0], + chunkWriter, + 1024, + refs, + 4, + sendFailure, + cancelled::get, + future + ); + + IterateResult result = future.get(10, TimeUnit.SECONDS); + refs.close(); + assertTrue(refsComplete.await(10, TimeUnit.SECONDS)); + + assertThat(result.hits, notNullValue()); + assertThat(result.hits.length, equalTo(0)); + assertThat(result.lastChunkBytes, nullValue()); + assertThat(circuitBreaker.getUsed(), equalTo(0L)); + result.close(); + } + + public void testFetchPhaseMaxInFlightChunksSettingIsReadCorrectly() { + Settings customSettings = Settings.builder().put(SearchService.FETCH_PHASE_MAX_IN_FLIGHT_CHUNKS.getKey(), 7).build(); + assertThat(SearchService.FETCH_PHASE_MAX_IN_FLIGHT_CHUNKS.get(customSettings), equalTo(7)); + assertThat(SearchService.FETCH_PHASE_MAX_IN_FLIGHT_CHUNKS.get(Settings.EMPTY), equalTo(3)); + } + + public void testIterateAsyncSingleDocument() throws Exception { + LuceneDocs docs = createDocs(1); + CircuitBreaker circuitBreaker = newLimitedBreaker(ByteSizeValue.ofBytes(Long.MAX_VALUE)); + TestChunkWriter chunkWriter = new TestChunkWriter(circuitBreaker); + AtomicReference sendFailure = new AtomicReference<>(); + AtomicBoolean cancelled = new AtomicBoolean(false); + + PlainActionFuture future = new PlainActionFuture<>(); + CountDownLatch refsComplete = new CountDownLatch(1); + RefCountingListener refs = new RefCountingListener(ActionListener.running(refsComplete::countDown)); + + createStreamingIterator().iterateAsync( + createShardTarget(), + docs.reader, + new int[] { 0 }, + chunkWriter, + 1024, + refs, + 4, + sendFailure, + cancelled::get, + future + ); + + IterateResult result = future.get(10, TimeUnit.SECONDS); + refs.close(); + assertTrue(refsComplete.await(10, TimeUnit.SECONDS)); + + // Single doc becomes the last chunk + assertThat(result.hits, nullValue()); + assertThat(result.lastChunkBytes, notNullValue()); + assertThat(result.lastChunkHitCount, equalTo(1)); + + // No intermediate chunks sent + assertThat(chunkWriter.getSentChunks().size(), equalTo(0)); + + // Pages for the last chunk are reserved on the CB + assertThat(circuitBreaker.getUsed(), greaterThan(0L)); + + result.close(); + assertThat(circuitBreaker.getUsed(), equalTo(0L)); + + docs.reader.close(); + docs.directory.close(); + } + + public void testIterateAsyncAllDocsInSingleChunk() throws Exception { + LuceneDocs docs = createDocs(5); + CircuitBreaker circuitBreaker = newLimitedBreaker(ByteSizeValue.ofBytes(Long.MAX_VALUE)); + TestChunkWriter chunkWriter = new TestChunkWriter(circuitBreaker); + AtomicReference sendFailure = new AtomicReference<>(); + AtomicBoolean cancelled = new AtomicBoolean(false); + + PlainActionFuture future = new PlainActionFuture<>(); + CountDownLatch refsComplete = new CountDownLatch(1); + RefCountingListener refs = new RefCountingListener(ActionListener.running(refsComplete::countDown)); + + createStreamingIterator().iterateAsync( + createShardTarget(), + docs.reader, + docs.docIds, + chunkWriter, + 1024 * 1024, // Large chunk size + refs, + 4, + sendFailure, + cancelled::get, + future + ); + + IterateResult result = future.get(10, TimeUnit.SECONDS); + refs.close(); + assertTrue(refsComplete.await(10, TimeUnit.SECONDS)); + + // No intermediate chunks sent - all in last chunk + assertThat(chunkWriter.getSentChunks().size(), equalTo(0)); + assertThat(result.lastChunkBytes, notNullValue()); + assertThat(result.lastChunkHitCount, equalTo(5)); + + result.close(); + assertThat(circuitBreaker.getUsed(), equalTo(0L)); + + docs.reader.close(); + docs.directory.close(); + } + + public void testIterateAsyncMultipleChunks() throws Exception { + LuceneDocs docs = createDocs(100); + CircuitBreaker circuitBreaker = newLimitedBreaker(ByteSizeValue.ofBytes(Long.MAX_VALUE)); + TestChunkWriter chunkWriter = new TestChunkWriter(circuitBreaker); + AtomicReference sendFailure = new AtomicReference<>(); + AtomicBoolean cancelled = new AtomicBoolean(false); + + PlainActionFuture future = new PlainActionFuture<>(); + CountDownLatch refsComplete = new CountDownLatch(1); + RefCountingListener refs = new RefCountingListener(ActionListener.running(refsComplete::countDown)); + + createStreamingIterator().iterateAsync( + createShardTarget(), + docs.reader, + docs.docIds, + chunkWriter, + 50, // Small chunk size to force multiple chunks + refs, + 4, + sendFailure, + cancelled::get, + future + ); + + IterateResult result = future.get(10, TimeUnit.SECONDS); + refs.close(); + assertTrue(refsComplete.await(10, TimeUnit.SECONDS)); + + // Verify chunks are in order by from index + List chunks = chunkWriter.getSentChunks(); + long expectedSequenceStart = 0L; + for (SentChunkInfo chunk : chunks) { + assertThat(chunk.sequenceStart, equalTo(expectedSequenceStart)); + expectedSequenceStart += chunk.hitCount; + } + assertThat(result.lastChunkSequenceStart, equalTo(expectedSequenceStart)); + + // Should have multiple chunks sent + last chunk held back + assertThat(chunkWriter.getSentChunks().size(), greaterThan(0)); + assertThat(result.lastChunkBytes, notNullValue()); + + // Total hits across all chunks should equal docCount + int totalHits = chunkWriter.getSentChunks().stream().mapToInt(c -> c.hitCount).sum() + result.lastChunkHitCount; + assertThat(totalHits, equalTo(100)); + + // Only last chunk's pages should remain reserved + assertThat(circuitBreaker.getUsed(), greaterThan(0L)); + + result.close(); + assertThat(circuitBreaker.getUsed(), equalTo(0L)); + + docs.reader.close(); + docs.directory.close(); + } + + public void testIterateAsyncCircuitBreakerTrips() throws Exception { + LuceneDocs docs = createDocs(100); + CircuitBreaker circuitBreaker = newLimitedBreaker(ByteSizeValue.ofBytes(100L)); + TestChunkWriter chunkWriter = new TestChunkWriter(true, circuitBreaker); + AtomicReference sendFailure = new AtomicReference<>(); + AtomicBoolean cancelled = new AtomicBoolean(false); + + PlainActionFuture future = new PlainActionFuture<>(); + CountDownLatch refsComplete = new CountDownLatch(1); + RefCountingListener refs = new RefCountingListener(ActionListener.running(refsComplete::countDown)); + + createStreamingIterator().iterateAsync( + createShardTarget(), + docs.reader, + docs.docIds, + chunkWriter, + 50, + refs, + 4, + sendFailure, + cancelled::get, + future + ); + chunkWriter.ackAll(); + + Exception e = expectThrows(Exception.class, () -> future.get(10, TimeUnit.SECONDS)); + Throwable actual = e instanceof ExecutionException ? e.getCause() : e; + assertThat(actual, instanceOf(CircuitBreakingException.class)); + + refs.close(); + assertTrue(refsComplete.await(10, TimeUnit.SECONDS)); + + assertBusy(() -> assertThat(circuitBreaker.getUsed(), equalTo(0L))); + + docs.reader.close(); + docs.directory.close(); + } + + public void testIterateAsyncCancellationBeforeFetchStart() throws Exception { + LuceneDocs docs = createDocs(100); + CircuitBreaker circuitBreaker = newLimitedBreaker(ByteSizeValue.ofBytes(Long.MAX_VALUE)); + TestChunkWriter chunkWriter = new TestChunkWriter(circuitBreaker); + AtomicReference sendFailure = new AtomicReference<>(); + AtomicBoolean cancelled = new AtomicBoolean(true); // Already cancelled + + PlainActionFuture future = new PlainActionFuture<>(); + CountDownLatch refsComplete = new CountDownLatch(1); + RefCountingListener refs = new RefCountingListener(ActionListener.running(refsComplete::countDown)); + + createStreamingIterator().iterateAsync( + createShardTarget(), + docs.reader, + docs.docIds, + chunkWriter, + 50, + refs, + 4, + sendFailure, + cancelled::get, + future + ); + + Exception e = expectThrows(Exception.class, () -> future.get(10, TimeUnit.SECONDS)); + assertTrue( + "Expected cancellation but got: " + e, + e.getCause() instanceof TaskCancelledException || e.getMessage().contains("cancelled") + ); + + refs.close(); + assertTrue(refsComplete.await(10, TimeUnit.SECONDS)); + + assertBusy(() -> assertThat(circuitBreaker.getUsed(), equalTo(0L))); + + docs.reader.close(); + docs.directory.close(); + } + + public void testIterateAsyncCancellationDuringDocProduction() throws Exception { + LuceneDocs docs = createDocs(1000); + CircuitBreaker circuitBreaker = newLimitedBreaker(ByteSizeValue.ofBytes(Long.MAX_VALUE)); + TestChunkWriter chunkWriter = new TestChunkWriter(circuitBreaker); + AtomicReference sendFailure = new AtomicReference<>(); + AtomicBoolean cancelled = new AtomicBoolean(false); + + // Iterator that cancels after processing some docs + AtomicInteger processedDocs = new AtomicInteger(0); + StreamingFetchPhaseDocsIterator it = new StreamingFetchPhaseDocsIterator() { + @Override + protected void setNextReader(LeafReaderContext ctx, int[] docsInLeaf) {} + + @Override + protected SearchHit nextDoc(int doc) { + if (processedDocs.incrementAndGet() == 100) { + cancelled.set(true); + } + return new SearchHit(doc); + } + }; + + PlainActionFuture future = new PlainActionFuture<>(); + CountDownLatch refsComplete = new CountDownLatch(1); + RefCountingListener refs = new RefCountingListener(ActionListener.running(refsComplete::countDown)); + + it.iterateAsync(createShardTarget(), docs.reader, docs.docIds, chunkWriter, 50, refs, 4, sendFailure, cancelled::get, future); + + Exception e = expectThrows(Exception.class, () -> future.get(10, TimeUnit.SECONDS)); + assertTrue( + "Expected TaskCancelledException but got: " + e, + e.getCause() instanceof TaskCancelledException || e.getMessage().contains("cancelled") + ); + + refs.close(); + assertTrue(refsComplete.await(10, TimeUnit.SECONDS)); + + assertBusy(() -> assertThat(circuitBreaker.getUsed(), equalTo(0L))); + + docs.reader.close(); + docs.directory.close(); + } + + public void testIterateAsyncDocProducerException() throws Exception { + LuceneDocs docs = createDocs(100); + CircuitBreaker circuitBreaker = newLimitedBreaker(ByteSizeValue.ofBytes(Long.MAX_VALUE)); + TestChunkWriter chunkWriter = new TestChunkWriter(circuitBreaker); + AtomicReference sendFailure = new AtomicReference<>(); + AtomicBoolean cancelled = new AtomicBoolean(false); + + // Iterator that throws after processing some docs + StreamingFetchPhaseDocsIterator it = new StreamingFetchPhaseDocsIterator() { + private int count = 0; + + @Override + protected void setNextReader(LeafReaderContext ctx, int[] docsInLeaf) {} + + @Override + protected SearchHit nextDoc(int doc) { + if (++count > 50) { + throw new RuntimeException("Simulated producer failure"); + } + return new SearchHit(doc); + } + }; + + PlainActionFuture future = new PlainActionFuture<>(); + CountDownLatch refsComplete = new CountDownLatch(1); + RefCountingListener refs = new RefCountingListener(ActionListener.running(refsComplete::countDown)); + + it.iterateAsync(createShardTarget(), docs.reader, docs.docIds, chunkWriter, 50, refs, 4, sendFailure, cancelled::get, future); + + Exception e = expectThrows(Exception.class, () -> future.get(10, TimeUnit.SECONDS)); + assertThat(e.getCause().getMessage(), containsString("Simulated producer failure")); + + refs.close(); + assertTrue(refsComplete.await(10, TimeUnit.SECONDS)); + + assertBusy(() -> assertThat(circuitBreaker.getUsed(), equalTo(0L))); + + docs.reader.close(); + docs.directory.close(); + } + + public void testIterateAsyncPreExistingSendFailure() throws Exception { + LuceneDocs docs = createDocs(100); + CircuitBreaker circuitBreaker = newLimitedBreaker(ByteSizeValue.ofBytes(Long.MAX_VALUE)); + TestChunkWriter chunkWriter = new TestChunkWriter(circuitBreaker); + AtomicReference sendFailure = new AtomicReference<>(new IOException("Pre-existing failure")); // Send Failure + AtomicBoolean cancelled = new AtomicBoolean(false); + + PlainActionFuture future = new PlainActionFuture<>(); + CountDownLatch refsComplete = new CountDownLatch(1); + RefCountingListener refs = new RefCountingListener(ActionListener.running(refsComplete::countDown)); + + createStreamingIterator().iterateAsync( + createShardTarget(), + docs.reader, + docs.docIds, + chunkWriter, + 50, + refs, + 4, + sendFailure, + cancelled::get, + future + ); + + Exception e = expectThrows(Exception.class, () -> future.get(10, TimeUnit.SECONDS)); + assertThat(e.getCause(), instanceOf(IOException.class)); + assertThat(e.getCause().getMessage(), containsString("Pre-existing failure")); + + refs.close(); + assertTrue(refsComplete.await(10, TimeUnit.SECONDS)); + + assertBusy(() -> assertThat(circuitBreaker.getUsed(), equalTo(0L))); + + docs.reader.close(); + docs.directory.close(); + } + + public void testIterateAsyncSendFailure() throws Exception { + LuceneDocs docs = createDocs(100); + CircuitBreaker circuitBreaker = newLimitedBreaker(ByteSizeValue.ofBytes(Long.MAX_VALUE)); + // Chunk writer that fails after first chunk + AtomicInteger chunkCount = new AtomicInteger(0); + TestChunkWriter chunkWriter = new TestChunkWriter(circuitBreaker) { + @Override + public void writeResponseChunk(FetchPhaseResponseChunk chunk, ActionListener listener) { + if (chunkCount.incrementAndGet() > 1) { + chunk.close(); + listener.onFailure(new IOException("Simulated send failure")); + } else { + super.writeResponseChunk(chunk, listener); + } + } + }; + AtomicReference sendFailure = new AtomicReference<>(); + AtomicBoolean cancelled = new AtomicBoolean(false); + + PlainActionFuture future = new PlainActionFuture<>(); + CountDownLatch refsComplete = new CountDownLatch(1); + RefCountingListener refs = new RefCountingListener(ActionListener.running(refsComplete::countDown)); + + createStreamingIterator().iterateAsync( + createShardTarget(), + docs.reader, + docs.docIds, + chunkWriter, + 50, + refs, + 4, + sendFailure, + cancelled::get, + future + ); + + Exception e = expectThrows(Exception.class, () -> future.get(10, TimeUnit.SECONDS)); + assertThat(e.getCause(), instanceOf(IOException.class)); + assertThat(e.getCause().getMessage(), containsString("Simulated send failure")); + + refs.close(); + assertTrue(refsComplete.await(10, TimeUnit.SECONDS)); + + assertBusy(() -> assertThat(circuitBreaker.getUsed(), equalTo(0L))); + + docs.reader.close(); + docs.directory.close(); + } + private static int[] randomDocIds(int maxDoc) { List integers = new ArrayList<>(); int v = 0; @@ -151,4 +602,94 @@ private static int[] randomDocIds(int maxDoc) { return integers.stream().mapToInt(i -> i).toArray(); } + private static SearchShardTarget createShardTarget() { + return new SearchShardTarget("node1", new ShardId(new Index("test", "uuid"), 0), null); + } + + private static StreamingFetchPhaseDocsIterator createStreamingIterator() { + return new StreamingFetchPhaseDocsIterator() { + @Override + protected void setNextReader(LeafReaderContext ctx, int[] docsInLeaf) {} + + @Override + protected SearchHit nextDoc(int doc) { + return new SearchHit(doc); + } + }; + } + + private LuceneDocs createDocs(int numDocs) throws IOException { + Directory directory = newDirectory(); + RandomIndexWriter writer = new RandomIndexWriter(random(), directory); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(new StringField("field", "value" + i, Field.Store.NO)); + writer.addDocument(doc); + if (i % 30 == 0) { + writer.commit(); // Create multiple segments + } + } + writer.commit(); + IndexReader reader = writer.getReader(); + writer.close(); + + int[] docIds = new int[numDocs]; + for (int i = 0; i < numDocs; i++) { + docIds[i] = i; + } + + return new LuceneDocs(directory, reader, docIds); + } + + private record LuceneDocs(Directory directory, IndexReader reader, int[] docIds) {} + + /** + * Simple record to track sent chunk info + */ + private record SentChunkInfo(int hitCount, long sequenceStart, int expectedTotalDocs) {} + + private static class TestChunkWriter implements FetchPhaseResponseChunk.Writer { + + protected final List sentChunks = new CopyOnWriteArrayList<>(); + private final List> pendingAcks = new CopyOnWriteArrayList<>(); + private final boolean delayAcks; + private final CircuitBreaker circuitBreaker; + + private final PageCacheRecycler recycler = new PageCacheRecycler(Settings.EMPTY); + + TestChunkWriter(CircuitBreaker circuitBreaker) { + this(false, circuitBreaker); + } + + TestChunkWriter(boolean delayAcks, CircuitBreaker circuitBreaker) { + this.delayAcks = delayAcks; + this.circuitBreaker = circuitBreaker; + } + + @Override + public void writeResponseChunk(FetchPhaseResponseChunk chunk, ActionListener listener) { + sentChunks.add(new SentChunkInfo(chunk.hitCount(), chunk.sequenceStart(), chunk.expectedTotalDocs())); + if (delayAcks) { + pendingAcks.add(listener); + } else { + listener.onResponse(null); + } + } + + public void ackAll() { + for (ActionListener ack : pendingAcks) { + ack.onResponse(null); + } + pendingAcks.clear(); + } + + @Override + public RecyclerBytesStreamOutput newNetworkBytesStream() { + return new RecyclerBytesStreamOutput(new BytesRefRecycler(recycler), circuitBreaker); + } + + public List getSentChunks() { + return sentChunks; + } + } } diff --git a/server/src/test/java/org/elasticsearch/search/fetch/chunk/ActiveFetchPhaseTasksTests.java b/server/src/test/java/org/elasticsearch/search/fetch/chunk/ActiveFetchPhaseTasksTests.java new file mode 100644 index 0000000000000..1dd20c2fe0875 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/fetch/chunk/ActiveFetchPhaseTasksTests.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.fetch.chunk; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.test.ESTestCase; + +public class ActiveFetchPhaseTasksTests extends ESTestCase { + + private static final ShardId TEST_SHARD_ID = new ShardId(new Index("test-index", "test-uuid"), 0); + + public void testAcquireRegisteredStream() { + ActiveFetchPhaseTasks tasks = new ActiveFetchPhaseTasks(); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(0, 10, new NoopCircuitBreaker("test")); + Releasable registration = tasks.registerResponseBuilder(123L, TEST_SHARD_ID, stream); + + try { + FetchPhaseResponseStream acquired = tasks.acquireResponseStream(123L, TEST_SHARD_ID); + assertSame(stream, acquired); + assertTrue(acquired.hasReferences()); + acquired.decRef(); + registration.close(); + + expectThrows(ResourceNotFoundException.class, () -> tasks.acquireResponseStream(123L, TEST_SHARD_ID)); + } finally { + stream.decRef(); + } + } + + public void testDuplicateRegisterThrows() { + ActiveFetchPhaseTasks tasks = new ActiveFetchPhaseTasks(); + FetchPhaseResponseStream first = new FetchPhaseResponseStream(0, 10, new NoopCircuitBreaker("test")); + FetchPhaseResponseStream second = new FetchPhaseResponseStream(0, 10, new NoopCircuitBreaker("test")); + Releasable registration = tasks.registerResponseBuilder(123L, TEST_SHARD_ID, first); + + try { + Exception e = expectThrows(IllegalStateException.class, () -> tasks.registerResponseBuilder(123L, TEST_SHARD_ID, second)); + assertEquals("already executing fetch task [123]", e.getMessage()); + } finally { + registration.close(); + first.decRef(); + second.decRef(); + } + } + + public void testCloseRegistrationRemovesTaskAndAllowsReregister() { + ActiveFetchPhaseTasks tasks = new ActiveFetchPhaseTasks(); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(0, 10, new NoopCircuitBreaker("test")); + Releasable registration = tasks.registerResponseBuilder(123L, TEST_SHARD_ID, stream); + + try { + registration.close(); + expectThrows(ResourceNotFoundException.class, () -> tasks.acquireResponseStream(123L, TEST_SHARD_ID)); + + FetchPhaseResponseStream replacement = new FetchPhaseResponseStream(0, 10, new NoopCircuitBreaker("test")); + Releasable secondRegistration = tasks.registerResponseBuilder(123L, TEST_SHARD_ID, replacement); + try { + FetchPhaseResponseStream acquired = tasks.acquireResponseStream(123L, TEST_SHARD_ID); + assertSame(replacement, acquired); + acquired.decRef(); + } finally { + secondRegistration.close(); + replacement.decRef(); + } + } finally { + stream.decRef(); + } + } + + public void testAcquireMissingTaskThrowsResourceNotFound() { + ActiveFetchPhaseTasks tasks = new ActiveFetchPhaseTasks(); + expectThrows(ResourceNotFoundException.class, () -> tasks.acquireResponseStream(999L, TEST_SHARD_ID)); + } + + public void testAcquireFailsWhenStreamAlreadyClosed() { + ActiveFetchPhaseTasks tasks = new ActiveFetchPhaseTasks(); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(0, 10, new NoopCircuitBreaker("test")); + Releasable registration = tasks.registerResponseBuilder(123L, TEST_SHARD_ID, stream); + registration.close(); + + try { + expectThrows(ResourceNotFoundException.class, () -> tasks.acquireResponseStream(123L, TEST_SHARD_ID)); + } finally { + stream.decRef(); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/search/fetch/chunk/FetchPhaseResponseChunkTests.java b/server/src/test/java/org/elasticsearch/search/fetch/chunk/FetchPhaseResponseChunkTests.java new file mode 100644 index 0000000000000..085c6842ff1a0 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/fetch/chunk/FetchPhaseResponseChunkTests.java @@ -0,0 +1,172 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.fetch.chunk; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.TransportVersionUtils; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.sameInstance; + +public class FetchPhaseResponseChunkTests extends ESTestCase { + + private static final ShardId TEST_SHARD_ID = new ShardId(new Index("test-index", "test-uuid"), 0); + + public void testToReleasableBytesReferenceTransfersOwnership() throws IOException { + SearchHit hit = createHit(1); + try { + AtomicBoolean released = new AtomicBoolean(false); + ReleasableBytesReference serializedHits = new ReleasableBytesReference(serializeHits(hit), () -> released.set(true)); + + FetchPhaseResponseChunk chunk = new FetchPhaseResponseChunk(TEST_SHARD_ID, serializedHits, 1, 10, 0L); + try { + assertThat(chunk.getBytesLength(), greaterThan(0L)); + + ReleasableBytesReference wireBytes = chunk.toReleasableBytesReference(42L); + try { + assertThat(chunk.getBytesLength(), equalTo(0L)); + assertFalse(released.get()); + + try (StreamInput in = wireBytes.streamInput()) { + assertThat(in.readVLong(), equalTo(42L)); + FetchPhaseResponseChunk decoded = new FetchPhaseResponseChunk(in); + try { + assertThat(decoded.shardId(), equalTo(TEST_SHARD_ID)); + assertThat(decoded.hitCount(), equalTo(1)); + assertThat(getIdFromSource(decoded.getHits()[0]), equalTo(1)); + } finally { + decoded.close(); + } + } + } finally { + wireBytes.decRef(); + } + + assertTrue(released.get()); + } finally { + chunk.close(); + } + } finally { + hit.decRef(); + } + } + + public void testGetHitsCachesDeserializedHits() throws IOException { + SearchHit first = createHit(1); + SearchHit second = createHit(2); + try { + FetchPhaseResponseChunk chunk = new FetchPhaseResponseChunk(TEST_SHARD_ID, serializeHits(first, second), 2, 10, 0L); + try { + SearchHit[] firstRead = chunk.getHits(); + SearchHit[] secondRead = chunk.getHits(); + assertThat(secondRead, sameInstance(firstRead)); + assertThat(secondRead.length, equalTo(2)); + assertThat(getIdFromSource(secondRead[0]), equalTo(1)); + assertThat(getIdFromSource(secondRead[1]), equalTo(2)); + } finally { + chunk.close(); + } + } finally { + first.decRef(); + second.decRef(); + } + } + + public void testGetHitsReturnsEmptyWhenHitCountIsZero() throws IOException { + FetchPhaseResponseChunk chunk = new FetchPhaseResponseChunk(TEST_SHARD_ID, BytesArray.EMPTY, 0, 0, 0L); + try { + assertThat(chunk.getHits().length, equalTo(0)); + } finally { + chunk.close(); + } + } + + public void testCloseClearsChunkState() throws IOException { + SearchHit hit = createHit(7); + try { + FetchPhaseResponseChunk chunk = new FetchPhaseResponseChunk(TEST_SHARD_ID, serializeHits(hit), 1, 1, 0L); + + SearchHit[] hits = chunk.getHits(); + assertTrue(hits[0].hasReferences()); + + chunk.close(); + assertThat(chunk.getBytesLength(), equalTo(0L)); + assertThat(chunk.getHits().length, equalTo(0)); + } finally { + hit.decRef(); + } + } + + public void testSerializationRoundTripAcrossCompatibleTransportVersion() throws IOException { + SearchHit hit = createHit(42); + try { + FetchPhaseResponseChunk chunk = new FetchPhaseResponseChunk(TEST_SHARD_ID, serializeHits(hit), 1, 1, 0L); + try { + TransportVersion version = randomBoolean() ? TransportVersion.current() : TransportVersionUtils.randomCompatibleVersion(); + FetchPhaseResponseChunk roundTripped = copyWriteable( + chunk, + new NamedWriteableRegistry(Collections.emptyList()), + FetchPhaseResponseChunk::new, + version + ); + try { + assertThat(roundTripped.shardId(), equalTo(TEST_SHARD_ID)); + assertThat(roundTripped.hitCount(), equalTo(1)); + assertThat(roundTripped.expectedTotalDocs(), equalTo(1)); + assertThat(roundTripped.sequenceStart(), equalTo(0L)); + assertThat(getIdFromSource(roundTripped.getHits()[0]), equalTo(42)); + } finally { + roundTripped.close(); + } + } finally { + chunk.close(); + } + } finally { + hit.decRef(); + } + } + + private SearchHit createHit(int id) { + SearchHit hit = new SearchHit(id); + hit.sourceRef(new BytesArray("{\"id\":" + id + "}")); + return hit; + } + + private BytesReference serializeHits(SearchHit... hits) throws IOException { + try (BytesStreamOutput out = new BytesStreamOutput()) { + for (SearchHit hit : hits) { + hit.writeTo(out); + } + return out.bytes(); + } + } + + private int getIdFromSource(SearchHit hit) { + Number id = (Number) XContentHelper.convertToMap(hit.getSourceRef(), false, XContentType.JSON).v2().get("id"); + return id.intValue(); + } +} diff --git a/server/src/test/java/org/elasticsearch/search/fetch/chunk/FetchPhaseResponseStreamTests.java b/server/src/test/java/org/elasticsearch/search/fetch/chunk/FetchPhaseResponseStreamTests.java new file mode 100644 index 0000000000000..c609f6b7ed34e --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/fetch/chunk/FetchPhaseResponseStreamTests.java @@ -0,0 +1,741 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.fetch.chunk; + +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.fetch.FetchSearchResult; +import org.elasticsearch.search.internal.ShardSearchContextId; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; + +/** + * Unit tests for {@link FetchPhaseResponseStream}. + */ +public class FetchPhaseResponseStreamTests extends ESTestCase { + + private static final int SHARD_INDEX = 0; + private static final ShardId TEST_SHARD_ID = new ShardId(new Index("test-index", "test-uuid"), 0); + + public void testEmptyStream() { + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 0, new NoopCircuitBreaker("test")); + try { + FetchSearchResult result = buildFinalResult(stream); + try { + assertThat(result.hits().getHits().length, equalTo(0)); + } finally { + result.decRef(); + } + } finally { + stream.decRef(); + } + } + + public void testSingleHit() throws IOException { + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 1, new NoopCircuitBreaker("test")); + + try { + writeChunk(stream, createChunk(0, 1, 0)); + + FetchSearchResult result = buildFinalResult(stream); + + try { + SearchHit[] hits = result.hits().getHits(); + assertThat(hits.length, equalTo(1)); + assertThat(getIdFromSource(hits[0]), equalTo(0)); + } finally { + result.decRef(); + } + } finally { + stream.decRef(); + } + } + + public void testChunksArriveInOrder() throws IOException { + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 15, new NoopCircuitBreaker("test")); + + try { + // Send 3 chunks in order: sequence 0-4, 5-9, 10-14 + writeChunk(stream, createChunk(0, 5, 0)); // hits 0-4, sequence starts at 0 + writeChunk(stream, createChunk(5, 5, 5)); // hits 5-9, sequence starts at 5 + writeChunk(stream, createChunk(10, 5, 10)); // hits 10-14, sequence starts at 10 + + FetchSearchResult result = buildFinalResult(stream); + + try { + SearchHit[] hits = result.hits().getHits(); + assertThat(hits.length, equalTo(15)); + + for (int i = 0; i < 15; i++) { + assertThat("Hit at position " + i + " should have correct id in source", getIdFromSource(hits[i]), equalTo(i)); + } + } finally { + result.decRef(); + } + } finally { + stream.decRef(); + } + } + + public void testChunksArriveRandomOrder() throws IOException { + CircuitBreaker breaker = new NoopCircuitBreaker("test"); + int numChunks = 10; + int hitsPerChunk = 5; + int totalHits = numChunks * hitsPerChunk; + + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, totalHits, breaker); + + try { + // Create chunks and shuffle them + List chunks = new ArrayList<>(); + for (int i = 0; i < numChunks; i++) { + int startId = i * hitsPerChunk; + long sequenceStart = i * hitsPerChunk; + chunks.add(createChunk(startId, hitsPerChunk, sequenceStart)); + } + Collections.shuffle(chunks, random()); + + // Write in shuffled order + for (FetchPhaseResponseChunk chunk : chunks) { + writeChunk(stream, chunk); + } + + FetchSearchResult result = buildFinalResult(stream); + + try { + SearchHit[] hits = result.hits().getHits(); + assertThat(hits.length, equalTo(totalHits)); + + for (int i = 0; i < totalHits; i++) { + assertThat("Hit at position " + i + " should have correct id in source", getIdFromSource(hits[i]), equalTo(i)); + } + } finally { + result.decRef(); + } + } finally { + stream.decRef(); + } + } + + public void testAddHitWithSequence() { + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 5, new NoopCircuitBreaker("test")); + + try { + stream.addHitWithSequence(createHit(3), 3); + stream.addHitWithSequence(createHit(1), 1); + stream.addHitWithSequence(createHit(4), 4); + stream.addHitWithSequence(createHit(0), 0); + stream.addHitWithSequence(createHit(2), 2); + + FetchSearchResult result = buildFinalResult(stream); + + try { + SearchHit[] hits = result.hits().getHits(); + assertThat(hits.length, equalTo(5)); + + for (int i = 0; i < 5; i++) { + assertThat(getIdFromSource(hits[i]), equalTo(i)); + } + } finally { + result.decRef(); + } + } finally { + stream.decRef(); + } + } + + public void testMixedChunkAndSingleHitAddition() throws IOException { + CircuitBreaker breaker = new NoopCircuitBreaker("test"); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 10, breaker); + + try { + // Add a chunk (sequence 0-4) + writeChunk(stream, createChunk(0, 5, 0)); + + // Add individual hits for sequence 5-9 in random order + stream.addHitWithSequence(createHit(7), 7); + stream.addHitWithSequence(createHit(5), 5); + stream.addHitWithSequence(createHit(9), 9); + stream.addHitWithSequence(createHit(6), 6); + stream.addHitWithSequence(createHit(8), 8); + + FetchSearchResult result = buildFinalResult(stream); + + try { + SearchHit[] hits = result.hits().getHits(); + assertThat(hits.length, equalTo(10)); + + for (int i = 0; i < 10; i++) { + assertThat(getIdFromSource(hits[i]), equalTo(i)); + } + } finally { + result.decRef(); + } + } finally { + stream.decRef(); + } + } + + public void testNonContiguousSequenceNumbers() throws IOException { + CircuitBreaker breaker = new NoopCircuitBreaker("test"); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 6, breaker); + + try { + // Chunks with gaps in sequence + writeChunk(stream, createChunkWithSequence(0, 2, 0)); // id 0,1 -> seq 0, 1 + writeChunk(stream, createChunkWithSequence(2, 2, 10)); // id 2,3 -> seq 10, 11 + writeChunk(stream, createChunkWithSequence(4, 2, 5)); // id 4,5 -> seq 5, 6 + + FetchSearchResult result = buildFinalResult(stream); + + try { + SearchHit[] hits = result.hits().getHits(); + assertThat(hits.length, equalTo(6)); + + // source ids: 0, 1, 4, 5, 2, 3 + assertThat(getIdFromSource(hits[0]), equalTo(0)); // seq 0 + assertThat(getIdFromSource(hits[1]), equalTo(1)); // seq 1 + assertThat(getIdFromSource(hits[2]), equalTo(4)); // seq 5 + assertThat(getIdFromSource(hits[3]), equalTo(5)); // seq 6 + assertThat(getIdFromSource(hits[4]), equalTo(2)); // seq 10 + assertThat(getIdFromSource(hits[5]), equalTo(3)); // seq 11 + } finally { + result.decRef(); + } + } finally { + stream.decRef(); + } + } + + // ==================== Circuit Breaker Tests ==================== + + public void testCircuitBreakerBytesTracked() throws IOException { + CircuitBreaker breaker = newLimitedBreaker(ByteSizeValue.ofBytes(Long.MAX_VALUE)); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 10, breaker); + + try { + long bytesBefore = breaker.getUsed(); + assertThat(bytesBefore, equalTo(0L)); + + FetchPhaseResponseChunk chunk1 = createChunkWithSourceSize(0, 5, 0, 1024); + long chunk1Bytes = chunk1.getBytesLength(); + writeChunk(stream, chunk1); + + long bytesAfterChunk1 = breaker.getUsed(); + assertThat("Circuit breaker should track chunk1 bytes", bytesAfterChunk1, equalTo(chunk1Bytes)); + + FetchPhaseResponseChunk chunk2 = createChunkWithSourceSize(5, 5, 5, 1024); + long chunk2Bytes = chunk2.getBytesLength(); + writeChunk(stream, chunk2); + + long bytesAfterChunk2 = breaker.getUsed(); + assertThat("Circuit breaker should track both chunks' bytes", bytesAfterChunk2, equalTo(chunk1Bytes + chunk2Bytes)); + } finally { + stream.decRef(); + } + } + + public void testCircuitBreakerBytesReleasedOnClose() throws IOException { + CircuitBreaker breaker = newLimitedBreaker(ByteSizeValue.ofBytes(Long.MAX_VALUE)); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 10, breaker); + + FetchPhaseResponseChunk chunk1 = createChunkWithSourceSize(0, 5, 0, 1024); + FetchPhaseResponseChunk chunk2 = createChunkWithSourceSize(5, 5, 5, 1024); + long expectedBytes = chunk1.getBytesLength() + chunk2.getBytesLength(); + + writeChunk(stream, chunk1); + writeChunk(stream, chunk2); + + long bytesBeforeClose = breaker.getUsed(); + assertThat("Should have bytes tracked", bytesBeforeClose, equalTo(expectedBytes)); + + stream.decRef(); + + long bytesAfterClose = breaker.getUsed(); + assertThat("All breaker bytes should be released", bytesAfterClose, equalTo(0L)); + } + + public void testCircuitBreakerTrips() throws IOException { + FetchPhaseResponseChunk testChunk = createChunkWithSourceSize(0, 5, 0, 2048); + long chunkSize = testChunk.getBytesLength(); + + // Set limit smaller than chunk size + CircuitBreaker breaker = newLimitedBreaker(ByteSizeValue.ofBytes(chunkSize - 1)); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 10, breaker); + + try { + FetchPhaseResponseChunk chunk = createChunkWithSourceSize(0, 5, 0, 2048); + expectThrows(CircuitBreakingException.class, () -> writeChunk(stream, chunk)); + } finally { + stream.decRef(); + } + } + + public void testCircuitBreakerTripsOnSecondChunk() throws IOException { + FetchPhaseResponseChunk chunk1 = createChunkWithSourceSize(0, 5, 0, 1024); + FetchPhaseResponseChunk chunk2 = createChunkWithSourceSize(5, 5, 5, 1024); + long chunk1Size = chunk1.getBytesLength(); + long chunk2Size = chunk2.getBytesLength(); + + // Set limit to allow first chunk but not second + long limit = chunk1Size + (chunk2Size / 2); + CircuitBreaker breaker = newLimitedBreaker(ByteSizeValue.ofBytes(limit)); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 10, breaker); + + try { + writeChunk(stream, createChunkWithSourceSize(0, 5, 0, 1024)); + assertThat("First chunk should be tracked", breaker.getUsed(), greaterThan(0L)); + + expectThrows(CircuitBreakingException.class, () -> { writeChunk(stream, createChunkWithSourceSize(5, 5, 5, 1024)); }); + } finally { + stream.decRef(); + } + } + + public void testCircuitBreakerReleasedOnCloseWithoutBuildingResult() throws IOException { + CircuitBreaker breaker = newLimitedBreaker(ByteSizeValue.ofBytes(Long.MAX_VALUE)); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 10, breaker); + + // Write chunks but don't call buildFinalResult + writeChunk(stream, createChunkWithSourceSize(0, 5, 0, 1024)); + writeChunk(stream, createChunkWithSourceSize(5, 5, 5, 1024)); + + long bytesBeforeClose = breaker.getUsed(); + assertThat("Should have bytes tracked", bytesBeforeClose, greaterThan(0L)); + + stream.decRef(); + + assertThat("All breaker bytes should be released", breaker.getUsed(), equalTo(0L)); + } + + // ==================== Reference Counting Tests ==================== + + public void testHitsIncRefOnWrite() throws IOException { + CircuitBreaker breaker = new NoopCircuitBreaker("test"); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 5, breaker); + + try { + FetchPhaseResponseChunk chunk = createChunk(0, 5, 0); + writeChunk(stream, chunk); + + FetchSearchResult result = buildFinalResult(stream); + + try { + // Hits should still have references after writeChunk + for (SearchHit hit : result.hits().getHits()) { + assertTrue("Hit should have references", hit.hasReferences()); + } + } finally { + result.decRef(); + } + } finally { + stream.decRef(); + } + } + + // ==================== Score Handling Tests ==================== + + public void testMaxScoreCalculation() throws IOException { + CircuitBreaker breaker = new NoopCircuitBreaker("test"); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 5, breaker); + + try { + float[] scores = { 1.5f, 3.2f, 2.1f, 4.8f, 0.9f }; + FetchPhaseResponseChunk chunk = createChunkWithScores(0, scores, 0); + writeChunk(stream, chunk); + + FetchSearchResult result = buildFinalResult(stream); + + try { + assertThat(result.hits().getMaxScore(), equalTo(4.8f)); + } finally { + result.decRef(); + } + } finally { + stream.decRef(); + } + } + + public void testMaxScoreWithNaN() throws IOException { + CircuitBreaker breaker = new NoopCircuitBreaker("test"); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 3, breaker); + + try { + float[] scores = { Float.NaN, Float.NaN, Float.NaN }; + FetchPhaseResponseChunk chunk = createChunkWithScores(0, scores, 0); + writeChunk(stream, chunk); + + FetchSearchResult result = stream.buildFinalResult( + new ShardSearchContextId("test", 1), + new SearchShardTarget("node1", TEST_SHARD_ID, null), + null + ); + + try { + assertTrue(Float.isNaN(result.hits().getMaxScore())); + } finally { + result.decRef(); + } + } finally { + stream.decRef(); + } + } + + public void testMaxScoreWithMixedNaNAndValid() throws IOException { + CircuitBreaker breaker = new NoopCircuitBreaker("test"); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 4, breaker); + + try { + float[] scores = { Float.NaN, 2.5f, Float.NaN, 1.8f }; + FetchPhaseResponseChunk chunk = createChunkWithScores(0, scores, 0); + writeChunk(stream, chunk); + + FetchSearchResult result = buildFinalResult(stream); + + try { + assertThat(result.hits().getMaxScore(), equalTo(2.5f)); + } finally { + result.decRef(); + } + } finally { + stream.decRef(); + } + } + + /** + * Test concurrent chunk writes from multiple threads. + * Verifies thread-safety of the ConcurrentLinkedQueue usage. + * Simulates shards + */ + public void testConcurrentChunkWrites() throws Exception { + CircuitBreaker breaker = new NoopCircuitBreaker("test"); + int numThreads = 10; + int hitsPerThread = 10; + int totalHits = numThreads * hitsPerThread; + + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, totalHits, breaker); + + try { + CyclicBarrier barrier = new CyclicBarrier(numThreads); + CountDownLatch done = new CountDownLatch(numThreads); + AtomicBoolean error = new AtomicBoolean(false); + + for (int t = 0; t < numThreads; t++) { + final int threadId = t; + new Thread(() -> { + try { + barrier.await(); + // Each thread writes its own chunk with distinct sequence range + int startId = threadId * hitsPerThread; + long sequenceStart = threadId * hitsPerThread; + FetchPhaseResponseChunk chunk = createChunk(startId, hitsPerThread, sequenceStart); + writeChunk(stream, chunk); + } catch (Exception e) { + error.set(true); + } finally { + done.countDown(); + } + }).start(); + } + + assertTrue("All threads should complete", done.await(10, TimeUnit.SECONDS)); + assertFalse("No errors should occur", error.get()); + + FetchSearchResult result = buildFinalResult(stream); + + try { + SearchHit[] hits = result.hits().getHits(); + assertThat(hits.length, equalTo(totalHits)); + + for (int i = 0; i < totalHits; i++) { + assertThat("Hit at position " + i + " should have correct id in source", getIdFromSource(hits[i]), equalTo(i)); + } + } finally { + result.decRef(); + } + } finally { + stream.decRef(); + } + } + + public void testReleasableClosedOnSuccess() throws IOException { + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 5, new NoopCircuitBreaker("test")); + + try { + AtomicBoolean releasableClosed = new AtomicBoolean(false); + Releasable releasable = () -> releasableClosed.set(true); + + stream.writeChunk(createChunk(0, 5, 0), releasable); + + assertTrue("Releasable should be closed after successful write", releasableClosed.get()); + } finally { + stream.decRef(); + } + } + + public void testReleasableNotClosedOnFailure() throws IOException { + FetchPhaseResponseChunk testChunk = createChunkWithSourceSize(0, 5, 0, 10000); + long chunkSize = testChunk.getBytesLength(); + + // Set limit smaller than chunk size to guarantee trip + CircuitBreaker breaker = newLimitedBreaker(ByteSizeValue.ofBytes(chunkSize / 2)); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 5, breaker); + + try { + AtomicBoolean releasableClosed = new AtomicBoolean(false); + Releasable releasable = () -> releasableClosed.set(true); + + expectThrows( + CircuitBreakingException.class, + () -> { stream.writeChunk(createChunkWithSourceSize(0, 5, 0, 10000), releasable); } + ); + + assertFalse("Releasable should not be closed on failure", releasableClosed.get()); + } finally { + stream.decRef(); + } + } + + public void testWriteChunkWithCircuitBreakerTripPreservesAccountingAndPropagates() throws IOException { + FetchPhaseResponseChunk chunk = createChunkWithSourceSize(0, 5, 0, 4096); + long chunkSize = chunk.getBytesLength(); + + CircuitBreaker breaker = newLimitedBreaker(ByteSizeValue.ofBytes(chunkSize - 1)); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, 5, breaker); + AtomicBoolean releasableClosed = new AtomicBoolean(false); + + try { + CircuitBreakingException e = expectThrows( + CircuitBreakingException.class, + () -> stream.writeChunk(chunk, () -> releasableClosed.set(true)) + ); + + assertFalse("Releasable should not be closed on failure", releasableClosed.get()); + assertThat("No bytes should be tracked on breaker trip", breaker.getUsed(), equalTo(0L)); + + FetchSearchResult result = buildFinalResult(stream); + try { + assertThat("No hits should be accumulated after breaker trip", result.hits().getHits().length, equalTo(0)); + } finally { + result.decRef(); + } + } finally { + stream.decRef(); + assertThat("No breaker bytes should remain after close", breaker.getUsed(), equalTo(0L)); + } + } + + public void testConcurrentWriteChunkAndBuildFinalResultNoHitLeaks() throws Exception { + CircuitBreaker breaker = newLimitedBreaker(ByteSizeValue.ofBytes(Long.MAX_VALUE)); + int numThreads = 8; + int hitsPerThread = 8; + int totalHits = numThreads * hitsPerThread; + + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(SHARD_INDEX, totalHits, breaker); + CountDownLatch startSignal = new CountDownLatch(1); + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + + SearchHit[] resultHits = null; + FetchSearchResult result = null; + try { + List> writerFutures = IntStream.range(0, numThreads) + .mapToObj(threadId -> CompletableFuture.runAsync(() -> { + try { + assertTrue("Writer should be released to start", startSignal.await(5, TimeUnit.SECONDS)); + int startId = threadId * hitsPerThread; + long sequenceStart = threadId * hitsPerThread; + FetchPhaseResponseChunk chunk = createChunk(startId, hitsPerThread, sequenceStart); + try { + writeChunk(stream, chunk); + } finally { + chunk.close(); + } + } catch (Exception e) { + throw new AssertionError("Writer failed", e); + } + }, executor)) + .toList(); + startSignal.countDown(); + + CompletableFuture.allOf(writerFutures.toArray(new CompletableFuture[0])).get(10, TimeUnit.SECONDS); + + result = buildFinalResult(stream); + resultHits = result.hits().getHits().clone(); + assertThat(resultHits.length, equalTo(totalHits)); + + for (int i = 0; i < totalHits; i++) { + assertThat(getIdFromSource(resultHits[i]), equalTo(i)); + } + } finally { + executor.shutdown(); + assertTrue("Executor should terminate", executor.awaitTermination(10, TimeUnit.SECONDS)); + if (result != null) { + result.decRef(); + } + stream.decRef(); + } + + assertNotNull(resultHits); + assertThat("All breaker bytes should be released after stream close", breaker.getUsed(), equalTo(0L)); + } + + public void testChunkMetadata() throws IOException { + SearchHit hit = createHit(0); + try { + FetchPhaseResponseChunk chunk = new FetchPhaseResponseChunk(TEST_SHARD_ID, serializeHits(hit), 1, 10, 0); + + assertThat(chunk.shardId(), equalTo(TEST_SHARD_ID)); + assertThat(chunk.hitCount(), equalTo(1)); + assertThat(chunk.expectedTotalDocs(), equalTo(10)); + assertThat(chunk.sequenceStart(), equalTo(0L)); + assertThat(chunk.getBytesLength(), greaterThan(0L)); + + chunk.close(); + } finally { + hit.decRef(); + } + } + + private FetchSearchResult buildFinalResult(FetchPhaseResponseStream stream) { + return stream.buildFinalResult(new ShardSearchContextId("test", 1), new SearchShardTarget("node1", TEST_SHARD_ID, null), null); + } + + /** + * Extracts the "id" field from a hit's source JSON. + */ + private int getIdFromSource(SearchHit hit) { + Number id = (Number) XContentHelper.convertToMap(hit.getSourceRef(), false, XContentType.JSON).v2().get("id"); + return id.intValue(); + } + + private FetchPhaseResponseChunk createChunk(int startId, int hitCount, long sequenceStart) throws IOException { + SearchHit[] hits = new SearchHit[hitCount]; + for (int i = 0; i < hitCount; i++) { + hits[i] = createHit(startId + i); + } + try { + return new FetchPhaseResponseChunk(TEST_SHARD_ID, serializeHits(hits), hitCount, 100, sequenceStart); + } finally { + for (SearchHit hit : hits) { + hit.decRef(); + } + } + } + + private FetchPhaseResponseChunk createChunkWithSequence(int startId, int hitCount, long sequenceStart) throws IOException { + SearchHit[] hits = new SearchHit[hitCount]; + for (int i = 0; i < hitCount; i++) { + hits[i] = createHit(startId + i); + } + try { + return new FetchPhaseResponseChunk(TEST_SHARD_ID, serializeHits(hits), hitCount, 100, sequenceStart); + } finally { + for (SearchHit hit : hits) { + hit.decRef(); + } + } + } + + private FetchPhaseResponseChunk createChunkWithSourceSize(int startId, int hitCount, long sequenceStart, int sourceSize) + throws IOException { + SearchHit[] hits = new SearchHit[hitCount]; + for (int i = 0; i < hitCount; i++) { + hits[i] = createHitWithSourceSize(startId + i, sourceSize); + } + try { + return new FetchPhaseResponseChunk(TEST_SHARD_ID, serializeHits(hits), hitCount, 100, sequenceStart); + } finally { + for (SearchHit hit : hits) { + hit.decRef(); + } + } + } + + private FetchPhaseResponseChunk createChunkWithScores(int startId, float[] scores, long sequenceStart) throws IOException { + SearchHit[] hits = new SearchHit[scores.length]; + for (int i = 0; i < scores.length; i++) { + hits[i] = createHitWithScore(startId + i, scores[i]); + } + try { + return new FetchPhaseResponseChunk(TEST_SHARD_ID, serializeHits(hits), scores.length, 100, sequenceStart); + } finally { + for (SearchHit hit : hits) { + hit.decRef(); + } + } + } + + private SearchHit createHit(int id) { + SearchHit hit = new SearchHit(id); + hit.sourceRef(new BytesArray("{\"id\":" + id + "}")); + return hit; + } + + private SearchHit createHitWithSourceSize(int id, int sourceSize) { + SearchHit hit = new SearchHit(id); + StringBuilder sb = new StringBuilder(); + sb.append("{\"id\":").append(id).append(",\"data\":\""); + int dataSize = Math.max(0, sourceSize - 20); + for (int i = 0; i < dataSize; i++) { + sb.append('x'); + } + sb.append("\"}"); + hit.sourceRef(new BytesArray(sb.toString())); + return hit; + } + + private SearchHit createHitWithScore(int id, float score) { + SearchHit hit = new SearchHit(id); + hit.sourceRef(new BytesArray("{\"id\":" + id + "}")); + hit.score(score); + return hit; + } + + private BytesReference serializeHits(SearchHit... hits) throws IOException { + try (BytesStreamOutput out = new BytesStreamOutput()) { + for (SearchHit hit : hits) { + hit.writeTo(out); + } + return out.bytes(); + } + } + + private void writeChunk(FetchPhaseResponseStream stream, FetchPhaseResponseChunk chunk) throws IOException { + stream.writeChunk(chunk, () -> {}); + } + +} diff --git a/server/src/test/java/org/elasticsearch/search/fetch/chunk/TransportFetchPhaseCoordinationActionTests.java b/server/src/test/java/org/elasticsearch/search/fetch/chunk/TransportFetchPhaseCoordinationActionTests.java new file mode 100644 index 0000000000000..6f8e9bb9001b7 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/fetch/chunk/TransportFetchPhaseCoordinationActionTests.java @@ -0,0 +1,518 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.fetch.chunk; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.node.VersionInformation; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.search.RescoreDocIds; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.fetch.FetchSearchResult; +import org.elasticsearch.search.fetch.ShardFetchSearchRequest; +import org.elasticsearch.search.internal.AliasFilter; +import org.elasticsearch.search.internal.ShardSearchContextId; +import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.transport.MockTransportService; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.BytesTransportRequest; +import org.elasticsearch.transport.TransportResponseHandler; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.action.search.SearchTransportService.FETCH_ID_ACTION_NAME; +import static org.elasticsearch.search.fetch.chunk.TransportFetchPhaseCoordinationAction.CHUNKED_FETCH_PHASE; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; + +/** + * Unit tests for {@link TransportFetchPhaseCoordinationAction}. + */ +public class TransportFetchPhaseCoordinationActionTests extends ESTestCase { + + private static final ShardId TEST_SHARD_ID = new ShardId(new Index("test-index", "test-uuid"), 0); + + private ThreadPool threadPool; + private MockTransportService transportService; + private ActiveFetchPhaseTasks activeFetchPhaseTasks; + private NamedWriteableRegistry namedWriteableRegistry; + private TransportFetchPhaseCoordinationAction action; + + @Before + public void setUp() throws Exception { + super.setUp(); + threadPool = new TestThreadPool(getTestName()); + transportService = MockTransportService.createNewService( + Settings.EMPTY, + VersionInformation.CURRENT, + CHUNKED_FETCH_PHASE, + threadPool + ); + transportService.start(); + transportService.acceptIncomingRequests(); + + activeFetchPhaseTasks = new ActiveFetchPhaseTasks(); + namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); + + action = new TransportFetchPhaseCoordinationAction( + transportService, + new ActionFilters(Set.of()), + activeFetchPhaseTasks, + new NoneCircuitBreakerService(), + namedWriteableRegistry + ); + new TransportFetchPhaseResponseChunkAction(transportService, activeFetchPhaseTasks, namedWriteableRegistry); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + if (transportService != null) { + transportService.close(); + } + if (threadPool != null) { + ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); + } + } + + public void testActionType() { + assertThat(TransportFetchPhaseCoordinationAction.TYPE.name(), equalTo("internal:data/read/search/fetch/coordination")); + } + + public void testDoExecuteSetsCoordinatorNodeAndTaskIdOnRequest() throws Exception { + CountDownLatch latch = new CountDownLatch(1); + transportService.registerRequestHandler( + FETCH_ID_ACTION_NAME, + threadPool.executor(ThreadPool.Names.GENERIC), + ShardFetchSearchRequest::new, + (req, channel, task) -> { + latch.countDown(); + FetchSearchResult result = createFetchSearchResult(); + try { + channel.sendResponse(result); + } finally { + result.decRef(); + } + } + ); + + ShardFetchSearchRequest shardFetchRequest = createShardFetchSearchRequest(); + TransportFetchPhaseCoordinationAction.Request request = new TransportFetchPhaseCoordinationAction.Request( + shardFetchRequest, + transportService.getLocalNode(), + Collections.emptyMap() + ); + + long taskId = 123L; + Task task = createTask(taskId); + PlainActionFuture future = new PlainActionFuture<>(); + action.doExecute(task, request, future); + + assertTrue("Request handler should be called", latch.await(10, TimeUnit.SECONDS)); + assertThat(shardFetchRequest.getCoordinatingNode(), equalTo(transportService.getLocalNode())); + assertThat(shardFetchRequest.getCoordinatingTaskId(), equalTo(taskId)); + } + + public void testDoExecuteWithParentTaskId() throws Exception { + AtomicReference capturedParentTaskId = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + + transportService.registerRequestHandler( + FETCH_ID_ACTION_NAME, + threadPool.executor(ThreadPool.Names.GENERIC), + ShardFetchSearchRequest::new, + (req, channel, task) -> { + capturedParentTaskId.set(req.getParentTask()); + latch.countDown(); + FetchSearchResult result = null; + try { + result = createFetchSearchResult(); + channel.sendResponse(result); + } finally { + if (result != null) { + result.decRef(); + } + } + } + ); + + TransportFetchPhaseCoordinationAction.Request request = new TransportFetchPhaseCoordinationAction.Request( + createShardFetchSearchRequest(), + transportService.getLocalNode(), + Collections.emptyMap() + ); + + TaskId parentTaskId = new TaskId("parent-node", 999L); + Task task = createTaskWithParent(123L, parentTaskId); + PlainActionFuture future = new PlainActionFuture<>(); + + action.doExecute(task, request, future); + + assertTrue("Request handler should be called", latch.await(10, TimeUnit.SECONDS)); + assertThat(capturedParentTaskId.get(), equalTo(parentTaskId)); + } + + public void testDoExecuteWithHeaders() throws Exception { + AtomicReference capturedHeader = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + + transportService.registerRequestHandler( + FETCH_ID_ACTION_NAME, + threadPool.executor(ThreadPool.Names.GENERIC), + ShardFetchSearchRequest::new, + (req, channel, task) -> { + ThreadContext threadContext = threadPool.getThreadContext(); + capturedHeader.set(threadContext.getHeader("X-Test-Header")); + latch.countDown(); + FetchSearchResult result = createFetchSearchResult(); + try { + channel.sendResponse(result); + } finally { + result.decRef(); + } + } + ); + + TransportFetchPhaseCoordinationAction.Request request = new TransportFetchPhaseCoordinationAction.Request( + createShardFetchSearchRequest(), + transportService.getLocalNode(), + Map.of("X-Test-Header", "test-value", "X-Another-Header", "another-value") + ); + + PlainActionFuture future = new PlainActionFuture<>(); + action.doExecute(createTask(123L), request, future); + + assertTrue("Request handler should be called", latch.await(10, TimeUnit.SECONDS)); + assertThat(capturedHeader.get(), equalTo("test-value")); + } + + public void testDoExecuteReturnsResponseOnSuccess() { + FetchSearchResult expectedResult = createFetchSearchResult(); + + transportService.registerRequestHandler( + FETCH_ID_ACTION_NAME, + threadPool.executor(ThreadPool.Names.GENERIC), + ShardFetchSearchRequest::new, + (req, channel, task) -> { + channel.sendResponse(expectedResult); + } + ); + + TransportFetchPhaseCoordinationAction.Request request = new TransportFetchPhaseCoordinationAction.Request( + createShardFetchSearchRequest(), + transportService.getLocalNode(), + Collections.emptyMap() + ); + + PlainActionFuture future = new PlainActionFuture<>(); + action.doExecute(createTask(123L), request, future); + TransportFetchPhaseCoordinationAction.Response response = future.actionGet(10, TimeUnit.SECONDS); + + try { + assertThat(response, notNullValue()); + assertThat(response.getResult(), notNullValue()); + assertEquals(response.getResult().getContextId(), expectedResult.getContextId()); + } finally { + expectedResult.decRef(); + } + } + + public void testDoExecuteHandlesFailure() { + RuntimeException expectedException = new RuntimeException("Test failure"); + + transportService.registerRequestHandler( + FETCH_ID_ACTION_NAME, + threadPool.executor(ThreadPool.Names.GENERIC), + ShardFetchSearchRequest::new, + (req, channel, task) -> { + channel.sendResponse(expectedException); + } + ); + + TransportFetchPhaseCoordinationAction.Request request = new TransportFetchPhaseCoordinationAction.Request( + createShardFetchSearchRequest(), + transportService.getLocalNode(), + Collections.emptyMap() + ); + + PlainActionFuture future = new PlainActionFuture<>(); + action.doExecute(createTask(123L), request, future); + Exception caughtException = expectThrows(Exception.class, () -> future.actionGet(10, TimeUnit.SECONDS)); + assertThat(caughtException.getMessage(), equalTo("Test failure")); + } + + public void testDoExecuteProcessesLastChunkInResponse() { + transportService.registerRequestHandler( + FETCH_ID_ACTION_NAME, + threadPool.executor(ThreadPool.Names.GENERIC), + ShardFetchSearchRequest::new, + (req, channel, task) -> { + FetchSearchResult result = createFetchSearchResult(); + try { + + BytesStreamOutput out = new BytesStreamOutput(); + SearchHit hit = createHit(0); + hit.writeTo(out); + hit.decRef(); + + result.setLastChunkBytes(out.bytes(), 1); + result.setLastChunkSequenceStart(0L); + + channel.sendResponse(result); + } finally { + result.decRef(); + } + } + ); + + TransportFetchPhaseCoordinationAction.Request request = new TransportFetchPhaseCoordinationAction.Request( + createShardFetchSearchRequest(), + transportService.getLocalNode(), + Collections.emptyMap() + ); + + PlainActionFuture future = new PlainActionFuture<>(); + action.doExecute(createTask(123L), request, future); + TransportFetchPhaseCoordinationAction.Response response = future.actionGet(10, TimeUnit.SECONDS); + + assertThat(response, notNullValue()); + assertThat(response.getResult(), notNullValue()); + } + + public void testDoExecuteIgnoresLastChunkBytesWhenHitCountIsZero() { + transportService.registerRequestHandler( + FETCH_ID_ACTION_NAME, + threadPool.executor(ThreadPool.Names.GENERIC), + ShardFetchSearchRequest::new, + (req, channel, task) -> { + FetchSearchResult result = createFetchSearchResult(); + try { + // Bytes are intentionally not a serialized SearchHit payload. + // They must be ignored when hitCount == 0. + result.setLastChunkBytes(new BytesArray(new byte[] { 1, 2, 3, 4 }), 0); + result.setLastChunkSequenceStart(0L); + channel.sendResponse(result); + } finally { + result.decRef(); + } + } + ); + + TransportFetchPhaseCoordinationAction.Request request = new TransportFetchPhaseCoordinationAction.Request( + createShardFetchSearchRequest(), + transportService.getLocalNode(), + Collections.emptyMap() + ); + + PlainActionFuture future = new PlainActionFuture<>(); + action.doExecute(createTask(123L), request, future); + TransportFetchPhaseCoordinationAction.Response response = future.actionGet(10, TimeUnit.SECONDS); + + assertThat(response, notNullValue()); + assertThat(response.getResult(), notNullValue()); + } + + public void testDoExecuteReleasesRegistrationOnLastChunkDeserializationFailure() throws Exception { + transportService.registerRequestHandler( + FETCH_ID_ACTION_NAME, + threadPool.executor(ThreadPool.Names.GENERIC), + ShardFetchSearchRequest::new, + (req, channel, task) -> { + FetchSearchResult result = createFetchSearchResult(); + try { + // Invalid payload for hitCount=1, forcing SearchHit.readFrom() failure. + result.setLastChunkBytes(new BytesArray(new byte[] { 9, 9, 9 }), 1); + result.setLastChunkSequenceStart(0L); + channel.sendResponse(result); + } finally { + result.decRef(); + } + } + ); + + TransportFetchPhaseCoordinationAction.Request request = new TransportFetchPhaseCoordinationAction.Request( + createShardFetchSearchRequest(), + transportService.getLocalNode(), + Collections.emptyMap() + ); + + long taskId = 456L; + PlainActionFuture future = new PlainActionFuture<>(); + action.doExecute(createTask(taskId), request, future); + expectThrows(Exception.class, () -> future.actionGet(10, TimeUnit.SECONDS)); + + assertBusy(() -> { + expectThrows(ResourceNotFoundException.class, () -> activeFetchPhaseTasks.acquireResponseStream(taskId, TEST_SHARD_ID)); + }); + } + + public void testDoExecutePreservesContextIdInFinalResult() throws Exception { + ShardSearchContextId expectedContextId = new ShardSearchContextId("expected-session", 12345L); + SearchShardTarget expectedShardTarget = new SearchShardTarget("node1", TEST_SHARD_ID, null); + + transportService.registerRequestHandler( + FETCH_ID_ACTION_NAME, + threadPool.executor(ThreadPool.Names.GENERIC), + ShardFetchSearchRequest::new, + (req, channel, task) -> { + FetchSearchResult result = new FetchSearchResult(expectedContextId, expectedShardTarget); + try { + channel.sendResponse(result); + } finally { + result.decRef(); + } + } + ); + + TransportFetchPhaseCoordinationAction.Request request = new TransportFetchPhaseCoordinationAction.Request( + createShardFetchSearchRequest(), + transportService.getLocalNode(), + Collections.emptyMap() + ); + + PlainActionFuture future = new PlainActionFuture<>(); + action.doExecute(createTask(123L), request, future); + TransportFetchPhaseCoordinationAction.Response response = future.actionGet(10, TimeUnit.SECONDS); + + assertThat(response.getResult().getContextId().getId(), equalTo(expectedContextId.getId())); + assertThat(response.getResult().getSearchShardTarget(), equalTo(expectedShardTarget)); + } + + public void testDoExecuteReleasesRegistrationWhenDataNodeFailsAfterChunkStreaming() throws Exception { + transportService.registerRequestHandler( + FETCH_ID_ACTION_NAME, + threadPool.executor(ThreadPool.Names.GENERIC), + ShardFetchSearchRequest::new, + (req, channel, task) -> { + SearchHit streamedHit = createHit(123); + FetchPhaseResponseChunk streamedChunk = null; + ReleasableBytesReference wireBytes = null; + try { + streamedChunk = new FetchPhaseResponseChunk(TEST_SHARD_ID, serializeHits(streamedHit), 1, req.docIds().length, 0L); + wireBytes = streamedChunk.toReleasableBytesReference(req.getCoordinatingTaskId()); + + PlainActionFuture ackFuture = new PlainActionFuture<>(); + transportService.sendRequest( + req.getCoordinatingNode(), + TransportFetchPhaseResponseChunkAction.ZERO_COPY_ACTION_NAME, + new BytesTransportRequest(wireBytes, TransportVersion.current()), + new ActionListenerResponseHandler<>( + ackFuture, + in -> ActionResponse.Empty.INSTANCE, + TransportResponseHandler.TRANSPORT_WORKER + ) + ); + ackFuture.actionGet(10, TimeUnit.SECONDS); + + channel.sendResponse(new RuntimeException("simulated data node failure during chunk streaming")); + } finally { + if (wireBytes != null) { + wireBytes.decRef(); + } + if (streamedChunk != null) { + streamedChunk.close(); + } + streamedHit.decRef(); + } + } + ); + + TransportFetchPhaseCoordinationAction.Request request = new TransportFetchPhaseCoordinationAction.Request( + createShardFetchSearchRequest(), + transportService.getLocalNode(), + Collections.emptyMap() + ); + + long taskId = 789L; + PlainActionFuture future = new PlainActionFuture<>(); + action.doExecute(createTask(taskId), request, future); + + Exception failure = expectThrows(Exception.class, () -> future.actionGet(10, TimeUnit.SECONDS)); + assertThat(failure.getMessage(), equalTo("simulated data node failure during chunk streaming")); + + assertBusy(() -> { + expectThrows(ResourceNotFoundException.class, () -> activeFetchPhaseTasks.acquireResponseStream(taskId, TEST_SHARD_ID)); + }); + } + + private ShardFetchSearchRequest createShardFetchSearchRequest() { + ShardSearchContextId contextId = new ShardSearchContextId("test", randomLong()); + + OriginalIndices originalIndices = new OriginalIndices( + new String[] { "test-index" }, + IndicesOptions.strictExpandOpenAndForbidClosed() + ); + List docIds = List.of(0, 1, 2, 3, 4); + + ShardSearchRequest shardSearchRequest = new ShardSearchRequest(TEST_SHARD_ID, System.currentTimeMillis(), AliasFilter.EMPTY); + + return new ShardFetchSearchRequest(originalIndices, contextId, shardSearchRequest, docIds, null, null, RescoreDocIds.EMPTY, null); + } + + private FetchSearchResult createFetchSearchResult() { + ShardSearchContextId contextId = new ShardSearchContextId("test", randomLong()); + FetchSearchResult result = new FetchSearchResult(contextId, new SearchShardTarget("node", TEST_SHARD_ID, null)); + result.shardResult(SearchHits.unpooled(new SearchHit[0], null, Float.NaN), null); + return result; + } + + private SearchHit createHit(int id) { + SearchHit hit = new SearchHit(id); + hit.sourceRef(new BytesArray("{\"id\":" + id + "}")); + return hit; + } + + private BytesReference serializeHits(SearchHit... hits) throws IOException { + try (BytesStreamOutput out = new BytesStreamOutput()) { + for (SearchHit hit : hits) { + hit.writeTo(out); + } + return out.bytes(); + } + } + + private Task createTask(long taskId) { + return new Task(taskId, "transport", "action", "description", TaskId.EMPTY_TASK_ID, Collections.emptyMap()); + } + + private Task createTaskWithParent(long taskId, TaskId parentTaskId) { + return new Task(taskId, "transport", "action", "description", parentTaskId, Collections.emptyMap()); + } +} diff --git a/server/src/test/java/org/elasticsearch/search/fetch/chunk/TransportFetchPhaseResponseChunkActionTests.java b/server/src/test/java/org/elasticsearch/search/fetch/chunk/TransportFetchPhaseResponseChunkActionTests.java new file mode 100644 index 0000000000000..c7667abc4e44f --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/fetch/chunk/TransportFetchPhaseResponseChunkActionTests.java @@ -0,0 +1,287 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.fetch.chunk; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.node.VersionInformation; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.fetch.FetchSearchResult; +import org.elasticsearch.search.internal.ShardSearchContextId; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.transport.MockTransportService; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.BytesTransportRequest; +import org.elasticsearch.transport.TransportResponseHandler; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; + +public class TransportFetchPhaseResponseChunkActionTests extends ESTestCase { + + private static final ShardId TEST_SHARD_ID = new ShardId(new Index("test-index", "test-uuid"), 0); + + private ThreadPool threadPool; + private MockTransportService transportService; + private ActiveFetchPhaseTasks activeFetchPhaseTasks; + + @Before + public void setUp() throws Exception { + super.setUp(); + threadPool = new TestThreadPool(getTestName()); + transportService = MockTransportService.createNewService( + Settings.EMPTY, + VersionInformation.CURRENT, + TransportFetchPhaseCoordinationAction.CHUNKED_FETCH_PHASE, + threadPool + ); + transportService.start(); + transportService.acceptIncomingRequests(); + + activeFetchPhaseTasks = new ActiveFetchPhaseTasks(); + new TransportFetchPhaseResponseChunkAction( + transportService, + activeFetchPhaseTasks, + new NamedWriteableRegistry(Collections.emptyList()) + ); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + if (transportService != null) { + transportService.close(); + } + if (threadPool != null) { + ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); + } + } + + public void testProcessChunkWhenWriteChunkThrowsSendsErrorAndReleasesChunk() throws Exception { + final long coordinatingTaskId = 123L; + AtomicReference processedChunk = new AtomicReference<>(); + + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(0, 1, new NoopCircuitBreaker("test")) { + @Override + void writeChunk(FetchPhaseResponseChunk chunk, Releasable releasable) { + processedChunk.set(chunk); + try { + chunk.getHits(); + } catch (IOException e) { + throw new RuntimeException(e); + } + throw new IllegalStateException("simulated writeChunk failure"); + } + }; + + Releasable registration = activeFetchPhaseTasks.registerResponseBuilder(coordinatingTaskId, TEST_SHARD_ID, stream); + SearchHit originalHit = createHit(7); + FetchPhaseResponseChunk chunk = null; + try { + chunk = new FetchPhaseResponseChunk(TEST_SHARD_ID, serializeHits(originalHit), 1, 1, 0L); + + ReleasableBytesReference wireBytes = chunk.toReleasableBytesReference(coordinatingTaskId); + PlainActionFuture future = new PlainActionFuture<>(); + + transportService.sendRequest( + transportService.getLocalNode(), + TransportFetchPhaseResponseChunkAction.ZERO_COPY_ACTION_NAME, + new BytesTransportRequest(wireBytes, TransportVersion.current()), + new ActionListenerResponseHandler<>(future, in -> ActionResponse.Empty.INSTANCE, TransportResponseHandler.TRANSPORT_WORKER) + ); + + Exception e = expectThrows(Exception.class, () -> future.actionGet(10, TimeUnit.SECONDS)); + assertThat(e.getMessage(), containsString("simulated writeChunk failure")); + + assertBusy(() -> { + FetchPhaseResponseChunk seen = processedChunk.get(); + assertNotNull("Chunk should have been processed before failure", seen); + assertEquals("Chunk should be closed on failure", 0L, seen.getBytesLength()); + }); + } finally { + if (chunk != null) { + chunk.close(); + } + registration.close(); + stream.decRef(); + originalHit.decRef(); + } + } + + public void testProcessChunkSuccessWritesChunkAndReturnsAck() throws Exception { + final long coordinatingTaskId = 321L; + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(0, 1, new NoopCircuitBreaker("test")); + Releasable registration = activeFetchPhaseTasks.registerResponseBuilder(coordinatingTaskId, TEST_SHARD_ID, stream); + SearchHit originalHit = createHit(9); + FetchPhaseResponseChunk chunk = null; + ReleasableBytesReference wireBytes = null; + try { + chunk = new FetchPhaseResponseChunk(TEST_SHARD_ID, serializeHits(originalHit), 1, 1, 0L); + wireBytes = chunk.toReleasableBytesReference(coordinatingTaskId); + + PlainActionFuture future = sendChunk(wireBytes); + assertSame(ActionResponse.Empty.INSTANCE, future.actionGet(10, TimeUnit.SECONDS)); + + FetchSearchResult finalResult = stream.buildFinalResult( + new ShardSearchContextId("ctx", 1L), + new SearchShardTarget("node-0", TEST_SHARD_ID, null), + null + ); + try { + SearchHit[] hits = finalResult.hits().getHits(); + assertThat(hits.length, equalTo(1)); + assertThat(hits[0].getSourceRef().utf8ToString(), containsString("\"id\":9")); + } finally { + finalResult.decRef(); + } + } finally { + if (wireBytes != null) { + wireBytes.decRef(); + } + if (chunk != null) { + chunk.close(); + } + registration.close(); + stream.decRef(); + originalHit.decRef(); + } + } + + public void testProcessChunkForUnknownTaskReturnsResourceNotFound() throws Exception { + final long unknownTaskId = randomLongBetween(10_000L, 20_000L); + SearchHit originalHit = createHit(1); + FetchPhaseResponseChunk chunk = null; + ReleasableBytesReference wireBytes = null; + try { + chunk = new FetchPhaseResponseChunk(TEST_SHARD_ID, serializeHits(originalHit), 1, 1, 0L); + wireBytes = chunk.toReleasableBytesReference(unknownTaskId); + + PlainActionFuture future = sendChunk(wireBytes); + Exception e = expectThrows(Exception.class, () -> future.actionGet(10, TimeUnit.SECONDS)); + assertThat(e.getMessage(), containsString("fetch task [" + unknownTaskId + "] not found")); + } finally { + if (wireBytes != null) { + wireBytes.decRef(); + } + if (chunk != null) { + chunk.close(); + } + originalHit.decRef(); + } + } + + public void testProcessChunkForLateChunkReturnsResourceNotFound() throws Exception { + final long coordinatingTaskId = 777L; + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(0, 1, new NoopCircuitBreaker("test")); + Releasable registration = activeFetchPhaseTasks.registerResponseBuilder(coordinatingTaskId, TEST_SHARD_ID, stream); + + registration.close(); + stream.decRef(); + + SearchHit originalHit = createHit(3); + FetchPhaseResponseChunk chunk = null; + ReleasableBytesReference wireBytes = null; + try { + chunk = new FetchPhaseResponseChunk(TEST_SHARD_ID, serializeHits(originalHit), 1, 1, 0L); + wireBytes = chunk.toReleasableBytesReference(coordinatingTaskId); + + PlainActionFuture future = sendChunk(wireBytes); + Exception e = expectThrows(Exception.class, () -> future.actionGet(10, TimeUnit.SECONDS)); + assertThat(e.getMessage(), containsString("fetch task [" + coordinatingTaskId + "] not found")); + } finally { + if (wireBytes != null) { + wireBytes.decRef(); + } + if (chunk != null) { + chunk.close(); + } + originalHit.decRef(); + } + } + + public void testProcessChunkTracksAndReleasesCircuitBreakerBytes() throws Exception { + final long coordinatingTaskId = 222L; + var breaker = newLimitedBreaker(ByteSizeValue.ofBytes(Long.MAX_VALUE)); + FetchPhaseResponseStream stream = new FetchPhaseResponseStream(0, 1, breaker); + Releasable registration = activeFetchPhaseTasks.registerResponseBuilder(coordinatingTaskId, TEST_SHARD_ID, stream); + SearchHit originalHit = createHit(12); + FetchPhaseResponseChunk chunk = null; + ReleasableBytesReference wireBytes = null; + try { + chunk = new FetchPhaseResponseChunk(TEST_SHARD_ID, serializeHits(originalHit), 1, 1, 0L); + long expectedBytes = chunk.getBytesLength(); + wireBytes = chunk.toReleasableBytesReference(coordinatingTaskId); + + PlainActionFuture future = sendChunk(wireBytes); + future.actionGet(10, TimeUnit.SECONDS); + assertThat(breaker.getUsed(), equalTo(expectedBytes)); + } finally { + if (wireBytes != null) { + wireBytes.decRef(); + } + if (chunk != null) { + chunk.close(); + } + registration.close(); + stream.decRef(); + originalHit.decRef(); + } + + assertThat("breaker bytes should be released when stream is closed", breaker.getUsed(), equalTo(0L)); + } + + private SearchHit createHit(int id) { + SearchHit hit = new SearchHit(id); + hit.sourceRef(new BytesArray("{\"id\":" + id + "}")); + return hit; + } + + private PlainActionFuture sendChunk(ReleasableBytesReference wireBytes) { + PlainActionFuture future = new PlainActionFuture<>(); + transportService.sendRequest( + transportService.getLocalNode(), + TransportFetchPhaseResponseChunkAction.ZERO_COPY_ACTION_NAME, + new BytesTransportRequest(wireBytes, TransportVersion.current()), + new ActionListenerResponseHandler<>(future, in -> ActionResponse.Empty.INSTANCE, TransportResponseHandler.TRANSPORT_WORKER) + ); + return future; + } + + private BytesReference serializeHits(SearchHit... hits) throws IOException { + try (BytesStreamOutput out = new BytesStreamOutput()) { + for (SearchHit hit : hits) { + hit.writeTo(out); + } + return out.bytes(); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/search/internal/ContextIndexSearcherTests.java b/server/src/test/java/org/elasticsearch/search/internal/ContextIndexSearcherTests.java index 06ebfba2ff896..472761bdd3add 100644 --- a/server/src/test/java/org/elasticsearch/search/internal/ContextIndexSearcherTests.java +++ b/server/src/test/java/org/elasticsearch/search/internal/ContextIndexSearcherTests.java @@ -706,8 +706,6 @@ public void testMaxClause() throws Exception { terminate(executor); } } - } finally { - terminate(executor); } } diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTestHelper.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTestHelper.java index 47c14418599a2..83098d7099c7d 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTestHelper.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTestHelper.java @@ -150,6 +150,9 @@ import org.elasticsearch.search.SearchService; import org.elasticsearch.search.crossproject.CrossProjectModeDecider; import org.elasticsearch.search.fetch.FetchPhase; +import org.elasticsearch.search.fetch.chunk.ActiveFetchPhaseTasks; +import org.elasticsearch.search.fetch.chunk.TransportFetchPhaseCoordinationAction; +import org.elasticsearch.search.fetch.chunk.TransportFetchPhaseResponseChunkAction; import org.elasticsearch.telemetry.TelemetryProvider; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.test.client.NoOpClient; @@ -694,6 +697,8 @@ public RecyclerBytesStreamOutput newNetworkBytesStream(@Nullable CircuitBreaker ); final ActionFilters actionFilters = new ActionFilters(emptySet()); + final ActiveFetchPhaseTasks activeFetchPhaseTasks = new ActiveFetchPhaseTasks(); + new TransportFetchPhaseResponseChunkAction(transportService, activeFetchPhaseTasks, namedWriteableRegistry); Map, TransportAction> actions = new HashMap<>(); // Inject initialization from subclass which may be needed by initializations after this point. @@ -760,6 +765,7 @@ public RecyclerBytesStreamOutput newNetworkBytesStream(@Nullable CircuitBreaker client, SearchExecutionStatsCollector.makeWrapper(responseCollectorService) ); + searchTransportService.setSearchService(searchService); indicesClusterStateService = new IndicesClusterStateService( settings, @@ -951,6 +957,16 @@ public boolean clusterHasFeature(ClusterState state, NodeFeature feature) { CrossProjectModeDecider.NOOP ) ); + actions.put( + TransportFetchPhaseCoordinationAction.TYPE, + new TransportFetchPhaseCoordinationAction( + transportService, + actionFilters, + activeFetchPhaseTasks, + new NoneCircuitBreakerService(), + namedWriteableRegistry + ) + ); actions.put( TransportRestoreSnapshotAction.TYPE, new TransportRestoreSnapshotAction( diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index 5ee4f6f0c2640..29f66ddceadc5 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -1040,6 +1040,7 @@ public void run() { } public void testSuccessfulSnapshotWithConcurrentDynamicMappingUpdates() { + setupTestCluster(randomFrom(1, 3, 5), randomIntBetween(2, 10)); String repoName = "repo"; diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractSearchCancellationTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractSearchCancellationTestCase.java index 5ecb2f24acb32..40ed521fd6f49 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractSearchCancellationTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractSearchCancellationTestCase.java @@ -280,11 +280,16 @@ protected List initSearchShardBlockingPlugin() { public static class SearchShardBlockingPlugin extends Plugin { private final AtomicReference> runOnPreQueryPhase = new AtomicReference<>(); + private final AtomicReference> runOnPreFetchPhase = new AtomicReference<>(); public void setRunOnPreQueryPhase(Consumer consumer) { runOnPreQueryPhase.set(consumer); } + public void setRunOnPreFetchPhase(Consumer consumer) { + runOnPreFetchPhase.set(consumer); + } + @Override public void onIndexModule(IndexModule indexModule) { super.onIndexModule(indexModule); @@ -295,6 +300,13 @@ public void onPreQueryPhase(SearchContext c) { runOnPreQueryPhase.get().accept(c); } } + + @Override + public void onPreFetchPhase(SearchContext c) { + if (runOnPreFetchPhase.get() != null) { + runOnPreFetchPhase.get().accept(c); + } + } }); } } diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 7063865a68133..17d5302ec0116 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -668,6 +668,7 @@ public class Constants { "internal:cluster/coordination_diagnostics/info", "internal:cluster/formation/info", "internal:cluster/snapshot/update_snapshot_status", + "internal:data/read/search/fetch/coordination", "internal:gateway/local/started_shards", "internal:admin/indices/prevalidate_shard_path", "internal:index/metadata/migration_version/update",