Skip to content

Commit 8e7a03c

Browse files
author
elasticsearchmachine
committed
Fix tests
1 parent d661af3 commit 8e7a03c

File tree

5 files changed

+40
-52
lines changed

5 files changed

+40
-52
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ public class Knn extends SingleFieldFullTextFunction implements OptionalArgument
6666

6767
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
6868

69-
// k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes
70-
private final transient Integer k;
69+
// Implicit k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes
70+
private final transient Integer implicitK;
7171
// Expressions to be used as prefilters in knn query
7272
private final List<Expression> filterExpressions;
7373

@@ -169,12 +169,12 @@ public Knn(
169169
Expression field,
170170
Expression query,
171171
Expression options,
172-
Integer k,
172+
Integer implicitK,
173173
QueryBuilder queryBuilder,
174174
List<Expression> filterExpressions
175175
) {
176176
super(source, field, query, options, expressionList(field, query, options), queryBuilder);
177-
this.k = k;
177+
this.implicitK = implicitK;
178178
this.filterExpressions = filterExpressions;
179179
}
180180

@@ -188,15 +188,15 @@ private static List<Expression> expressionList(Expression field, Expression quer
188188
return result;
189189
}
190190

191-
public Integer k() {
192-
return k;
191+
public Integer implicitK() {
192+
return implicitK;
193193
}
194194

195195
public List<Expression> filterExpressions() {
196196
return filterExpressions;
197197
}
198198

199-
public Knn replaceK(Integer k) {
199+
public Knn withImplicitK(Integer k) {
200200
Check.notNull(k, "k must not be null");
201201
return new Knn(source(), field(), query(), options(), k, queryBuilder(), filterExpressions());
202202
}
@@ -214,7 +214,7 @@ public List<Number> queryAsObject() {
214214

215215
@Override
216216
public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
217-
return new Knn(source(), field(), query(), options(), k(), queryBuilder, filterExpressions());
217+
return new Knn(source(), field(), query(), options(), implicitK(), queryBuilder, filterExpressions());
218218
}
219219

220220
@Override
@@ -230,7 +230,7 @@ public Translatable translatable(LucenePushdownPredicates pushdownPredicates) {
230230

231231
@Override
232232
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
233-
assert k() != null : "Knn function must have a k value set before translation";
233+
assert implicitK() != null : "Knn function must have a k value set before translation";
234234
var fieldAttribute = fieldAsFieldAttribute(field());
235235

236236
Check.notNull(fieldAttribute, "Knn must have a field attribute as the first argument");
@@ -249,7 +249,10 @@ protected Query translate(LucenePushdownPredicates pushdownPredicates, Translato
249249
}
250250
}
251251

252-
return new KnnQuery(source(), fieldName, queryAsFloats, k(), queryOptions(), filterQueries);
252+
Map<String, Object> options = queryOptions();
253+
Integer explicitK = (Integer) options.get(K_FIELD.getPreferredName());
254+
255+
return new KnnQuery(source(), fieldName, queryAsFloats, explicitK != null ? explicitK : implicitK(), options, filterQueries);
253256
}
254257

255258
private float[] queryAsFloats() {
@@ -262,7 +265,7 @@ private float[] queryAsFloats() {
262265
}
263266

264267
public Expression withFilters(List<Expression> filterExpressions) {
265-
return new Knn(source(), field(), query(), options(), k(), queryBuilder(), filterExpressions);
268+
return new Knn(source(), field(), query(), options(), implicitK(), queryBuilder(), filterExpressions);
266269
}
267270

268271
private Map<String, Object> queryOptions() throws InvalidArgumentException {
@@ -287,7 +290,7 @@ protected QueryBuilder evaluatorQueryBuilder() {
287290
@Override
288291
public void postOptimizationVerification(Failures failures) {
289292
// Check that a k has been set
290-
if (k() == null) {
293+
if (implicitK() == null) {
291294
failures.add(
292295
Failure.fail(this, "Knn function must be used with a LIMIT clause after it to set the number of nearest neighbors to find")
293296
);
@@ -301,15 +304,15 @@ public Expression replaceChildren(List<Expression> newChildren) {
301304
newChildren.get(0),
302305
newChildren.get(1),
303306
newChildren.size() > 2 ? newChildren.get(2) : null,
304-
k(),
307+
implicitK(),
305308
queryBuilder(),
306309
filterExpressions()
307310
);
308311
}
309312

310313
@Override
311314
protected NodeInfo<? extends Expression> info() {
312-
return NodeInfo.create(this, Knn::new, field(), query(), options(), k(), queryBuilder(), filterExpressions());
315+
return NodeInfo.create(this, Knn::new, field(), query(), options(), implicitK(), queryBuilder(), filterExpressions());
313316
}
314317

315318
@Override
@@ -357,12 +360,14 @@ public boolean equals(Object o) {
357360
// ignore options when comparing two Knn functions
358361
if (o == null || getClass() != o.getClass()) return false;
359362
Knn knn = (Knn) o;
360-
return super.equals(knn) && Objects.equals(k(), knn.k()) && Objects.equals(filterExpressions(), knn.filterExpressions());
363+
return super.equals(knn)
364+
&& Objects.equals(implicitK(), knn.implicitK())
365+
&& Objects.equals(filterExpressions(), knn.filterExpressions());
361366
}
362367

363368
@Override
364369
public int hashCode() {
365-
return Objects.hash(field(), query(), queryBuilder(), k(), filterExpressions());
370+
return Objects.hash(field(), query(), queryBuilder(), implicitK(), filterExpressions());
366371
}
367372

368373
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushLimitToKnn.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public LogicalPlan rule(Limit limit, LogicalOptimizerContext ctx) {
6161
private Expression limitFilterExpressions(Expression condition, Limit limit, LogicalOptimizerContext ctx) {
6262
return condition.transformDown(exp -> {
6363
if (exp instanceof Knn knn) {
64-
return knn.replaceK((Integer) limit.limit().fold(ctx.foldCtx()));
64+
return knn.withImplicitK((Integer) limit.limit().fold(ctx.foldCtx()));
6565
}
6666
return exp;
6767
});

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import java.util.Objects;
2222

2323
import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
24-
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
2524
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
2625
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VISIT_PERCENTAGE_FIELD;
2726

@@ -53,10 +52,9 @@ protected QueryBuilder asBuilder() {
5352
if (oversample != null) {
5453
rescoreVectorBuilder = new RescoreVectorBuilder(oversample);
5554
}
56-
Integer k = (Integer) options.get(K_FIELD);
5755
Float vectorSimilarity = (Float) options.get(VECTOR_SIMILARITY_FIELD.getPreferredName());
5856
Integer minCandidates = (Integer) options.get(Knn.MIN_CANDIDATES_OPTION);
59-
Float visitPercentage = (Float) options.get(VISIT_PERCENTAGE_FIELD);
57+
Float visitPercentage = (Float) options.get(VISIT_PERCENTAGE_FIELD.getPreferredName());
6058
minCandidates = minCandidates == null ? null : Math.max(minCandidates, k);
6159

6260
// TODO: expose visit_percentage in ESQL

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,7 +1286,8 @@ public void testKnnOptionsPushDown() {
12861286
String query = """
12871287
from test
12881288
| where KNN(dense_vector, [0.1, 0.2, 0.3],
1289-
{ "similarity": 0.001, "min_candidates": 5000, "rescore_oversample": 7, "boost": 3.5 })
1289+
{"k": 10, "min_candidates": 20, "rescore_oversample": 1.5, "similarity": 0.5, "boost": 2.0, "visit_percentage": 0.25})
1290+
| limit 50
12901291
""";
12911292
var analyzer = makeAnalyzer("mapping-all-types.json");
12921293
var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer);
@@ -1297,12 +1298,12 @@ public void testKnnOptionsPushDown() {
12971298
var expectedQuery = new KnnVectorQueryBuilder(
12981299
"dense_vector",
12991300
new float[] { 0.1f, 0.2f, 0.3f },
1300-
5000,
1301-
5000,
1302-
null,
1303-
new RescoreVectorBuilder(7),
1304-
0.001f
1305-
).boost(3.5f);
1301+
10,
1302+
20,
1303+
0.25f,
1304+
new RescoreVectorBuilder(1.5f),
1305+
0.5f
1306+
).boost(2.0f);
13061307
assertEquals(expectedQuery.toString(), planStr.get());
13071308
}
13081309

@@ -1322,10 +1323,10 @@ public void testKnnUsesLimitForK() {
13221323
assertEquals(expectedQuery.toString(), planStr.get());
13231324
}
13241325

1325-
public void testKnnKAndMinCandidatesLowerK() {
1326+
public void testKnnKOverridesLimitK() {
13261327
String query = """
13271328
from test
1328-
| where KNN(dense_vector, [0.1, 0.2, 0.3], {"min_candidates": 50})
1329+
| where KNN(dense_vector, [0.1, 0.2, 0.3], {"k": 20})
13291330
| limit 10
13301331
""";
13311332
var analyzer = makeAnalyzer("mapping-all-types.json");
@@ -1334,23 +1335,7 @@ public void testKnnKAndMinCandidatesLowerK() {
13341335
AtomicReference<String> planStr = new AtomicReference<>();
13351336
plan.forEachDown(EsQueryExec.class, result -> planStr.set(result.query().toString()));
13361337

1337-
var expectedQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0.1f, 0.2f, 0.3f }, 50, 50, null, null, null);
1338-
assertEquals(expectedQuery.toString(), planStr.get());
1339-
}
1340-
1341-
public void testKnnKAndMinCandidatesHigherK() {
1342-
String query = """
1343-
from test
1344-
| where KNN(dense_vector, [0.1, 0.2, 0.3], {"min_candidates": 10})
1345-
| limit 50
1346-
""";
1347-
var analyzer = makeAnalyzer("mapping-all-types.json");
1348-
var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer);
1349-
1350-
AtomicReference<String> planStr = new AtomicReference<>();
1351-
plan.forEachDown(EsQueryExec.class, result -> planStr.set(result.query().toString()));
1352-
1353-
var expectedQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0.1f, 0.2f, 0.3f }, 50, 50, null, null, null);
1338+
var expectedQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0.1f, 0.2f, 0.3f }, 20, null, null, null, null);
13541339
assertEquals(expectedQuery.toString(), planStr.get());
13551340
}
13561341

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8846,7 +8846,7 @@ public void testKnnImplicitLimit() {
88468846
var limit = as(optimized, Limit.class);
88478847
var filter = as(limit.child(), Filter.class);
88488848
var knn = as(filter.condition(), Knn.class);
8849-
assertThat(knn.k(), equalTo(1000));
8849+
assertThat(knn.implicitK(), equalTo(1000));
88508850
}
88518851

88528852
public void testKnnWithLimit() {
@@ -8860,7 +8860,7 @@ public void testKnnWithLimit() {
88608860
var limit = as(optimized, Limit.class);
88618861
var filter = as(limit.child(), Filter.class);
88628862
var knn = as(filter.condition(), Knn.class);
8863-
assertThat(knn.k(), equalTo(10));
8863+
assertThat(knn.implicitK(), equalTo(10));
88648864
}
88658865

88668866
public void testKnnWithTopN() {
@@ -8875,7 +8875,7 @@ public void testKnnWithTopN() {
88758875
var topN = as(optimized, TopN.class);
88768876
var filter = as(topN.child(), Filter.class);
88778877
var knn = as(filter.condition(), Knn.class);
8878-
assertThat(knn.k(), equalTo(10));
8878+
assertThat(knn.implicitK(), equalTo(10));
88798879
}
88808880

88818881
public void testKnnWithMultipleLimitsAfterTopN() {
@@ -8893,7 +8893,7 @@ public void testKnnWithMultipleLimitsAfterTopN() {
88938893
var limit = as(topN.child(), Limit.class);
88948894
var filter = as(limit.child(), Filter.class);
88958895
var knn = as(filter.condition(), Knn.class);
8896-
assertThat(knn.k(), equalTo(20));
8896+
assertThat(knn.implicitK(), equalTo(20));
88978897
}
88988898

88998899
public void testKnnWithMultipleLimitsCombined() {
@@ -8909,7 +8909,7 @@ public void testKnnWithMultipleLimitsCombined() {
89098909
assertThat(limit.limit().fold(FoldContext.small()), equalTo(10));
89108910
var filter = as(limit.child(), Filter.class);
89118911
var knn = as(filter.condition(), Knn.class);
8912-
assertThat(knn.k(), equalTo(10));
8912+
assertThat(knn.implicitK(), equalTo(10));
89138913
}
89148914

89158915
public void testKnnWithMultipleClauses() {
@@ -8965,7 +8965,7 @@ public void testKnnWithRerankAmdLimit() {
89658965
assertThat(limit.limit().fold(FoldContext.small()), equalTo(100));
89668966
var filter = as(limit.child(), Filter.class);
89678967
var knn = as(filter.condition(), Knn.class);
8968-
assertThat(knn.k(), equalTo(100));
8968+
assertThat(knn.implicitK(), equalTo(100));
89698969
}
89708970

89718971
private LogicalPlanOptimizer getCustomRulesLogicalPlanOptimizer(

0 commit comments

Comments
 (0)