diff --git a/src/main/knn/KnnGraphTester.java b/src/main/knn/KnnGraphTester.java index 90b7aa89..3acecde5 100644 --- a/src/main/knn/KnnGraphTester.java +++ b/src/main/knn/KnnGraphTester.java @@ -84,6 +84,7 @@ import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.ConstantScoreScorer; import org.apache.lucene.search.ConstantScoreWeight; +import org.apache.lucene.search.FilterDocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; @@ -1426,9 +1427,16 @@ long totalVectorCount() { private static class BitSetQuery extends Query { private final BitSet[] segmentDocs; + private final int[] cardinalities; + private final int hash; BitSetQuery(BitSet[] segmentDocs) { this.segmentDocs = segmentDocs; + this.cardinalities = new int[segmentDocs.length]; + for (int i = 0; i < segmentDocs.length; i++) { + cardinalities[i] = segmentDocs[i].cardinality(); + } + this.hash = Arrays.hashCode(segmentDocs); } @Override @@ -1436,8 +1444,12 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo return new ConstantScoreWeight(this, boost) { public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { var bitSet = segmentDocs[context.ord]; - var cardinality = bitSet.cardinality(); - var scorer = new ConstantScoreScorer(score(), scoreMode, new BitSetIterator(bitSet, cardinality)); + var cardinality = cardinalities[context.ord]; + var scorer = new ConstantScoreScorer( + score(), + scoreMode, + // wrap it to simulate a more realistic query that must iterate its docs + new FilterDocIdSetIterator(new BitSetIterator(bitSet, cardinality))); return new ScorerSupplier() { @Override public Scorer get(long leadCost) throws IOException { @@ -1474,7 +1486,7 @@ public boolean equals(Object other) { @Override public int hashCode() { - return 31 * classHash() + Arrays.hashCode(segmentDocs); + return 31 * classHash() + hash; } } }