Skip to content

Commit bda1c73

Browse files
committed
add more tests
1 parent e1abb0c commit bda1c73

File tree

3 files changed

+166
-3
lines changed

3 files changed

+166
-3
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2551,7 +2551,7 @@ private Query createKnnBitQuery(
25512551
? new DiversifyingParentBlockQuery(parentFilter, createExactKnnBitQuery(queryVector))
25522552
: createExactKnnBitQuery(queryVector);
25532553
knnQuery = filter == null
2554-
? createExactKnnBitQuery(queryVector)
2554+
? exactKnnQuery
25552555
: new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD)
25562556
.add(filter, BooleanClause.Occur.FILTER)
25572557
.build();
@@ -2592,7 +2592,7 @@ private Query createKnnByteQuery(
25922592
? new DiversifyingParentBlockQuery(parentFilter, createExactKnnByteQuery(queryVector))
25932593
: createExactKnnByteQuery(queryVector);
25942594
knnQuery = filter == null
2595-
? createExactKnnByteQuery(queryVector)
2595+
? exactKnnQuery
25962596
: new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD)
25972597
.add(filter, BooleanClause.Occur.FILTER)
25982598
.build();
@@ -2658,7 +2658,7 @@ && isNotUnitVector(squaredMagnitude)) {
26582658
? new DiversifyingParentBlockQuery(parentFilter, createExactKnnFloatQuery(queryVector))
26592659
: createExactKnnFloatQuery(queryVector);
26602660
knnQuery = filter == null
2661-
? createExactKnnFloatQuery(queryVector)
2661+
? exactKnnQuery
26622662
: new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD)
26632663
.add(filter, BooleanClause.Occur.FILTER)
26642664
.build();

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.elasticsearch.search.DocValueFormat;
2727
import org.elasticsearch.search.vectors.DenseVectorQuery;
2828
import org.elasticsearch.search.vectors.DiversifyingParentBlockQuery;
29+
import org.elasticsearch.search.vectors.DiversifyingParentBlockQueryTests;
2930
import org.elasticsearch.search.vectors.ESKnnByteVectorQuery;
3031
import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery;
3132
import org.elasticsearch.search.vectors.RescoreKnnVectorQuery;
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.apache.lucene.search.IndexSearcher;
13+
import org.apache.lucene.search.join.ScoreMode;
14+
import org.apache.lucene.search.join.ToParentBlockJoinQuery;
15+
import org.elasticsearch.common.bytes.BytesReference;
16+
import org.elasticsearch.index.mapper.MapperServiceTestCase;
17+
import org.elasticsearch.index.mapper.ParsedDocument;
18+
import org.elasticsearch.index.mapper.SourceToParse;
19+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
20+
import org.elasticsearch.xcontent.XContentBuilder;
21+
import org.elasticsearch.xcontent.XContentType;
22+
23+
import java.io.IOException;
24+
import java.util.ArrayList;
25+
import java.util.List;
26+
import java.util.Locale;
27+
import java.util.Set;
28+
import java.util.TreeMap;
29+
30+
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
31+
import static org.hamcrest.Matchers.equalTo;
32+
import static org.hamcrest.Matchers.instanceOf;
33+
34+
public class DiversifyingParentBlockQueryTests extends MapperServiceTestCase {
35+
private static String getMapping(int dim) {
36+
return String.format(Locale.ROOT, """
37+
{
38+
"_doc": {
39+
"properties": {
40+
"id": {
41+
"type": "keyword",
42+
"store": true
43+
},
44+
"nested": {
45+
"type": "nested",
46+
"properties": {
47+
"emb": {
48+
"type": "dense_vector",
49+
"dims": %d,
50+
"similarity": "l2_norm",
51+
"index_options": {
52+
"type": "flat"
53+
}
54+
}
55+
}
56+
}
57+
}
58+
}
59+
}
60+
}
61+
""", dim);
62+
}
63+
64+
public void testRandom() throws IOException {
65+
int dims = randomIntBetween(3, 10);
66+
var mapperService = createMapperService(getMapping(dims));
67+
var fieldType = (DenseVectorFieldMapper.DenseVectorFieldType) mapperService.fieldType("nested.emb");
68+
var nestedParent = mapperService.mappingLookup().nestedLookup().getNestedMappers().get("nested");
69+
70+
int numQueries = randomIntBetween(1, 3);
71+
float[][] queries = new float[numQueries][];
72+
List<TreeMap<Float, String>> expectedTopDocs = new ArrayList<>();
73+
for (int i = 0; i < numQueries; i++) {
74+
queries[i] = randomVector(dims);
75+
expectedTopDocs.add(new TreeMap<>((o1, o2) -> -Float.compare(o1, o2)));
76+
}
77+
78+
withLuceneIndex(mapperService, iw -> {
79+
int numDocs = randomIntBetween(10, 50);
80+
for (int i = 0; i < numDocs; i++) {
81+
int numVectors = randomIntBetween(0, 5);
82+
float[][] vectors = new float[numVectors][];
83+
for (int j = 0; j < numVectors; j++) {
84+
vectors[j] = randomVector(dims);
85+
}
86+
87+
for (int k = 0; k < numQueries; k++) {
88+
float maxScore = Float.MIN_VALUE;
89+
for (int j = 0; j < numVectors; j++) {
90+
float score = EUCLIDEAN.compare(vectors[j], queries[k]);
91+
maxScore = Math.max(score, maxScore);
92+
}
93+
expectedTopDocs.get(k).put(maxScore, Integer.toString(i));
94+
}
95+
96+
SourceToParse source = randomSource(Integer.toString(i), vectors);
97+
ParsedDocument doc = mapperService.documentMapper().parse(source);
98+
iw.addDocuments(doc.docs());
99+
100+
if (randomBoolean()) {
101+
int numEmpty = randomIntBetween(1, 3);
102+
for (int l = 0; l < numEmpty; l++) {
103+
source = randomSource(randomAlphaOfLengthBetween(5, 10), new float[0][]);
104+
doc = mapperService.documentMapper().parse(source);
105+
iw.addDocuments(doc.docs());
106+
}
107+
}
108+
}
109+
}, ir -> {
110+
var storedFields = ir.storedFields();
111+
var searcher = new IndexSearcher(wrapInMockESDirectoryReader(ir));
112+
var context = createSearchExecutionContext(mapperService);
113+
var bitSetproducer = context.bitsetFilter(nestedParent.parentTypeFilter());
114+
for (int i = 0; i < numQueries; i++) {
115+
var knnQuery = fieldType.createKnnQuery(
116+
VectorData.fromFloats(queries[i]),
117+
10,
118+
10,
119+
null,
120+
null,
121+
null,
122+
bitSetproducer,
123+
DenseVectorFieldMapper.FilterHeuristic.ACORN
124+
);
125+
assertThat(knnQuery, instanceOf(DiversifyingParentBlockQuery.class));
126+
var nestedQuery = new ToParentBlockJoinQuery(knnQuery, bitSetproducer, ScoreMode.Total);
127+
var topDocs = searcher.search(nestedQuery, 10);
128+
for (var doc : topDocs.scoreDocs) {
129+
var entry = expectedTopDocs.get(i).pollFirstEntry();
130+
assertNotNull(entry);
131+
assertThat(doc.score, equalTo(entry.getKey()));
132+
var storedDoc = storedFields.document(doc.doc, Set.of("id"));
133+
assertThat(storedDoc.getField("id").binaryValue().utf8ToString(), equalTo(entry.getValue()));
134+
}
135+
}
136+
});
137+
}
138+
139+
private SourceToParse randomSource(String id, float[][] vectors) throws IOException {
140+
try (var builder = XContentBuilder.builder(XContentType.JSON.xContent())) {
141+
builder.startObject();
142+
builder.field("id", id);
143+
builder.startArray("nested");
144+
for (int i = 0; i < vectors.length; i++) {
145+
builder.startObject();
146+
builder.field("emb", vectors[i]);
147+
builder.endObject();
148+
}
149+
builder.endArray();
150+
builder.endObject();
151+
return new SourceToParse(id, BytesReference.bytes(builder), XContentType.JSON);
152+
}
153+
}
154+
155+
private float[] randomVector(int dim) {
156+
float[] vector = new float[dim];
157+
for (int i = 0; i < vector.length; i++) {
158+
vector[i] = randomFloat();
159+
}
160+
return vector;
161+
}
162+
}

0 commit comments

Comments
 (0)