Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,13 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
private final Object shardFailuresMutex = new Object();
private final AtomicBoolean hasShardResponse = new AtomicBoolean(false);
private final AtomicInteger successfulOps = new AtomicInteger();
private final AtomicInteger skippedOps = new AtomicInteger();
private final SearchTimeProvider timeProvider;
private final SearchResponse.Clusters clusters;

protected final GroupShardsIterator<SearchShardIterator> toSkipShardsIts;
protected final GroupShardsIterator<SearchShardIterator> shardsIts;
private final SearchShardIterator[] shardIterators;
private final int expectedTotalOps;
private final AtomicInteger totalOps = new AtomicInteger();
private final AtomicInteger outstandingShards;
private final int maxConcurrentRequestsPerNode;
private final Map<String, PendingExecutions> pendingExecutionsPerNode = new ConcurrentHashMap<>();
private final boolean throttleConcurrentRequests;
Expand Down Expand Up @@ -139,18 +137,12 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
}
this.toSkipShardsIts = new GroupShardsIterator<>(toSkipIterators);
this.shardsIts = new GroupShardsIterator<>(iterators);

outstandingShards = new AtomicInteger(shardsIts.size());
this.shardIterators = iterators.toArray(new SearchShardIterator[0]);
// we later compute the shard index based on the natural order of the shards
// that participate in the search request. This means that this number is
// consistent between two requests that target the same shards.
Arrays.sort(shardIterators);

// we need to add 1 for non active partition, since we count it in the total. This means for each shard in the iterator we sum up
// it's number of active shards but use 1 as the default if no replica of a shard is active at this point.
// on a per shards level we use shardIt.remaining() to increment the totalOps pointer but add 1 for the current shard result
// we process hence we add one for the non active partition here.
this.expectedTotalOps = shardsIts.totalSizeWith1ForEmpty();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For posterity, this was introduced with faefc77 . The intent was to account for inactive shards (that are in the process of being allocated) as 1, so that they are expected to fail, otherwise their failure would make the search complete earlier than expected, missing results from other active shards.

Together with it SearchWhileCreatingIndexTests was added, which is now called SearchWhileCreatingIndexIT. The test went through a lot of changes and refactoring over time, and it was also quite problematic and flaky in the early days. Funnily enough, if I remove the counting of the empty group as 1 (and call totalSize instead), this specific test still succeeds. I may not have run it enough times to cause failures, or perhaps the issue that this was fixing no longer manifests. Either way, a lot of other tests fail due to too many ops executed compared to the expected ops, because the counting also needs to be adjusted accordingly (which is expected).

In principle, I agree that counting each shard as 1, regardless of how many copies it has whether that be inactive, primary only, one replica or multiple replicas is simpler.

this.maxConcurrentRequestsPerNode = maxConcurrentRequestsPerNode;
// in the case were we have less shards than maxConcurrentRequestsPerNode we don't need to throttle
this.throttleConcurrentRequests = maxConcurrentRequestsPerNode < shardsIts.size();
Expand Down Expand Up @@ -251,9 +243,8 @@ protected final void run() {

void skipShard(SearchShardIterator iterator) {
successfulOps.incrementAndGet();
skippedOps.incrementAndGet();
assert iterator.skip();
successfulShardExecution(iterator);
successfulShardExecution();
}

private static boolean assertExecuteOnStartThread() {
Expand Down Expand Up @@ -380,7 +371,7 @@ protected void executeNextPhase(String currentPhase, Supplier<SearchPhase> nextP
"Partial shards failure (unavailable: {}, successful: {}, skipped: {}, num-shards: {}, phase: {})",
discrepancy,
successfulOps.get(),
skippedOps.get(),
toSkipShardsIts.size(),
getNumShards(),
currentPhase
);
Expand Down Expand Up @@ -449,17 +440,14 @@ private void onShardFailure(final int shardIndex, SearchShardTarget shard, final
}
onShardGroupFailure(shardIndex, shard, e);
}
final int totalOps = this.totalOps.incrementAndGet();
if (totalOps == expectedTotalOps) {
onPhaseDone();
} else if (totalOps > expectedTotalOps) {
throw new AssertionError(
"unexpected higher total ops [" + totalOps + "] compared to expected [" + expectedTotalOps + "]",
new SearchPhaseExecutionException(getName(), "Shard failures", null, buildShardFailures())
);
if (lastShard == false) {
performPhaseOnShard(shardIndex, shardIt, nextShard);
} else {
if (lastShard == false) {
performPhaseOnShard(shardIndex, shardIt, nextShard);
// count down outstanding shards, we're done with this shard as there's no more copies to try
final int outstanding = outstandingShards.getAndDecrement();
assert outstanding > 0 : "outstanding: " + outstanding;
if (outstanding == 1) {
onPhaseDone();
}
}
}
Expand Down Expand Up @@ -535,10 +523,10 @@ protected void onShardResult(Result result, SearchShardIterator shardIt) {
if (logger.isTraceEnabled()) {
logger.trace("got first-phase result from {}", result != null ? result.getSearchShardTarget() : null);
}
results.consumeResult(result, () -> onShardResultConsumed(result, shardIt));
results.consumeResult(result, () -> onShardResultConsumed(result));
}

private void onShardResultConsumed(Result result, SearchShardIterator shardIt) {
private void onShardResultConsumed(Result result) {
successfulOps.incrementAndGet();
// clean a previous error on this shard group (note, this code will be serialized on the same shardIndex value level
// so its ok concurrency wise to miss potentially the shard failures being created because of another failure
Expand All @@ -552,28 +540,14 @@ private void onShardResultConsumed(Result result, SearchShardIterator shardIt) {
// cause the successor to read a wrong value from successfulOps if second phase is very fast ie. count etc.
// increment all the "future" shards to update the total ops since we some may work and some may not...
// and when that happens, we break on total ops, so we must maintain them
successfulShardExecution(shardIt);
successfulShardExecution();
}

private void successfulShardExecution(SearchShardIterator shardsIt) {
final int remainingOpsOnIterator;
if (shardsIt.skip()) {
// It's possible that we're skipping a shard that's unavailable
// but its range was available in the IndexMetadata, in that
// case the shardsIt.remaining() would be 0, expectedTotalOps
// accounts for unavailable shards too.
remainingOpsOnIterator = Math.max(shardsIt.remaining(), 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definitely good to get rid of this special case.

} else {
remainingOpsOnIterator = shardsIt.remaining() + 1;
}
final int xTotalOps = totalOps.addAndGet(remainingOpsOnIterator);
if (xTotalOps == expectedTotalOps) {
private void successfulShardExecution() {
final int outstanding = outstandingShards.getAndDecrement();
assert outstanding > 0 : "outstanding: " + outstanding;
if (outstanding == 1) {
onPhaseDone();
} else if (xTotalOps > expectedTotalOps) {
throw new AssertionError(
"unexpected higher total ops [" + xTotalOps + "] compared to expected [" + expectedTotalOps + "]",
new SearchPhaseExecutionException(getName(), "Shard failures", null, buildShardFailures())
);
}
}

Expand Down Expand Up @@ -640,7 +614,7 @@ private SearchResponse buildSearchResponse(
scrollId,
getNumShards(),
numSuccess,
skippedOps.get(),
toSkipShardsIts.size(),
buildTookInMillis(),
failures,
clusters,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ protected void run() {
hits.decRef();
}
} finally {
mockSearchPhaseContext.execute(() -> {});
var resp = mockSearchPhaseContext.searchResponse.get();
if (resp != null) {
resp.decRef();
Expand Down