Skip to content

Commit 8317911

Browse files
committed
Add scoring
1 parent 3c4c401 commit 8317911

File tree

3 files changed

+36
-8
lines changed

3 files changed

+36
-8
lines changed

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.common.settings.Settings;
1313
import org.elasticsearch.xcontent.XContentBuilder;
1414
import org.elasticsearch.xcontent.XContentFactory;
15+
import org.elasticsearch.xpack.esql.EsqlTestUtils;
1516
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
1617
import org.junit.Before;
1718

@@ -29,14 +30,34 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
2930

3031
public void testKnn() {
3132
var query = """
32-
FROM test
33-
| WHERE knn(vector, [1.0, 2.0, 3.0])
34-
| KEEP id, floats
33+
FROM test METADATA _score
34+
| WHERE knn(vector, [1.0, 1.0, 1.0])
35+
| KEEP id, floats, _score, vector
36+
| SORT _score DESC
3537
""";
3638

3739
try (var resp = run(query)) {
38-
assertColumnNames(resp.columns(), List.of("id", "floats"));
39-
assertColumnTypes(resp.columns(), List.of("integer", "double"));
40+
assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector"));
41+
assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector"));
42+
43+
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
44+
assertEquals(indexedVectors.size(), valuesList.size());
45+
for (int i = 0; i < valuesList.size(); i++) {
46+
List<Object> row = valuesList.get(i);
47+
// Vectors should be in order of ID, as they're less similar than the query vector as the ID increases
48+
assertEquals(i, row.getFirst());
49+
@SuppressWarnings("unchecked")
50+
// Vectors should be the same
51+
List<Double> floats = (List<Double>)row.get(1);
52+
for(int j = 0; j < floats.size(); j++) {
53+
assertEquals(floats.get(j).floatValue(), indexedVectors.get(i).get(j), 0f);
54+
}
55+
var score = (Double) row.get(2);
56+
assertNotNull(score);
57+
assertTrue(score > 0.0);
58+
// dense_vector is null for now
59+
assertNull(row.get(3));
60+
}
4061
}
4162
}
4263

@@ -67,7 +88,7 @@ public void setup() throws IOException {
6788
var CreateRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build());
6889
assertAcked(CreateRequest);
6990

70-
int numDocs = randomIntBetween(10, 100);
91+
int numDocs = 10;
7192
int numDims = 3;
7293
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
7394
float value = 0.0f;
@@ -76,7 +97,7 @@ public void setup() throws IOException {
7697
for (int j = 0; j < numDims; j++) {
7798
vector.add(value++);
7899
}
79-
docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector);
100+
docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "floats", vector, "vector", vector);
80101
indexedVectors.put(i, vector);
81102
}
82103

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,9 @@ public static ElementType toElementType(DataType dataType, MappedFieldType.Field
299299
case GEO_POINT, CARTESIAN_POINT -> fieldExtractPreference == DOC_VALUES ? ElementType.LONG : ElementType.BYTES_REF;
300300
case GEO_SHAPE, CARTESIAN_SHAPE -> fieldExtractPreference == EXTRACT_SPATIAL_BOUNDS ? ElementType.INT : ElementType.BYTES_REF;
301301
case PARTIAL_AGG, AGGREGATE_METRIC_DOUBLE -> ElementType.COMPOSITE;
302-
case SHORT, BYTE, DATE_PERIOD, TIME_DURATION, OBJECT, FLOAT, HALF_FLOAT, SCALED_FLOAT, DENSE_VECTOR ->
302+
// Can't throw IAE as this is used to estimate row size
303+
case DENSE_VECTOR -> ElementType.NULL;
304+
case SHORT, BYTE, DATE_PERIOD, TIME_DURATION, OBJECT, FLOAT, HALF_FLOAT, SCALED_FLOAT ->
303305
throw EsqlIllegalArgumentException.illegalDataType(dataType);
304306
};
305307
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,9 @@ public boolean equals(Object o) {
4747
public int hashCode() {
4848
return Objects.hash(super.hashCode(), field, Arrays.hashCode(query));
4949
}
50+
51+
@Override
52+
public boolean scorable() {
53+
return true;
54+
}
5055
}

0 commit comments

Comments
 (0)