Skip to content

Commit a174c6d

Browse files
committed
Improve PIT context relocation
1 parent a15c735 commit a174c6d

File tree

13 files changed

+382
-52
lines changed

13 files changed

+382
-52
lines changed

server/src/internalClusterTest/java/org/elasticsearch/search/scroll/SearchScrollIT.java

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111

1212
import org.elasticsearch.ExceptionsHelper;
1313
import org.elasticsearch.action.search.ClearScrollResponse;
14+
import org.elasticsearch.action.search.ParsedScrollId;
1415
import org.elasticsearch.action.search.SearchPhaseExecutionException;
1516
import org.elasticsearch.action.search.SearchRequestBuilder;
1617
import org.elasticsearch.action.search.SearchResponse;
18+
import org.elasticsearch.action.search.SearchScrollRequestBuilder;
1719
import org.elasticsearch.action.search.SearchType;
18-
import org.elasticsearch.action.search.ShardSearchFailure;
1920
import org.elasticsearch.cluster.metadata.IndexMetadata;
2021
import org.elasticsearch.common.Priority;
2122
import org.elasticsearch.common.bytes.BytesReference;
@@ -28,6 +29,7 @@
2829
import org.elasticsearch.index.query.RangeQueryBuilder;
2930
import org.elasticsearch.rest.RestStatus;
3031
import org.elasticsearch.search.SearchHit;
32+
import org.elasticsearch.search.internal.ShardSearchContextId;
3133
import org.elasticsearch.search.sort.FieldSortBuilder;
3234
import org.elasticsearch.search.sort.SortOrder;
3335
import org.elasticsearch.test.ESIntegTestCase;
@@ -703,13 +705,15 @@ public void testRestartDataNodesDuringScrollSearch() throws Exception {
703705
} finally {
704706
respFromProdIndex.decRef();
705707
}
706-
SearchPhaseExecutionException error = expectThrows(
707-
SearchPhaseExecutionException.class,
708-
client().prepareSearchScroll(respFromDemoIndexScrollId)
708+
SearchScrollRequestBuilder searchScrollRequestBuilder = client().prepareSearchScroll(respFromDemoIndexScrollId);
709+
SearchPhaseExecutionException error = expectThrows(SearchPhaseExecutionException.class, searchScrollRequestBuilder);
710+
assertEquals(1, error.shardFailures().length);
711+
ParsedScrollId parsedScrollId = searchScrollRequestBuilder.request().parseScrollId();
712+
ShardSearchContextId shardSearchContextId = parsedScrollId.getContext()[0].getSearchContextId();
713+
assertThat(
714+
error.shardFailures()[0].getCause().getMessage(),
715+
containsString("No search context found for id [" + shardSearchContextId + "]")
709716
);
710-
for (ShardSearchFailure shardSearchFailure : error.shardFailures()) {
711-
assertThat(shardSearchFailure.getCause().getMessage(), containsString("No search context found for id [1]"));
712-
}
713717
client().prepareSearchScroll(respFromProdIndexScrollId).get().decRef();
714718
}
715719

server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.apache.lucene.util.SetOnce;
1414
import org.elasticsearch.ElasticsearchException;
1515
import org.elasticsearch.ExceptionsHelper;
16+
import org.elasticsearch.TransportVersion;
1617
import org.elasticsearch.action.ActionListener;
1718
import org.elasticsearch.action.NoShardAvailableActionException;
1819
import org.elasticsearch.action.OriginalIndices;
@@ -31,6 +32,7 @@
3132
import org.elasticsearch.search.SearchPhaseResult;
3233
import org.elasticsearch.search.SearchShardTarget;
3334
import org.elasticsearch.search.builder.PointInTimeBuilder;
35+
import org.elasticsearch.search.builder.SearchSourceBuilder;
3436
import org.elasticsearch.search.internal.AliasFilter;
3537
import org.elasticsearch.search.internal.SearchContext;
3638
import org.elasticsearch.search.internal.ShardSearchContextId;
@@ -39,8 +41,10 @@
3941

4042
import java.util.ArrayList;
4143
import java.util.Arrays;
44+
import java.util.HashMap;
4245
import java.util.List;
4346
import java.util.Map;
47+
import java.util.Objects;
4448
import java.util.concurrent.ConcurrentHashMap;
4549
import java.util.concurrent.ConcurrentLinkedQueue;
4650
import java.util.concurrent.Executor;
@@ -93,6 +97,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
9397
private final Map<String, PendingExecutions> pendingExecutionsPerNode;
9498
private final AtomicBoolean requestCancelled = new AtomicBoolean();
9599
private final int skippedCount;
100+
private final TransportVersion mintransportVersion;
96101

97102
// protected for tests
98103
protected final SubscribableListener<Void> doneFuture = new SubscribableListener<>();
@@ -149,6 +154,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
149154
this.nodeIdToConnection = nodeIdToConnection;
150155
this.concreteIndexBoosts = concreteIndexBoosts;
151156
this.clusterStateVersion = clusterState.version();
157+
this.mintransportVersion = clusterState.getMinTransportVersion();
152158
this.aliasFilter = aliasFilter;
153159
this.results = resultConsumer;
154160
// register the release of the query consumer to free up the circuit breaker memory
@@ -416,6 +422,7 @@ protected final void onShardFailure(final int shardIndex, SearchShardTarget shar
416422
onShardGroupFailure(shardIndex, shard, e);
417423
}
418424
if (lastShard == false) {
425+
logger.debug("Retrying shard [{}] with target [{}]", shard.getShardId(), nextShard);
419426
performPhaseOnShard(shardIndex, shardIt, nextShard);
420427
} else {
421428
// count down outstanding shards, we're done with this shard as there's no more copies to try
@@ -607,10 +614,70 @@ public void sendSearchResponse(SearchResponseSections internalSearchResponse, At
607614
}
608615

609616
protected BytesReference buildSearchContextId(ShardSearchFailure[] failures) {
610-
var source = request.source();
611-
return source != null && source.pointInTimeBuilder() != null && source.pointInTimeBuilder().singleSession() == false
612-
? source.pointInTimeBuilder().getEncodedId()
613-
: null;
617+
SearchSourceBuilder source = request.source();
618+
// only (re-)build a search context id if we have a point in time
619+
if (source != null && source.pointInTimeBuilder() != null && source.pointInTimeBuilder().singleSession() == false) {
620+
// we want to change node ids in the PIT id if any shards and its PIT context have moved
621+
return maybeReEncodeNodeIds(
622+
source.pointInTimeBuilder(),
623+
results.getAtomicArray().asList(),
624+
failures,
625+
namedWriteableRegistry,
626+
mintransportVersion
627+
);
628+
} else {
629+
return null;
630+
}
631+
}
632+
633+
static <Result extends SearchPhaseResult> BytesReference maybeReEncodeNodeIds(
634+
PointInTimeBuilder originalPit,
635+
List<Result> results,
636+
ShardSearchFailure[] failures,
637+
NamedWriteableRegistry namedWriteableRegistry,
638+
TransportVersion mintransportVersion
639+
) {
640+
SearchContextId original = originalPit.getSearchContextId(namedWriteableRegistry);
641+
boolean idChanged = false;
642+
Map<ShardId, SearchContextIdForNode> updatedShardMap = null; // only create this if we detect a change
643+
for (Result result : results) {
644+
SearchShardTarget searchShardTarget = result.getSearchShardTarget();
645+
ShardId shardId = searchShardTarget.getShardId();
646+
SearchContextIdForNode originalShard = original.shards().get(shardId);
647+
if (originalShard != null
648+
&& Objects.equals(originalShard.getClusterAlias(), searchShardTarget.getClusterAlias())
649+
&& Objects.equals(originalShard.getSearchContextId(), result.getContextId())) {
650+
// result shard and context id match the originalShard one, check if the node is different and replace if so
651+
String originalNode = originalShard.getNode();
652+
if (originalNode != null && originalNode.equals(searchShardTarget.getNodeId()) == false) {
653+
// the target node for this shard entry in the PIT has changed, we need to update it
654+
idChanged = true;
655+
if (updatedShardMap == null) {
656+
updatedShardMap = new HashMap<>(original.shards().size());
657+
}
658+
updatedShardMap.put(
659+
shardId,
660+
new SearchContextIdForNode(
661+
originalShard.getClusterAlias(),
662+
searchShardTarget.getNodeId(),
663+
originalShard.getSearchContextId()
664+
)
665+
);
666+
}
667+
}
668+
}
669+
if (idChanged) {
670+
// we also need to add shard that are not in the results for some reason (e.g. query rewrote to match none) but that
671+
// were part of the original PIT
672+
for (ShardId shardId : original.shards().keySet()) {
673+
if (updatedShardMap.containsKey(shardId) == false) {
674+
updatedShardMap.put(shardId, original.shards().get(shardId));
675+
}
676+
}
677+
return SearchContextId.encode(updatedShardMap, original.aliasFilter(), mintransportVersion, failures);
678+
} else {
679+
return originalPit.getEncodedId();
680+
}
614681
}
615682

616683
/**

server/src/main/java/org/elasticsearch/action/search/SearchContextId.java

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import org.elasticsearch.common.util.Maps;
2222
import org.elasticsearch.index.shard.ShardId;
2323
import org.elasticsearch.search.SearchPhaseResult;
24-
import org.elasticsearch.search.SearchShardTarget;
2524
import org.elasticsearch.search.internal.AliasFilter;
2625
import org.elasticsearch.search.internal.ShardSearchContextId;
2726
import org.elasticsearch.transport.RemoteClusterAware;
@@ -30,6 +29,7 @@
3029
import java.util.Collections;
3130
import java.util.List;
3231
import java.util.Map;
32+
import java.util.Objects;
3333
import java.util.Set;
3434
import java.util.TreeSet;
3535
import java.util.stream.Collectors;
@@ -62,6 +62,26 @@ public static BytesReference encode(
6262
Map<String, AliasFilter> aliasFilter,
6363
TransportVersion version,
6464
ShardSearchFailure[] shardFailures
65+
) {
66+
Map<ShardId, SearchContextIdForNode> shards = searchPhaseResults.stream()
67+
.collect(
68+
Collectors.toMap(
69+
r -> r.getSearchShardTarget().getShardId(),
70+
r -> new SearchContextIdForNode(
71+
r.getSearchShardTarget().getClusterAlias(),
72+
r.getSearchShardTarget().getNodeId(),
73+
r.getContextId()
74+
)
75+
)
76+
);
77+
return encode(shards, aliasFilter, version, shardFailures);
78+
}
79+
80+
static BytesReference encode(
81+
Map<ShardId, SearchContextIdForNode> shards,
82+
Map<String, AliasFilter> aliasFilter,
83+
TransportVersion version,
84+
ShardSearchFailure[] shardFailures
6585
) {
6686
assert shardFailures.length == 0 || version.onOrAfter(TransportVersions.V_8_16_0)
6787
: "[allow_partial_search_results] cannot be enabled on a cluster that has not been fully upgraded to version ["
@@ -71,12 +91,12 @@ public static BytesReference encode(
7191
out.setTransportVersion(version);
7292
TransportVersion.writeVersion(version, out);
7393
boolean allowNullContextId = out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0);
74-
int shardSize = searchPhaseResults.size() + (allowNullContextId ? shardFailures.length : 0);
94+
int shardSize = shards.size() + (allowNullContextId ? shardFailures.length : 0);
7595
out.writeVInt(shardSize);
76-
for (var searchResult : searchPhaseResults) {
77-
final SearchShardTarget target = searchResult.getSearchShardTarget();
78-
target.getShardId().writeTo(out);
79-
new SearchContextIdForNode(target.getClusterAlias(), target.getNodeId(), searchResult.getContextId()).writeTo(out);
96+
for (ShardId shardId : shards.keySet()) {
97+
shardId.writeTo(out);
98+
SearchContextIdForNode searchContextIdForNode = shards.get(shardId);
99+
searchContextIdForNode.writeTo(out);
80100
}
81101
if (allowNullContextId) {
82102
for (var failure : shardFailures) {
@@ -142,4 +162,23 @@ public String[] getActualIndices() {
142162
}
143163
return indices.toArray(String[]::new);
144164
}
165+
166+
@Override
167+
public boolean equals(Object o) {
168+
if (o == null || getClass() != o.getClass()) return false;
169+
SearchContextId that = (SearchContextId) o;
170+
return Objects.equals(shards, that.shards)
171+
&& Objects.equals(aliasFilter, that.aliasFilter)
172+
&& Objects.equals(contextIds, that.contextIds);
173+
}
174+
175+
@Override
176+
public int hashCode() {
177+
return Objects.hash(shards, aliasFilter, contextIds);
178+
}
179+
180+
@Override
181+
public String toString() {
182+
return "SearchContextId{" + "shards=" + shards + ", aliasFilter=" + aliasFilter + '}';
183+
}
145184
}

server/src/main/java/org/elasticsearch/action/search/SearchContextIdForNode.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.search.internal.ShardSearchContextId;
1818

1919
import java.io.IOException;
20+
import java.util.Objects;
2021

2122
public final class SearchContextIdForNode implements Writeable {
2223
private final String node;
@@ -103,4 +104,18 @@ public String toString() {
103104
+ '\''
104105
+ '}';
105106
}
107+
108+
@Override
109+
public boolean equals(Object o) {
110+
if (o == null || getClass() != o.getClass()) return false;
111+
SearchContextIdForNode that = (SearchContextIdForNode) o;
112+
return Objects.equals(node, that.node)
113+
&& Objects.equals(searchContextId, that.searchContextId)
114+
&& Objects.equals(clusterAlias, that.clusterAlias);
115+
}
116+
117+
@Override
118+
public int hashCode() {
119+
return Objects.hash(node, searchContextId, clusterAlias);
120+
}
106121
}

server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,8 @@ public void writeTo(StreamOutput out) throws IOException {
386386

387387
public static void registerRequestHandler(TransportService transportService, SearchService searchService) {
388388
final TransportRequestHandler<ScrollFreeContextRequest> freeContextHandler = (request, channel, task) -> {
389-
logger.trace("releasing search context [{}]", request.id());
390389
boolean freed = searchService.freeReaderContext(request.id());
390+
logger.trace("releasing search context [{}], [{}]", request.id(), freed);
391391
channel.sendResponse(SearchFreeContextResponse.of(freed));
392392
};
393393
final Executor freeContextExecutor = buildFreeContextExecutor(transportService);

server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1946,6 +1946,12 @@ static List<SearchShardIterator> getLocalShardsIteratorFromPointInTime(
19461946
// Prefer executing shard requests on nodes that are part of PIT first.
19471947
if (projectState.cluster().nodes().nodeExists(perNode.getNode())) {
19481948
targetNodes.add(perNode.getNode());
1949+
} else {
1950+
logger.debug(
1951+
"Node [{}] referenced in PIT context id [{}] no longer exists.",
1952+
perNode.getNode(),
1953+
perNode.getSearchContextId()
1954+
);
19491955
}
19501956
if (perNode.getSearchContextId().getSearcherId() != null) {
19511957
for (ShardRouting shard : shards) {

server/src/main/java/org/elasticsearch/search/SearchContextMissingException.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public class SearchContextMissingException extends ElasticsearchException {
2222
private final ShardSearchContextId contextId;
2323

2424
public SearchContextMissingException(ShardSearchContextId contextId) {
25-
super("No search context found for id [" + contextId.getId() + "]");
25+
super("No search context found for id [" + contextId + "]");
2626
this.contextId = contextId;
2727
}
2828

server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
*/
3131
public abstract class SearchPhaseResult extends TransportResponse {
3232

33-
private SearchShardTarget searchShardTarget;
33+
protected SearchShardTarget searchShardTarget;
3434
private int shardIndex = -1;
3535
protected ShardSearchContextId contextId;
3636
private ShardSearchRequest shardSearchRequest;

0 commit comments

Comments
 (0)