Skip to content

Commit 39e1676

Browse files
committed
Add testing
1 parent f5080a6 commit 39e1676

File tree

2 files changed

+121
-15
lines changed

2 files changed

+121
-15
lines changed

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

Lines changed: 117 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,13 @@
2020
import org.elasticsearch.index.mapper.FieldTypeTestCase;
2121
import org.elasticsearch.index.mapper.MappedFieldType;
2222
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType;
23+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
2324
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity;
2425
import org.elasticsearch.search.DocValueFormat;
2526
import org.elasticsearch.search.vectors.DenseVectorQuery;
27+
import org.elasticsearch.search.vectors.ESKnnByteVectorQuery;
28+
import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery;
29+
import org.elasticsearch.search.vectors.RescoreKnnVectorQuery;
2630
import org.elasticsearch.search.vectors.VectorData;
2731

2832
import java.io.IOException;
@@ -31,8 +35,12 @@
3135
import java.util.Set;
3236

3337
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS;
38+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE;
39+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.FLOAT;
3440
import static org.hamcrest.Matchers.containsString;
41+
import static org.hamcrest.Matchers.equalTo;
3542
import static org.hamcrest.Matchers.instanceOf;
43+
import static org.hamcrest.Matchers.is;
3644

3745
public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
3846
private final boolean indexed;
@@ -69,11 +77,27 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsAll() {
6977
);
7078
}
7179

80+
private DenseVectorFieldMapper.IndexOptions randomIndexOptionsHnswQuantized() {
81+
return randomFrom(
82+
new DenseVectorFieldMapper.Int8HnswIndexOptions(
83+
randomIntBetween(1, 100),
84+
randomIntBetween(1, 10_000),
85+
randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true))
86+
),
87+
new DenseVectorFieldMapper.Int4HnswIndexOptions(
88+
randomIntBetween(1, 100),
89+
randomIntBetween(1, 10_000),
90+
randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true))
91+
),
92+
new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000))
93+
);
94+
}
95+
7296
private DenseVectorFieldType createFloatFieldType() {
7397
return new DenseVectorFieldType(
7498
"f",
7599
IndexVersion.current(),
76-
DenseVectorFieldMapper.ElementType.FLOAT,
100+
FLOAT,
77101
BBQ_MIN_DIMS,
78102
indexed,
79103
VectorSimilarity.COSINE,
@@ -86,7 +110,7 @@ private DenseVectorFieldType createByteFieldType() {
86110
return new DenseVectorFieldType(
87111
"f",
88112
IndexVersion.current(),
89-
DenseVectorFieldMapper.ElementType.BYTE,
113+
BYTE,
90114
5,
91115
true,
92116
VectorSimilarity.COSINE,
@@ -159,7 +183,7 @@ public void testCreateNestedKnnQuery() {
159183
DenseVectorFieldType field = new DenseVectorFieldType(
160184
"f",
161185
IndexVersion.current(),
162-
DenseVectorFieldMapper.ElementType.FLOAT,
186+
FLOAT,
163187
dims,
164188
true,
165189
VectorSimilarity.COSINE,
@@ -177,7 +201,7 @@ public void testCreateNestedKnnQuery() {
177201
DenseVectorFieldType field = new DenseVectorFieldType(
178202
"f",
179203
IndexVersion.current(),
180-
DenseVectorFieldMapper.ElementType.BYTE,
204+
BYTE,
181205
dims,
182206
true,
183207
VectorSimilarity.COSINE,
@@ -209,7 +233,7 @@ public void testExactKnnQuery() {
209233
DenseVectorFieldType field = new DenseVectorFieldType(
210234
"f",
211235
IndexVersion.current(),
212-
DenseVectorFieldMapper.ElementType.FLOAT,
236+
FLOAT,
213237
dims,
214238
true,
215239
VectorSimilarity.COSINE,
@@ -227,7 +251,7 @@ public void testExactKnnQuery() {
227251
DenseVectorFieldType field = new DenseVectorFieldType(
228252
"f",
229253
IndexVersion.current(),
230-
DenseVectorFieldMapper.ElementType.BYTE,
254+
BYTE,
231255
dims,
232256
true,
233257
VectorSimilarity.COSINE,
@@ -247,7 +271,7 @@ public void testFloatCreateKnnQuery() {
247271
DenseVectorFieldType unindexedField = new DenseVectorFieldType(
248272
"f",
249273
IndexVersion.current(),
250-
DenseVectorFieldMapper.ElementType.FLOAT,
274+
FLOAT,
251275
4,
252276
false,
253277
VectorSimilarity.COSINE,
@@ -271,7 +295,7 @@ public void testFloatCreateKnnQuery() {
271295
DenseVectorFieldType dotProductField = new DenseVectorFieldType(
272296
"f",
273297
IndexVersion.current(),
274-
DenseVectorFieldMapper.ElementType.FLOAT,
298+
FLOAT,
275299
BBQ_MIN_DIMS,
276300
true,
277301
VectorSimilarity.DOT_PRODUCT,
@@ -291,7 +315,7 @@ public void testFloatCreateKnnQuery() {
291315
DenseVectorFieldType cosineField = new DenseVectorFieldType(
292316
"f",
293317
IndexVersion.current(),
294-
DenseVectorFieldMapper.ElementType.FLOAT,
318+
FLOAT,
295319
BBQ_MIN_DIMS,
296320
true,
297321
VectorSimilarity.COSINE,
@@ -310,7 +334,7 @@ public void testCreateKnnQueryMaxDims() {
310334
DenseVectorFieldType fieldWith4096dims = new DenseVectorFieldType(
311335
"f",
312336
IndexVersion.current(),
313-
DenseVectorFieldMapper.ElementType.FLOAT,
337+
FLOAT,
314338
4096,
315339
true,
316340
VectorSimilarity.COSINE,
@@ -329,7 +353,7 @@ public void testCreateKnnQueryMaxDims() {
329353
DenseVectorFieldType fieldWith4096dims = new DenseVectorFieldType(
330354
"f",
331355
IndexVersion.current(),
332-
DenseVectorFieldMapper.ElementType.BYTE,
356+
BYTE,
333357
4096,
334358
true,
335359
VectorSimilarity.COSINE,
@@ -350,7 +374,7 @@ public void testByteCreateKnnQuery() {
350374
DenseVectorFieldType unindexedField = new DenseVectorFieldType(
351375
"f",
352376
IndexVersion.current(),
353-
DenseVectorFieldMapper.ElementType.BYTE,
377+
BYTE,
354378
3,
355379
false,
356380
VectorSimilarity.COSINE,
@@ -366,7 +390,7 @@ public void testByteCreateKnnQuery() {
366390
DenseVectorFieldType cosineField = new DenseVectorFieldType(
367391
"f",
368392
IndexVersion.current(),
369-
DenseVectorFieldMapper.ElementType.BYTE,
393+
BYTE,
370394
3,
371395
true,
372396
VectorSimilarity.COSINE,
@@ -385,4 +409,84 @@ public void testByteCreateKnnQuery() {
385409
);
386410
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
387411
}
412+
413+
public void testRescoreOversampleUsedWithoutQuantization() {
414+
ElementType elementType = randomFrom(BYTE, FLOAT);
415+
DenseVectorFieldType nonQuantizedField = new DenseVectorFieldType(
416+
"f",
417+
IndexVersion.current(),
418+
elementType,
419+
3,
420+
true,
421+
VectorSimilarity.COSINE,
422+
randomIndexOptionsNonQuantized(),
423+
Collections.emptyMap()
424+
);
425+
426+
Query knnQuery = nonQuantizedField.createKnnQuery(
427+
new VectorData(null, new byte[]{1, 4, 10}),
428+
10,
429+
100,
430+
randomFloatBetween(1.0F, 10.0F, false),
431+
null,
432+
null,
433+
null
434+
);
435+
436+
if (elementType == BYTE) {
437+
ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery;
438+
assertThat(esKnnQuery.getK(), is(100));
439+
assertThat(esKnnQuery.kParam(), is(10));
440+
} else {
441+
ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery;
442+
assertThat(esKnnQuery.getK(), is(100));
443+
assertThat(esKnnQuery.kParam(), is(10));
444+
}
445+
}
446+
447+
public void testRescoreOversampleModifiesKnnParams() {
448+
DenseVectorFieldType fieldType = new DenseVectorFieldType(
449+
"f",
450+
IndexVersion.current(),
451+
randomFrom(BYTE, FLOAT),
452+
3,
453+
true,
454+
VectorSimilarity.COSINE,
455+
randomIndexOptionsHnswQuantized(),
456+
Collections.emptyMap()
457+
);
458+
459+
// Total results is k, internal k is multiplied by oversample
460+
checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, 10, 25, 200);
461+
// If numCands < k, update numCands to k
462+
checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, 10, 25, 25);
463+
// Oversampling limit
464+
checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, 1000, 10000, 10000);
465+
checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, 5000, 10000, 10000);
466+
}
467+
468+
private static void checkRescoreQueryParameters(
469+
DenseVectorFieldType fieldType,
470+
int k,
471+
int candidates,
472+
float oversample,
473+
int expectedResults,
474+
int expectedK,
475+
int expectedCandidates
476+
) {
477+
Query query = fieldType.createKnnQuery(new VectorData(null, new byte[] { 1, 4, 10}), k, candidates, oversample, null, null, null);
478+
479+
RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query;
480+
if (fieldType.getElementType() == BYTE) {
481+
ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) rescoreQuery.innerQuery();
482+
assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults));
483+
assertThat("Unexpected k parameter", esKnnQuery.kParam(), equalTo(expectedK));
484+
assertThat("Unexpected candidates", esKnnQuery.getK(), equalTo(expectedCandidates));
485+
} else {
486+
ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery();
487+
assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults));
488+
assertThat("Unexpected k parameter", esKnnQuery.kParam(), equalTo(expectedK));
489+
assertThat("Unexpected candidates", esKnnQuery.getK(), equalTo(expectedCandidates));
490+
}
491+
}
388492
}

server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import org.elasticsearch.common.io.stream.StreamInput;
2424
import org.elasticsearch.index.mapper.MapperService;
2525
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
26+
import org.elasticsearch.index.mapper.vectors.VectorSimilarityByteValueSource;
27+
import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource;
2628
import org.elasticsearch.index.query.InnerHitsRewriteContext;
2729
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
2830
import org.elasticsearch.index.query.QueryBuilder;
@@ -127,8 +129,8 @@ protected RescoreVectorBuilder randomRescoreVectorBuilder() {
127129
@Override
128130
protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException {
129131
if (queryBuilder.rescoreVectorBuilder() != null) {
130-
assertTrue(query instanceof org.apache.lucene.queries.function.FunctionScoreQuery);
131-
query = ((org.apache.lucene.queries.function.FunctionScoreQuery) query).getWrappedQuery();
132+
RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query;
133+
query = rescoreQuery.innerQuery();
132134
}
133135
if (queryBuilder.getVectorSimilarity() != null) {
134136
assertTrue(query instanceof VectorSimilarityQuery);

0 commit comments

Comments
 (0)