Skip to content

Commit 6f85cd9

Browse files
committed
adding tests
1 parent ed8940a commit 6f85cd9

File tree

6 files changed

+95
-10
lines changed

6 files changed

+95
-10
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import org.apache.lucene.util.VectorUtil;
2121
import org.apache.lucene.util.hnsw.NeighborQueue;
2222
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
23-
import org.elasticsearch.logging.LogManager;
24-
import org.elasticsearch.logging.Logger;
2523
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
2624
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
2725
import org.elasticsearch.simdvec.ESVectorUtil;
@@ -42,8 +40,6 @@
4240
*/
4341
public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeapStats {
4442

45-
static Logger logger = LogManager.getLogger(DefaultIVFVectorsReader.class);
46-
4743
// The percentage of centroids that are scored to keep recall
4844
public static final double CENTROID_SAMPLING_PERCENTAGE = 0.2;
4945

server/src/main/java/org/elasticsearch/search/vectors/DiversifiedIVFKnnCollectorManager.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111

1212
import org.apache.lucene.index.LeafReaderContext;
1313
import org.apache.lucene.search.IndexSearcher;
14-
import org.apache.lucene.search.KnnCollector;
1514
import org.apache.lucene.search.join.BitSetProducer;
16-
import org.apache.lucene.search.knn.KnnCollectorManager;
1715
import org.apache.lucene.search.knn.KnnSearchStrategy;
1816
import org.apache.lucene.util.BitSet;
1917

@@ -30,7 +28,8 @@ public class DiversifiedIVFKnnCollectorManager extends AbstractIVFKnnVectorQuery
3028
}
3129

3230
@Override
33-
public AbstractMaxScoreKnnCollector newCollector(int visitedLimit, KnnSearchStrategy searchStrategy, LeafReaderContext context) throws IOException {
31+
public AbstractMaxScoreKnnCollector newCollector(int visitedLimit, KnnSearchStrategy searchStrategy, LeafReaderContext context)
32+
throws IOException {
3433
BitSet parentBitSet = parentsFilter.getBitSet(context);
3534
if (parentBitSet == null) {
3635
return null;

server/src/main/java/org/elasticsearch/search/vectors/DiversifyingNearestChildrenKnnCollector.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ public int numCollected() {
105105

106106
@Override
107107
public long getMinCompetitiveDocScore() {
108-
return heap.size() > 0 ? Math.max(NeighborQueue.encodeRaw(heap.topNode(), heap.topScore()), minCompetitiveDocScore) : minCompetitiveDocScore;
108+
return heap.size() > 0
109+
? Math.max(NeighborQueue.encodeRaw(heap.topNode(), heap.topScore()), minCompetitiveDocScore)
110+
: minCompetitiveDocScore;
109111
}
110112

111113
@Override

server/src/main/java/org/elasticsearch/search/vectors/MaxScoreTopKnnCollector.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
import org.apache.lucene.search.knn.KnnSearchStrategy;
1616
import org.elasticsearch.index.codec.vectors.cluster.NeighborQueue;
1717

18-
public class MaxScoreTopKnnCollector extends AbstractMaxScoreKnnCollector {
18+
class MaxScoreTopKnnCollector extends AbstractMaxScoreKnnCollector {
1919

2020
private long minCompetitiveDocScore;
2121
private float minCompetitiveSimilarity;
2222
protected final NeighborQueue queue;
2323

24-
public MaxScoreTopKnnCollector(int k, long visitLimit, KnnSearchStrategy searchStrategy) {
24+
MaxScoreTopKnnCollector(int k, long visitLimit, KnnSearchStrategy searchStrategy) {
2525
super(k, visitLimit, searchStrategy);
2626
this.minCompetitiveDocScore = LEAST_COMPETITIVE;
2727
this.minCompetitiveSimilarity = Float.NEGATIVE_INFINITY;
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.elasticsearch.index.codec.vectors.cluster.NeighborQueue;
13+
import org.elasticsearch.test.ESTestCase;
14+
15+
import java.util.concurrent.atomic.LongAccumulator;
16+
17+
public class IVFKnnSearchStrategyTests extends ESTestCase {
18+
19+
public void testMaxScorePropagation() {
20+
LongAccumulator accumulator = new LongAccumulator(Long::max, AbstractMaxScoreKnnCollector.LEAST_COMPETITIVE);
21+
IVFKnnSearchStrategy strategy = new IVFKnnSearchStrategy(0.5f, accumulator);
22+
MaxScoreTopKnnCollector collector = new MaxScoreTopKnnCollector(2, 1000, strategy);
23+
strategy.setCollector(collector);
24+
25+
collector.collect(1, 0.9f);
26+
long competitiveScore = NeighborQueue.encodeRaw(1, 0.9f);
27+
28+
// accumulator should now be updated
29+
strategy.nextVectorsBlock();
30+
assertEquals(competitiveScore, accumulator.get());
31+
assertEquals(competitiveScore, collector.getMinCompetitiveDocScore());
32+
33+
// updated accumulator directly with more competitive score
34+
competitiveScore = NeighborQueue.encodeRaw(2, 1.5f);
35+
accumulator.accumulate(competitiveScore);
36+
assertEquals(competitiveScore, accumulator.get());
37+
strategy.nextVectorsBlock();
38+
assertEquals(competitiveScore, collector.getMinCompetitiveDocScore());
39+
assertEquals(competitiveScore, accumulator.get());
40+
}
41+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.elasticsearch.index.codec.vectors.cluster.NeighborQueue;
13+
import org.elasticsearch.test.ESTestCase;
14+
15+
import java.util.concurrent.atomic.LongAccumulator;
16+
17+
public class MaxScoreTopKnnCollectorTests extends ESTestCase {
18+
19+
public void testMaxScorePropagation() {
20+
LongAccumulator accumulator = new LongAccumulator(Long::max, AbstractMaxScoreKnnCollector.LEAST_COMPETITIVE);
21+
MaxScoreTopKnnCollector collector = new MaxScoreTopKnnCollector(2, 1000, new IVFKnnSearchStrategy(0.5f, accumulator));
22+
long competitiveScore = NeighborQueue.encodeRaw(1, 0.9f);
23+
24+
collector.updateMinCompetitiveDocScore(competitiveScore);
25+
assertEquals(competitiveScore, collector.getMinCompetitiveDocScore());
26+
// haven't collected k results
27+
assertTrue(Float.NEGATIVE_INFINITY == collector.minCompetitiveSimilarity());
28+
collector.collect(2, 1.5f);
29+
assertTrue(Float.NEGATIVE_INFINITY == collector.minCompetitiveSimilarity());
30+
31+
// we always provide the min competitive that this collector collected
32+
assertEquals(NeighborQueue.encodeRaw(2, 1.5f), collector.getMinCompetitiveDocScore());
33+
collector.collect(3, 1.9f);
34+
35+
// min competitive for this collector is more than global
36+
assertEquals(NeighborQueue.encodeRaw(2, 1.5f), collector.getMinCompetitiveDocScore());
37+
38+
// we have collected k results, min competitive is the min value collected
39+
assertEquals(1.5f, collector.minCompetitiveSimilarity(), 0.0f);
40+
// we update the global min competitive doc score with a new value
41+
competitiveScore = NeighborQueue.encodeRaw(4, 4f);
42+
collector.updateMinCompetitiveDocScore(competitiveScore);
43+
assertEquals(competitiveScore, collector.getMinCompetitiveDocScore());
44+
assertEquals(4f, collector.minCompetitiveSimilarity(), 0.0f);
45+
}
46+
47+
}

0 commit comments

Comments
 (0)