Skip to content

Commit 7f5ddde

Browse files
committed
Add first version of KnnTests and generated docs
1 parent 22efe27 commit 7f5ddde

File tree

8 files changed

+138
-16
lines changed

8 files changed

+138
-16
lines changed

docs/reference/query-languages/esql/images/functions/knn.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/reference/query-languages/esql/kibana/definition/functions/knn.json

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/query-languages/esql/kibana/docs/functions/knn.md

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ public boolean equals(Object obj) {
152152
return false;
153153
}
154154

155-
return Objects.equals(queryBuilder, ((FullTextFunction) obj).queryBuilder);
155+
return Objects.equals(queryBuilder, ((FullTextFunction) obj).queryBuilder)
156+
&& Objects.equals(query, ((FullTextFunction) obj).query);
156157
}
157158

158159
@Override

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,8 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
7272
@FunctionInfo(
7373
returnType = "boolean",
7474
preview = true,
75-
description = """
76-
Finds the k nearest vectors to a query vector, as measured by a similarity metric.
77-
knn function finds nearest vectors through approximate search on indexed dense_vectors
78-
""",
75+
description = "Finds the k nearest vectors to a query vector, as measured by a similarity metric. " +
76+
"knn function finds nearest vectors through approximate search on indexed dense_vectors.",
7977
examples = {
8078
@Example(file = "knn-function", tag = "knn-function"),
8179
@Example(file = "knn-function", tag = "knn-function-options"), },
@@ -96,7 +94,7 @@ public Knn(
9694
name = "boost",
9795
type = "float",
9896
valueHint = { "2.5" },
99-
description = "Floating point number used to decrease or increase the relevance scores of the query. "
97+
description = "Floating point number used to decrease or increase the relevance scores of the query."
10098
+ "Defaults to 1.0."
10199
),
102100
@MapParam.MapParamEntry(
@@ -120,7 +118,7 @@ public Knn(
120118
type = "double",
121119
valueHint = { "0.01" },
122120
description = "The minimum similarity required for a document to be considered a match. "
123-
+ "The similarity value calculated relates to the raw similarity used, not the document score"
121+
+ "The similarity value calculated relates to the raw similarity used, not the document score."
124122
),
125123
@MapParam.MapParamEntry(
126124
name = "rescore_oversample",
@@ -237,12 +235,13 @@ public void writeTo(StreamOutput out) throws IOException {
237235

238236
@Override
239237
public boolean equals(Object o) {
238+
// Knn does not serialize options, as they get included in the query builder. We need to override equals and hashcode to
239+
// ignore options when comparing two Knn functions
240240
if (o == null || getClass() != o.getClass()) return false;
241-
if (super.equals(o) == false) return false;
242241
Knn knn = (Knn) o;
243-
return Objects.equals(field, knn.field)
244-
&& Objects.equals(query(), knn.query())
245-
&& Objects.equals(queryBuilder(), knn.queryBuilder());
242+
return Objects.equals(field(), knn.field())
243+
&& Objects.equals(query(), knn.query())
244+
&& Objects.equals(queryBuilder(), knn.queryBuilder());
246245
}
247246

248247
@Override

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.index.query.TermQueryBuilder;
2525
import org.elasticsearch.index.query.TermsQueryBuilder;
2626
import org.elasticsearch.index.query.WildcardQueryBuilder;
27+
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
2728
import org.elasticsearch.test.EqualsHashCodeTestUtils;
2829
import org.elasticsearch.xpack.esql.core.expression.Expression;
2930
import org.elasticsearch.xpack.esql.expression.ExpressionWritables;
@@ -111,6 +112,7 @@ public static NamedWriteableRegistry writableRegistry() {
111112
entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, WildcardQueryBuilder.NAME, WildcardQueryBuilder::new));
112113
entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, RegexpQueryBuilder.NAME, RegexpQueryBuilder::new));
113114
entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, ExistsQueryBuilder.NAME, ExistsQueryBuilder::new));
115+
entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, KnnVectorQueryBuilder.NAME, KnnVectorQueryBuilder::new));
114116
entries.add(SingleValueQuery.ENTRY);
115117
entries.addAll(ExpressionWritables.getNamedWriteables());
116118
entries.addAll(PlanWritables.getNamedWriteables());

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,121 @@
1010
import com.carrotsearch.randomizedtesting.annotations.Name;
1111
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
1212

13+
import org.elasticsearch.index.query.QueryBuilder;
1314
import org.elasticsearch.xpack.esql.core.expression.Expression;
15+
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
16+
import org.elasticsearch.xpack.esql.core.expression.Literal;
17+
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
1418
import org.elasticsearch.xpack.esql.core.tree.Source;
19+
import org.elasticsearch.xpack.esql.core.type.DataType;
20+
import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
1521
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
22+
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
23+
import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput;
24+
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
1625

26+
import java.util.ArrayList;
1727
import java.util.List;
1828
import java.util.function.Supplier;
1929

20-
public class KnnTests extends NoneFieldFullTextFunctionTestCase {
30+
import static org.elasticsearch.xpack.esql.SerializationTestUtils.serializeDeserialize;
31+
import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN;
32+
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
33+
import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD;
34+
import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED;
35+
import static org.elasticsearch.xpack.esql.planner.TranslatorHandler.TRANSLATOR_HANDLER;
36+
import static org.hamcrest.Matchers.equalTo;
37+
38+
public class KnnTests extends AbstractFunctionTestCase {
39+
2140
public KnnTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
22-
super(testCaseSupplier);
41+
this.testCase = testCaseSupplier.get();
2342
}
2443

2544
@ParametersFactory
2645
public static Iterable<Object[]> parameters() {
27-
return generateParameters();
46+
return parameterSuppliersFromTypedData(addFunctionNamedParams(testCaseSuppliers()));
47+
}
48+
49+
private static List<TestCaseSupplier> testCaseSuppliers() {
50+
List<TestCaseSupplier> suppliers = new ArrayList<>();
51+
52+
suppliers.add(
53+
TestCaseSupplier.testCaseSupplier(
54+
new TestCaseSupplier.TypedDataSupplier("dense_vector field", KnnTests::randomDenseVector, DENSE_VECTOR),
55+
new TestCaseSupplier.TypedDataSupplier("query", KnnTests::randomDenseVector, DENSE_VECTOR, true),
56+
(d1, d2) -> equalTo("string"),
57+
BOOLEAN,
58+
(o1, o2) -> true
59+
)
60+
);
61+
62+
return suppliers;
63+
}
64+
65+
private static List<Float> randomDenseVector() {
66+
int dimensions = randomIntBetween(64, 128);
67+
List<Float> vector = new ArrayList<>();
68+
for (int i = 0; i < dimensions; i++) {
69+
vector.add(randomFloat());
70+
}
71+
return vector;
72+
}
73+
74+
/**
75+
* Adds function named parameters to all the test case suppliers provided
76+
*/
77+
private static List<TestCaseSupplier> addFunctionNamedParams(List<TestCaseSupplier> suppliers) {
78+
// TODO get to a common class with MatchTests
79+
List<TestCaseSupplier> result = new ArrayList<>();
80+
for (TestCaseSupplier supplier : suppliers) {
81+
List<DataType> dataTypes = new ArrayList<>(supplier.types());
82+
dataTypes.add(UNSUPPORTED);
83+
result.add(new TestCaseSupplier(supplier.name() + ", options", dataTypes, () -> {
84+
List<TestCaseSupplier.TypedData> values = new ArrayList<>(supplier.get().getData());
85+
values.add(
86+
new TestCaseSupplier.TypedData(
87+
new MapExpression(
88+
Source.EMPTY,
89+
List.of(
90+
new Literal(Source.EMPTY, randomAlphaOfLength(10), KEYWORD)
91+
)
92+
),
93+
UNSUPPORTED,
94+
"options"
95+
).forceLiteral()
96+
);
97+
98+
return new TestCaseSupplier.TestCase(values, equalTo("KnnEvaluator"), BOOLEAN, equalTo(true));
99+
}));
100+
}
101+
return result;
28102
}
29103

30104
@Override
31105
protected Expression build(Source source, List<Expression> args) {
32-
return new Kql(source, args.get(0));
106+
Knn knn = new Knn(source, args.get(0), args.get(1), args.size() > 2 ? args.get(2) : null);
107+
// We need to add the QueryBuilder to the match expression, as it is used to implement equals() and hashCode() and
108+
// thus test the serialization methods. But we can only do this if the parameters make sense .
109+
if (args.get(0) instanceof FieldAttribute && args.get(1).foldable()) {
110+
QueryBuilder queryBuilder = TRANSLATOR_HANDLER.asQuery(LucenePushdownPredicates.DEFAULT, knn).toQueryBuilder();
111+
knn = (Knn) knn.replaceQueryBuilder(queryBuilder);
112+
}
113+
return knn;
114+
}
115+
116+
/**
117+
* Copy of the overridden method that doesn't check for children size, as the {@code options} child isn't serialized in Match.
118+
*/
119+
@Override
120+
protected Expression serializeDeserializeExpression(Expression expression) {
121+
Expression newExpression = serializeDeserialize(
122+
expression,
123+
PlanStreamOutput::writeNamedWriteable,
124+
in -> in.readNamedWriteable(Expression.class),
125+
testCase.getConfiguration() // The configuration query should be == to the source text of the function for this to work
126+
);
127+
// Fields use synthetic sources, which can't be serialized. So we use the originals instead.
128+
return newExpression.replaceChildren(expression.children());
33129
}
34130
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ protected Expression build(Source source, List<Expression> args) {
8282
// thus test the serialization methods. But we can only do this if the parameters make sense .
8383
if (args.get(0) instanceof FieldAttribute && args.get(1).foldable()) {
8484
QueryBuilder queryBuilder = TRANSLATOR_HANDLER.asQuery(LucenePushdownPredicates.DEFAULT, match).toQueryBuilder();
85-
match.replaceQueryBuilder(queryBuilder);
85+
match = (Match) match.replaceQueryBuilder(queryBuilder);
8686
}
8787
return match;
8888
}

0 commit comments

Comments
 (0)