Skip to content

Commit e44745e

Browse files
committed
Add testing, fix LuceneQueryEvaluator to pick docs.getPositionCount instead of the docs length to avoid different lengths with non-pushed functions when indexRandom is used
1 parent 1b7f02f commit e44745e

File tree

2 files changed

+32
-17
lines changed

2 files changed

+32
-17
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ private Vector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException
112112
int min = docs.docs().getInt(0);
113113
int max = docs.docs().getInt(docs.getPositionCount() - 1);
114114
int length = max - min + 1;
115-
try (T scoreBuilder = createVectorBuilder(blockFactory, length)) {
115+
try (T scoreBuilder = createVectorBuilder(blockFactory, docs.getPositionCount())) {
116116
if (length == docs.getPositionCount() && length > 1) {
117117
return segmentState.scoreDense(scoreBuilder, min, max);
118118
}

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.action.index.IndexRequestBuilder;
1111
import org.elasticsearch.cluster.metadata.IndexMetadata;
1212
import org.elasticsearch.common.settings.Settings;
13+
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
1314
import org.elasticsearch.xcontent.XContentBuilder;
1415
import org.elasticsearch.xcontent.XContentFactory;
1516
import org.elasticsearch.xpack.esql.EsqlTestUtils;
@@ -18,30 +19,37 @@
1819

1920
import java.io.IOException;
2021
import java.util.ArrayList;
22+
import java.util.Arrays;
2123
import java.util.HashMap;
2224
import java.util.List;
25+
import java.util.Locale;
2326
import java.util.Map;
2427

2528
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
2629

2730
public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
2831

2932
private final Map<Integer, List<Float>> indexedVectors = new HashMap<>();
33+
private int numDocs;
34+
private int numDims;
3035

3136
public void testKnnDefaults() {
32-
var query = """
37+
float[] queryVector = new float[numDims];
38+
Arrays.fill(queryVector, 1.0f);
39+
40+
var query = String.format(Locale.ROOT, """
3341
FROM test METADATA _score
34-
| WHERE knn(vector, [1.0, 1.0, 1.0])
42+
| WHERE knn(vector, %s)
3543
| KEEP id, floats, _score, vector
3644
| SORT _score DESC
37-
""";
45+
""", Arrays.toString(queryVector));
3846

3947
try (var resp = run(query)) {
4048
assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector"));
4149
assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector"));
4250

4351
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
44-
assertEquals(indexedVectors.size(), valuesList.size());
52+
assertEquals(Math.min(indexedVectors.size(), 10), valuesList.size());
4553
for (int i = 0; i < valuesList.size(); i++) {
4654
List<Object> row = valuesList.get(i);
4755
// Vectors should be in order of ID, as they're less similar than the query vector as the ID increases
@@ -62,12 +70,15 @@ public void testKnnDefaults() {
6270
}
6371

6472
public void testKnnOptions() {
65-
var query = """
73+
float[] queryVector = new float[numDims];
74+
Arrays.fill(queryVector, 1.0f);
75+
76+
var query = String.format(Locale.ROOT, """
6677
FROM test METADATA _score
67-
| WHERE knn(vector, [1.0, 1.0, 1.0], {"k": 5})
78+
| WHERE knn(vector, %s, {"k": 5})
6879
| KEEP id, floats, _score, vector
6980
| SORT _score DESC
70-
""";
81+
""", Arrays.toString(queryVector));
7182

7283
try (var resp = run(query)) {
7384
assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector"));
@@ -79,20 +90,24 @@ public void testKnnOptions() {
7990
}
8091

8192
public void testKnnNonPushedDown() {
82-
var query = """
93+
float[] queryVector = new float[numDims];
94+
Arrays.fill(queryVector, 1.0f);
95+
96+
// TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query
97+
var query = String.format(Locale.ROOT, """
8398
FROM test METADATA _score
84-
| WHERE knn(vector, [1.0, 1.0, 1.0], {"k": 5}) OR id % 2 == 0
99+
| WHERE knn(vector, %s, {"k": 5}) OR id > 10
85100
| KEEP id, floats, _score, vector
86101
| SORT _score DESC
87-
""";
102+
""", Arrays.toString(queryVector));
88103

89104
try (var resp = run(query)) {
90105
assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector"));
91106
assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector"));
92107

93108
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
94-
// K = 5, 2 more for % operator, total 7
95-
assertEquals(7, valuesList.size());
109+
// K = 5, 1 more for every id > 10
110+
assertEquals(5 + Math.max(0, numDocs - 10 - 1), valuesList.size());
96111
}
97112
}
98113

@@ -120,11 +135,11 @@ public void setup() throws IOException {
120135
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
121136
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1);
122137

123-
var CreateRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build());
124-
assertAcked(CreateRequest);
138+
var createRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build());
139+
assertAcked(createRequest);
125140

126-
int numDocs = 10;
127-
int numDims = 3;
141+
numDocs = randomIntBetween(10, 20);
142+
numDims = randomIntBetween(3, 10);
128143
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
129144
float value = 0.0f;
130145
for (int i = 0; i < numDocs; i++) {

0 commit comments

Comments
 (0)