diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java index 87d8f839dfca1..a1bf70dacf741 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java @@ -105,11 +105,12 @@ public LeafBucketCollector getLeafCollector(AggregationExecutionContext aggCtx, // when post collecting then we have already replaced the leaf readers on the aggregator level have already been // replaced with the next leaf readers and then post collection pushes docids of the previous segment, which // then causes assertions to trip or incorrect top docs to be computed. + var leafCollectors = this.leafCollectors; if (leafCollectors != null) { + this.leafCollectors = null; // set to null, just in case the new allocation below fails leafCollectors.close(); - leafCollectors = null; // set to null, just in case the new allocation below fails } - leafCollectors = new LongObjectPagedHashMap<>(1, bigArrays); + final var currentLeafCollectors = this.leafCollectors = new LongObjectPagedHashMap<>(1, bigArrays); return new LeafBucketCollectorBase(sub, null) { Scorable scorer; @@ -118,13 +119,21 @@ public LeafBucketCollector getLeafCollector(AggregationExecutionContext aggCtx, public void setScorer(Scorable scorer) throws IOException { this.scorer = scorer; super.setScorer(scorer); - for (Cursor leafCollector : leafCollectors) { + for (Cursor leafCollector : currentLeafCollectors) { leafCollector.value.setScorer(scorer); } } @Override public void collect(int docId, long bucket) throws IOException { + LeafCollector leafCollector = currentLeafCollectors.get(bucket); + if (leafCollector == null) { + leafCollector = initLeafCollector(bucket); + } + leafCollector.collect(docId); + } + + private LeafCollector initLeafCollector(long bucket) throws IOException { Collectors collectors = topDocsCollectors.get(bucket); if (collectors == null) { SortAndFormats sort = subSearchContext.sort(); @@ -138,29 +147,25 @@ public void collect(int docId, long bucket) throws IOException { // but here we create collectors ourselves and we need prevent OOM because of crazy an offset and size. topN = Math.min(topN, subSearchContext.searcher().getIndexReader().maxDoc()); if (sort == null) { - TopScoreDocCollector topScoreDocCollector = new TopScoreDocCollectorManager(topN, null, Integer.MAX_VALUE, false) + TopScoreDocCollector topScoreDocCollector = new TopScoreDocCollectorManager(topN, null, Integer.MAX_VALUE) .newCollector(); collectors = new Collectors(topScoreDocCollector, null); } else { // TODO: can we pass trackTotalHits=subSearchContext.trackTotalHits(){ // Note that this would require to catch CollectionTerminatedException collectors = new Collectors( - new TopFieldCollectorManager(sort.sort, topN, null, Integer.MAX_VALUE, false).newCollector(), + new TopFieldCollectorManager(sort.sort, topN, null, Integer.MAX_VALUE).newCollector(), subSearchContext.trackScores() ? new MaxScoreCollector() : null ); } topDocsCollectors.put(bucket, collectors); } - - LeafCollector leafCollector = leafCollectors.get(bucket); - if (leafCollector == null) { - leafCollector = collectors.collector.getLeafCollector(aggCtx.getLeafReaderContext()); - if (scorer != null) { - leafCollector.setScorer(scorer); - } - leafCollectors.put(bucket, leafCollector); + LeafCollector leafCollector = collectors.collector.getLeafCollector(aggCtx.getLeafReaderContext()); + if (scorer != null) { + leafCollector.setScorer(scorer); } - leafCollector.collect(docId); + currentLeafCollectors.put(bucket, leafCollector); + return leafCollector; } }; }