Skip to content

Commit f6e50bd

Browse files
author
afoucret
committed
Fix some tests.
1 parent 4558c84 commit f6e50bd

File tree

3 files changed

+59
-7
lines changed

3 files changed

+59
-7
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ public static LogicalOptimizerContext unboundLogicalOptimizerContext() {
554554
mock(ProjectResolver.class),
555555
mock(IndexNameExpressionResolver.class),
556556
null,
557-
new InferenceService(mock(Client.class)),
557+
new InferenceService(createMockClient()),
558558
new BlockFactoryProvider(PlannerUtils.NON_BREAKING_BLOCK_FACTORY),
559559
TEST_PLANNER_SETTINGS,
560560
new CrossProjectModeDecider(Settings.EMPTY)
@@ -563,6 +563,7 @@ public static LogicalOptimizerContext unboundLogicalOptimizerContext() {
563563
private static ClusterService createMockClusterService() {
564564
var service = mock(ClusterService.class);
565565
doReturn(new ClusterName("test-cluster")).when(service).getClusterName();
566+
doReturn(Settings.EMPTY).when(service).getSettings();
566567
return service;
567568
}
568569

@@ -578,6 +579,13 @@ private static ThreadPool createMockThreadPool() {
578579
return threadPool;
579580
}
580581

582+
private static Client createMockClient() {
583+
var client = mock(Client.class);
584+
doReturn(Settings.EMPTY).when(client).settings();
585+
doReturn(createMockThreadPool()).when(client).threadPool();
586+
return client;
587+
}
588+
581589
private EsqlTestUtils() {}
582590

583591
public static Configuration configuration(QueryPragmas pragmas, String query, EsqlStatement statement) {
@@ -732,6 +740,8 @@ private static <T> List<T> toList(Iterator<T> iterator) {
732740
public static List<String> withDefaultLimitWarning(List<String> warnings) {
733741
List<String> result = warnings == null ? new ArrayList<>() : new ArrayList<>(warnings);
734742
result.add("No limit defined, adding default limit of [1000]");
743+
744+
result.add("No limit defined, adding default limit of [100]");
735745
return result;
736746
}
737747

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8943,13 +8943,54 @@ public void testKnnWithStats() {
89438943
}
89448944

89458945
public void testKnnWithRerankAmdTopN() {
8946-
assertThat(typesError("""
8946+
var query = """
89478947
from types metadata _score
89488948
| where knn(dense_vector, [0, 1, 2])
8949+
| LIMIT 100
89498950
| rerank "some text" on text with { "inference_id" : "reranking-inference-id" }
89508951
| sort _score desc
89518952
| limit 10
8952-
"""), containsString("Knn function must be used with a LIMIT clause"));
8953+
""";
8954+
8955+
var optimized = planTypes(query);
8956+
8957+
var topN = as(optimized, TopN.class);
8958+
assertThat(topN.limit().fold(FoldContext.small()), equalTo(10));
8959+
8960+
var rerank = as(topN.child(), Rerank.class);
8961+
var rerankLimit = as(rerank.child(), Limit.class);
8962+
assertThat(rerankLimit.limit().fold(FoldContext.small()), equalTo(100));
8963+
var filter = as(rerankLimit.child(), Filter.class);
8964+
8965+
// KNN is using the first limit (100)
8966+
var knn = as(filter.condition(), Knn.class);
8967+
assertThat(knn.implicitK(), equalTo(100));
8968+
}
8969+
8970+
public void testKnnWithRerankImplicitLimitAmdTopN() {
8971+
var query = """
8972+
from types metadata _score
8973+
| where knn(dense_vector, [0, 1, 2])
8974+
| rerank "some text" on text with { "inference_id" : "reranking-inference-id" }
8975+
| sort _score desc
8976+
| limit 10
8977+
""";
8978+
8979+
var optimized = planTypes(query);
8980+
8981+
var topN = as(optimized, TopN.class);
8982+
assertThat(topN.limit().fold(FoldContext.small()), equalTo(10));
8983+
8984+
var rerank = as(topN.child(), Rerank.class);
8985+
8986+
// RERANK implicit limit is set by setting (1000)
8987+
var rerankLimit = as(rerank.child(), Limit.class);
8988+
assertThat(rerankLimit.limit().fold(FoldContext.small()), equalTo(1000));
8989+
var filter = as(rerankLimit.child(), Filter.class);
8990+
8991+
// KNN is using the implicit limit of RERANK (1000)
8992+
var knn = as(filter.condition(), Knn.class);
8993+
assertThat(knn.implicitK(), equalTo(1000));
89538994
}
89548995

89558996
public void testKnnWithRerankAmdLimit() {
@@ -8964,7 +9005,8 @@ public void testKnnWithRerankAmdLimit() {
89649005

89659006
var rerank = as(optimized, Rerank.class);
89669007
var limit = as(rerank.child(), Limit.class);
8967-
assertThat(limit.limit().fold(FoldContext.small()), equalTo(100));
9008+
assertThat(limit.limit().fold(FoldContext.small()), equalTo(1_000));
9009+
89689010
var filter = as(limit.child(), Filter.class);
89699011
var knn = as(filter.condition(), Knn.class);
89709012
assertThat(knn.implicitK(), equalTo(100));

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/EnforceRowLimitsTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ public String toString() {
6060

6161
public static class CompletionTestCase extends TestCase {
6262
public String name() {
63-
return "Rerank";
63+
return "Completion";
6464
}
6565

6666
public Completion createPlan(LogicalPlan child, Expression rowLimit) {
6767
Source source = EMPTY;
6868
Expression prompt = Literal.keyword(source, "test prompt");
69-
Attribute targetField = new ReferenceAttribute(source, null, "completion", DataType.KEYWORD);
69+
Attribute targetField = new ReferenceAttribute(source, "completion", DataType.KEYWORD);
7070
return new Completion(source, child, Literal.keyword(source, "test-inference-id"), prompt, targetField, rowLimit);
7171
}
7272
}
@@ -80,7 +80,7 @@ public String name() {
8080
public Rerank createPlan(LogicalPlan child, Expression rowLimit) {
8181
Source source = EMPTY;
8282
Expression queryText = Literal.keyword(source, "test query");
83-
Attribute scoreAttribute = new ReferenceAttribute(source, null, "score", DataType.DOUBLE);
83+
Attribute scoreAttribute = new ReferenceAttribute(source, "score", DataType.DOUBLE);
8484
List<Alias> rerankFields = List.of();
8585
return new Rerank(
8686
source,

0 commit comments

Comments
 (0)