Skip to content

Commit 7d083b7

Browse files
committed
Fix test, add coverage for byte element types
1 parent 9c9773f commit 7d083b7

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,8 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
3535
public void profile(QueryProfiler queryProfiler) {
3636
queryProfiler.addVectorOpsCount(vectorOpsCount);
3737
}
38+
39+
public Integer kParam() {
40+
return kParam;
41+
}
3842
}

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity;
2424
import org.elasticsearch.search.DocValueFormat;
2525
import org.elasticsearch.search.vectors.DenseVectorQuery;
26+
import org.elasticsearch.search.vectors.ESKnnByteVectorQuery;
2627
import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery;
2728
import org.elasticsearch.search.vectors.RescoreKnnVectorQuery;
2829
import org.elasticsearch.search.vectors.VectorData;
@@ -409,10 +410,11 @@ public void testByteCreateKnnQuery() {
409410
}
410411

411412
public void testRescoreOversampleUsedWithoutQuantization() {
413+
DenseVectorFieldMapper.ElementType elementType = randomFrom(FLOAT, BYTE);
412414
DenseVectorFieldType nonQuantizedField = new DenseVectorFieldType(
413415
"f",
414416
IndexVersion.current(),
415-
randomFrom(FLOAT, BYTE),
417+
elementType,
416418
3,
417419
true,
418420
VectorSimilarity.COSINE,
@@ -430,9 +432,15 @@ public void testRescoreOversampleUsedWithoutQuantization() {
430432
null
431433
);
432434

433-
ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery;
434-
assertThat(esKnnQuery.getK(), is(100));
435-
assertThat(esKnnQuery.kParam(), is(10));
435+
if (elementType == BYTE) {
436+
ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery;
437+
assertThat(esKnnQuery.getK(), is(100));
438+
assertThat(esKnnQuery.kParam(), is(10));
439+
} else {
440+
ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery;
441+
assertThat(esKnnQuery.getK(), is(100));
442+
assertThat(esKnnQuery.kParam(), is(10));
443+
}
436444
}
437445

438446
public void testRescoreOversampleModifiesKnnParams() {

0 commit comments

Comments
 (0)