Skip to content

Commit 2e18147

Browse files
committed
simplify collector creation and fix tests
1 parent 7ff51d6 commit 2e18147

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperator.java

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public Factory(
6363
List<SortBuilder<?>> sorts,
6464
boolean needsScore
6565
) {
66-
super(contexts, weightFunction(queryFunction, sorts), dataPartitioning, taskConcurrency, limit, needsScore);
66+
super(contexts, weightFunction(queryFunction, sorts, needsScore), dataPartitioning, taskConcurrency, limit, needsScore);
6767
this.maxPageSize = maxPageSize;
6868
this.sorts = sorts;
6969
}
@@ -309,13 +309,13 @@ static final class ScoringPerShardCollector extends PerShardCollector {
309309
}
310310
}
311311

312-
private static Function<ShardContext, Weight> weightFunction(Function<ShardContext, Query> queryFunction, List<SortBuilder<?>> sorts) {
312+
private static Function<ShardContext, Weight> weightFunction(Function<ShardContext, Query> queryFunction, List<SortBuilder<?>> sorts, boolean needsScore) {
313313
return ctx -> {
314314
final var query = queryFunction.apply(ctx);
315315
final var searcher = ctx.searcher();
316316
try {
317317
// we create a collector with a limit of 1 to determine the appropriate score mode to use.
318-
var scoreMode = newPerShardCollector(ctx, sorts, false, 1).collector.scoreMode();
318+
var scoreMode = newPerShardCollector(ctx, sorts, needsScore, 1).collector.scoreMode();
319319
return searcher.createWeight(searcher.rewrite(query), scoreMode, 1);
320320
} catch (IOException e) {
321321
throw new UncheckedIOException(e);
@@ -332,20 +332,17 @@ private static PerShardCollector newPerShardCollector(ShardContext context, List
332332
if (needsScore == false) {
333333
return new NonScoringPerShardCollector(context, sortAndFormats.get().sort, limit);
334334
}
335-
SortField[] sortFields = sortAndFormats.get().sort.getSort();
336-
if (sortFields != null && sortFields.length == 1 && sortFields[0].needsScores() && sortFields[0].getReverse() == false) {
335+
Sort sort = sortAndFormats.get().sort;
336+
if (Sort.RELEVANCE.equals(sort)) {
337337
// SORT _score DESC
338338
return new ScoringPerShardCollector(context, new TopScoreDocCollectorManager(limit, null, 0).newCollector());
339339
}
340340

341341
// SORT ..., _score, ...
342-
var sort = new Sort();
343-
if (sortFields != null) {
344-
var l = new ArrayList<>(Arrays.asList(sortFields));
345-
l.add(SortField.FIELD_DOC);
346-
l.add(SortField.FIELD_SCORE);
347-
sort = new Sort(l.toArray(SortField[]::new));
348-
}
342+
var l = new ArrayList<>(Arrays.asList(sort.getSort()));
343+
l.add(SortField.FIELD_DOC);
344+
l.add(SortField.FIELD_SCORE);
345+
sort = new Sort(l.toArray(SortField[]::new));
349346
return new ScoringPerShardCollector(context, new TopFieldCollectorManager(sort, limit, null, 0).newCollector());
350347
}
351348
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
import org.elasticsearch.index.cache.query.TrivialQueryCachingPolicy;
3232
import org.elasticsearch.index.mapper.MapperServiceTestCase;
3333
import org.elasticsearch.node.Node;
34+
import org.elasticsearch.plugins.ExtensiblePlugin;
35+
import org.elasticsearch.plugins.Plugin;
3436
import org.elasticsearch.search.internal.AliasFilter;
3537
import org.elasticsearch.search.internal.ContextIndexSearcher;
3638
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
@@ -46,18 +48,22 @@
4648
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
4749
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
4850
import org.elasticsearch.xpack.esql.session.Configuration;
51+
import org.elasticsearch.xpack.spatial.SpatialPlugin;
4952
import org.hamcrest.Matcher;
5053
import org.junit.After;
5154

5255
import java.io.IOException;
5356
import java.util.ArrayList;
57+
import java.util.Collection;
58+
import java.util.Collections;
5459
import java.util.List;
5560
import java.util.Map;
5661

5762
import static org.hamcrest.Matchers.equalTo;
5863
import static org.hamcrest.Matchers.lessThanOrEqualTo;
5964

6065
public class LocalExecutionPlannerTests extends MapperServiceTestCase {
66+
6167
@ParametersFactory
6268
public static Iterable<Object[]> parameters() throws Exception {
6369
List<Object[]> params = new ArrayList<>();
@@ -78,6 +84,19 @@ public LocalExecutionPlannerTests(@Name("estimatedRowSizeIsHuge") boolean estima
7884
this.estimatedRowSizeIsHuge = estimatedRowSizeIsHuge;
7985
}
8086

87+
@Override
88+
protected Collection<Plugin> getPlugins() {
89+
var plugin = new SpatialPlugin();
90+
plugin.loadExtensions(new ExtensiblePlugin.ExtensionLoader() {
91+
@Override
92+
public <T> List<T> loadExtensions(Class<T> extensionPointType) {
93+
return List.of();
94+
}
95+
});
96+
97+
return Collections.singletonList(plugin);
98+
}
99+
81100
@After
82101
public void closeIndex() throws IOException {
83102
IOUtils.close(reader, directory, () -> Releasables.close(releasables), releasables::clear);
@@ -253,7 +272,9 @@ private List<EsPhysicalOperationProviders.ShardContext> createShardContexts() th
253272
shardContexts.add(
254273
new EsPhysicalOperationProviders.DefaultShardContext(
255274
i,
256-
createSearchExecutionContext(createMapperService(mapping(b -> {})), searcher),
275+
createSearchExecutionContext(createMapperService(mapping(b -> {
276+
b.startObject("point").field("type", "geo_point").endObject();
277+
})), searcher),
257278
AliasFilter.EMPTY
258279
)
259280
);

0 commit comments

Comments
 (0)