Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@

import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.search.ClearScrollResponse;
import org.elasticsearch.action.search.ParsedScrollId;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollRequestBuilder;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.Priority;
import org.elasticsearch.common.bytes.BytesReference;
Expand All @@ -28,6 +29,7 @@
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.internal.ShardSearchContextId;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.test.ESIntegTestCase;
Expand Down Expand Up @@ -703,13 +705,15 @@ public void testRestartDataNodesDuringScrollSearch() throws Exception {
} finally {
respFromProdIndex.decRef();
}
SearchPhaseExecutionException error = expectThrows(
SearchPhaseExecutionException.class,
client().prepareSearchScroll(respFromDemoIndexScrollId)
SearchScrollRequestBuilder searchScrollRequestBuilder = client().prepareSearchScroll(respFromDemoIndexScrollId);
SearchPhaseExecutionException error = expectThrows(SearchPhaseExecutionException.class, searchScrollRequestBuilder);
assertEquals(1, error.shardFailures().length);
ParsedScrollId parsedScrollId = searchScrollRequestBuilder.request().parseScrollId();
ShardSearchContextId shardSearchContextId = parsedScrollId.getContext()[0].getSearchContextId();
assertThat(
error.shardFailures()[0].getCause().getMessage(),
containsString("No search context found for id [" + shardSearchContextId + "]")
);
for (ShardSearchFailure shardSearchFailure : error.shardFailures()) {
assertThat(shardSearchFailure.getCause().getMessage(), containsString("No search context found for id [1]"));
}
client().prepareSearchScroll(respFromProdIndexScrollId).get().decRef();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.NoShardAvailableActionException;
import org.elasticsearch.action.OriginalIndices;
Expand All @@ -31,6 +32,7 @@
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.ShardSearchContextId;
Expand All @@ -39,8 +41,10 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Executor;
Expand Down Expand Up @@ -93,6 +97,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
private final Map<String, PendingExecutions> pendingExecutionsPerNode;
private final AtomicBoolean requestCancelled = new AtomicBoolean();
private final int skippedCount;
private final TransportVersion mintransportVersion;

// protected for tests
protected final SubscribableListener<Void> doneFuture = new SubscribableListener<>();
Expand Down Expand Up @@ -149,6 +154,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
this.nodeIdToConnection = nodeIdToConnection;
this.concreteIndexBoosts = concreteIndexBoosts;
this.clusterStateVersion = clusterState.version();
this.mintransportVersion = clusterState.getMinTransportVersion();
this.aliasFilter = aliasFilter;
this.results = resultConsumer;
// register the release of the query consumer to free up the circuit breaker memory
Expand Down Expand Up @@ -416,6 +422,7 @@ protected final void onShardFailure(final int shardIndex, SearchShardTarget shar
onShardGroupFailure(shardIndex, shard, e);
}
if (lastShard == false) {
logger.debug("Retrying shard [{}] with target [{}]", shard.getShardId(), nextShard);
performPhaseOnShard(shardIndex, shardIt, nextShard);
} else {
// count down outstanding shards, we're done with this shard as there's no more copies to try
Expand Down Expand Up @@ -607,10 +614,77 @@ public void sendSearchResponse(SearchResponseSections internalSearchResponse, At
}

protected BytesReference buildSearchContextId(ShardSearchFailure[] failures) {
var source = request.source();
return source != null && source.pointInTimeBuilder() != null && source.pointInTimeBuilder().singleSession() == false
? source.pointInTimeBuilder().getEncodedId()
: null;
SearchSourceBuilder source = request.source();
// only (re-)build a search context id if we have a point in time
if (source != null && source.pointInTimeBuilder() != null && source.pointInTimeBuilder().singleSession() == false) {
// we want to change node ids in the PIT id if any shards and its PIT context have moved
BytesReference bytesReference = maybeReEncodeNodeIds(
source.pointInTimeBuilder(),
results.getAtomicArray().asList(),
failures,
namedWriteableRegistry,
mintransportVersion
);
if ((bytesReference == source.pointInTimeBuilder().getEncodedId()) == false) {
logger.info(
"Changing PIT to: [{}]",
new PointInTimeBuilder(bytesReference).getSearchContextId(namedWriteableRegistry).toString().replace("},", "\n")
);
}
return bytesReference;
} else {
return null;
}
}

static <Result extends SearchPhaseResult> BytesReference maybeReEncodeNodeIds(
PointInTimeBuilder originalPit,
List<Result> results,
ShardSearchFailure[] failures,
NamedWriteableRegistry namedWriteableRegistry,
TransportVersion mintransportVersion
) {
SearchContextId original = originalPit.getSearchContextId(namedWriteableRegistry);
boolean idChanged = false;
Map<ShardId, SearchContextIdForNode> updatedShardMap = null; // only create this if we detect a change
for (Result result : results) {
SearchShardTarget searchShardTarget = result.getSearchShardTarget();
ShardId shardId = searchShardTarget.getShardId();
SearchContextIdForNode originalShard = original.shards().get(shardId);
if (originalShard != null
&& Objects.equals(originalShard.getClusterAlias(), searchShardTarget.getClusterAlias())
&& Objects.equals(originalShard.getSearchContextId(), result.getContextId())) {
// result shard and context id match the originalShard one, check if the node is different and replace if so
String originalNode = originalShard.getNode();
if (originalNode != null && originalNode.equals(searchShardTarget.getNodeId()) == false) {
// the target node for this shard entry in the PIT has changed, we need to update it
idChanged = true;
if (updatedShardMap == null) {
updatedShardMap = new HashMap<>(original.shards().size());
}
updatedShardMap.put(
shardId,
new SearchContextIdForNode(
originalShard.getClusterAlias(),
searchShardTarget.getNodeId(),
originalShard.getSearchContextId()
)
);
}
}
}
if (idChanged) {
// we also need to add shard that are not in the results for some reason (e.g. query rewrote to match none) but that
// were part of the original PIT
for (ShardId shardId : original.shards().keySet()) {
if (updatedShardMap.containsKey(shardId) == false) {
updatedShardMap.put(shardId, original.shards().get(shardId));
}
}
return SearchContextId.encode(updatedShardMap, original.aliasFilter(), mintransportVersion, failures);
} else {
return originalPit.getEncodedId();
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,27 @@
*/
package org.elasticsearch.action.search;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.project.ProjectResolver;
import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.concurrent.ListenableFuture;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportResponse;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -38,6 +45,7 @@ public final class ClearScrollController implements Runnable {
private final AtomicBoolean hasFailed = new AtomicBoolean(false);
private final AtomicInteger freedSearchContexts = new AtomicInteger(0);
private final Logger logger;
private static final Logger staticLogger = LogManager.getLogger(ClearScrollController.class);
private final Runnable runner;

ClearScrollController(
Expand Down Expand Up @@ -148,12 +156,15 @@ private void finish() {
* Closes the given context id and reports the number of freed contexts via the listener
*/
public static void closeContexts(
DiscoveryNodes nodes,
ClusterService clusterService,
ProjectResolver projectResolver,
SearchTransportService searchTransportService,
Collection<SearchContextIdForNode> contextIds,
Map<ShardId, SearchContextIdForNode> shards,
ActionListener<Integer> listener
) {
final Set<String> clusters = contextIds.stream()
DiscoveryNodes nodes = clusterService.state().nodes();
final Set<String> clusters = shards.values()
.stream()
.map(SearchContextIdForNode::getClusterAlias)
.filter(clusterAlias -> Strings.isEmpty(clusterAlias) == false)
.collect(Collectors.toSet());
Expand All @@ -166,16 +177,34 @@ public static void closeContexts(
lookupListener.addListener(listener.delegateFailure((l, nodeLookup) -> {
final var successes = new AtomicInteger();
try (RefCountingRunnable refs = new RefCountingRunnable(() -> l.onResponse(successes.get()))) {
for (SearchContextIdForNode contextId : contextIds) {
for (Entry<ShardId, SearchContextIdForNode> entry : shards.entrySet()) {
var contextId = entry.getValue();
if (contextId.getNode() == null) {
// the shard was missing when creating the PIT, ignore.
continue;
}
final DiscoveryNode node = nodeLookup.apply(contextId.getClusterAlias(), contextId.getNode());

Set<DiscoveryNode> targetNodes;
if (node != null) {
targetNodes = Collections.singleton(node);
} else {
staticLogger.info("---> missing node when closing context: " + contextId.getNode());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: We need to close the contexts after moving them when the "old" PIT is used, so if the originally encoded node it gone we try all remaining ones that currently hold that shard here (regardless of whether that node also hold a pit context.

// TODO we won't be able to use this with remote clusters
IndexShardRoutingTable indexShardRoutingTable = clusterService.state()
.routingTable(projectResolver.getProjectId())
.shardRoutingTable(entry.getKey());
targetNodes = indexShardRoutingTable.assignedUnpromotableShards()
.stream()
.map(ShardRouting::currentNodeId)
.map(nodeId -> nodeLookup.apply(contextId.getClusterAlias(), nodeId))
.collect(Collectors.toSet());
staticLogger.info("---> trying alternative nodes to close context: " + targetNodes);
}
for (DiscoveryNode targetNode : targetNodes) {
try {
searchTransportService.sendFreeContext(
searchTransportService.getConnection(contextId.getClusterAlias(), node),
searchTransportService.getConnection(contextId.getClusterAlias(), targetNode),
contextId.getSearchContextId(),
refs.acquireListener().map(r -> {
if (r.isFreed()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.action.search;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.index.shard.ShardId;

import java.io.IOException;
import java.util.Base64;
import java.util.Collections;
import java.util.Map;

public class PITHelper {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for debugging atm


public static SearchContextId decodePITId(String id) {
return decodePITId(new BytesArray(Base64.getUrlDecoder().decode(id)));
}

public static SearchContextId decodePITId(BytesReference id) {
try (var in = id.streamInput()) {
final TransportVersion version = TransportVersion.readVersion(in);
in.setTransportVersion(version);
final Map<ShardId, SearchContextIdForNode> shards = Collections.unmodifiableMap(
in.readCollection(Maps::newHashMapWithExpectedSize, (i, map) -> map.put(new ShardId(in), new SearchContextIdForNode(in)))
);
return new SearchContextId(shards, Collections.emptyMap());
} catch (IOException e) {
assert false : e;
throw new IllegalArgumentException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.ShardSearchContextId;
import org.elasticsearch.transport.RemoteClusterAware;
Expand All @@ -30,6 +29,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -62,6 +62,26 @@ public static BytesReference encode(
Map<String, AliasFilter> aliasFilter,
TransportVersion version,
ShardSearchFailure[] shardFailures
) {
Map<ShardId, SearchContextIdForNode> shards = searchPhaseResults.stream()
.collect(
Collectors.toMap(
r -> r.getSearchShardTarget().getShardId(),
r -> new SearchContextIdForNode(
r.getSearchShardTarget().getClusterAlias(),
r.getSearchShardTarget().getNodeId(),
r.getContextId()
)
)
);
return encode(shards, aliasFilter, version, shardFailures);
}

static BytesReference encode(
Map<ShardId, SearchContextIdForNode> shards,
Map<String, AliasFilter> aliasFilter,
TransportVersion version,
ShardSearchFailure[] shardFailures
) {
assert shardFailures.length == 0 || version.onOrAfter(TransportVersions.V_8_16_0)
: "[allow_partial_search_results] cannot be enabled on a cluster that has not been fully upgraded to version ["
Expand All @@ -71,12 +91,12 @@ public static BytesReference encode(
out.setTransportVersion(version);
TransportVersion.writeVersion(version, out);
boolean allowNullContextId = out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0);
int shardSize = searchPhaseResults.size() + (allowNullContextId ? shardFailures.length : 0);
int shardSize = shards.size() + (allowNullContextId ? shardFailures.length : 0);
out.writeVInt(shardSize);
for (var searchResult : searchPhaseResults) {
final SearchShardTarget target = searchResult.getSearchShardTarget();
target.getShardId().writeTo(out);
new SearchContextIdForNode(target.getClusterAlias(), target.getNodeId(), searchResult.getContextId()).writeTo(out);
for (ShardId shardId : shards.keySet()) {
shardId.writeTo(out);
SearchContextIdForNode searchContextIdForNode = shards.get(shardId);
searchContextIdForNode.writeTo(out);
}
if (allowNullContextId) {
for (var failure : shardFailures) {
Expand Down Expand Up @@ -142,4 +162,23 @@ public String[] getActualIndices() {
}
return indices.toArray(String[]::new);
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
SearchContextId that = (SearchContextId) o;
return Objects.equals(shards, that.shards)
&& Objects.equals(aliasFilter, that.aliasFilter)
&& Objects.equals(contextIds, that.contextIds);
}

@Override
public int hashCode() {
return Objects.hash(shards, aliasFilter, contextIds);
}

@Override
public String toString() {
return "SearchContextId{" + "shards=" + shards + ", aliasFilter=" + aliasFilter + '}';
}
}
Loading