Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ record CmdLineArgs(
int quantizeBits,
VectorEncoding vectorEncoding,
int dimensions,
boolean earlyTermination
boolean earlyTermination,
String mergePolicy
) implements ToXContentObject {

static final ParseField DOC_VECTORS_FIELD = new ParseField("doc_vectors");
Expand All @@ -79,6 +80,7 @@ record CmdLineArgs(
static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination");
static final ParseField FILTER_SELECTIVITY_FIELD = new ParseField("filter_selectivity");
static final ParseField SEED_FIELD = new ParseField("seed");
static final ParseField MERGE_POLICY_FIELD = new ParseField("merge_policy");

static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
Builder builder = PARSER.apply(parser, null);
Expand Down Expand Up @@ -112,6 +114,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
PARSER.declareBoolean(Builder::setEarlyTermination, EARLY_TERMINATION_FIELD);
PARSER.declareFloat(Builder::setFilterSelectivity, FILTER_SELECTIVITY_FIELD);
PARSER.declareLong(Builder::setSeed, SEED_FIELD);
PARSER.declareString(Builder::setMergePolicy, MERGE_POLICY_FIELD);
}

@Override
Expand Down Expand Up @@ -179,6 +182,7 @@ static class Builder {
private boolean earlyTermination;
private float filterSelectivity = 1f;
private long seed = 1751900822751L;
private String mergePolicy = null;

public Builder setDocVectors(List<String> docVectors) {
if (docVectors == null || docVectors.isEmpty()) {
Expand Down Expand Up @@ -304,6 +308,11 @@ public Builder setSeed(long seed) {
return this;
}

public Builder setMergePolicy(String mergePolicy) {
this.mergePolicy = mergePolicy;
return this;
}

public CmdLineArgs build() {
if (docVectors == null) {
throw new IllegalArgumentException("Document vectors path must be provided");
Expand Down Expand Up @@ -337,7 +346,8 @@ public CmdLineArgs build() {
quantizeBits,
vectorEncoding,
dimensions,
earlyTermination
earlyTermination,
mergePolicy
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene101.Lucene101Codec;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.index.LogByteSizeMergePolicy;
import org.apache.lucene.index.MergePolicy;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.TieredMergePolicy;
import org.elasticsearch.cli.ProcessInfo;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.logging.LogConfigurator;
Expand Down Expand Up @@ -196,6 +200,16 @@ public static void main(String[] args) throws Exception {
logger.info("Running KNN index tester with arguments: " + cmdLineArgs);
Codec codec = createCodec(cmdLineArgs);
Path indexPath = PathUtils.get(formatIndexPath(cmdLineArgs));
MergePolicy mergePolicy = null;
if (cmdLineArgs.mergePolicy() != null && cmdLineArgs.mergePolicy().isEmpty() == false) {
if ("tmp".equalsIgnoreCase(cmdLineArgs.mergePolicy())) {
mergePolicy = new TieredMergePolicy();
} else if ("lbmp".equalsIgnoreCase(cmdLineArgs.mergePolicy())) {
mergePolicy = new LogByteSizeMergePolicy();
} else if ("no".equalsIgnoreCase(cmdLineArgs.mergePolicy())) {
mergePolicy = NoMergePolicy.INSTANCE;
}
}
if (cmdLineArgs.reindex() || cmdLineArgs.forceMerge()) {
KnnIndexer knnIndexer = new KnnIndexer(
cmdLineArgs.docVectors(),
Expand All @@ -205,7 +219,8 @@ public static void main(String[] args) throws Exception {
cmdLineArgs.vectorEncoding(),
cmdLineArgs.dimensions(),
cmdLineArgs.vectorSpace(),
cmdLineArgs.numDocs()
cmdLineArgs.numDocs(),
mergePolicy
);
if (cmdLineArgs.reindex() == false && Files.exists(indexPath) == false) {
throw new IllegalArgumentException("Index path does not exist: " + indexPath);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.MergePolicy;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.FSDirectory;
Expand Down Expand Up @@ -64,6 +65,7 @@ class KnnIndexer {
private final List<Path> docsPath;
private final Path indexPath;
private final VectorEncoding vectorEncoding;
private final MergePolicy mergePolicy;
private int dim;
private final VectorSimilarityFunction similarityFunction;
private final Codec codec;
Expand All @@ -78,7 +80,8 @@ class KnnIndexer {
VectorEncoding vectorEncoding,
int dim,
VectorSimilarityFunction similarityFunction,
int numDocs
int numDocs,
MergePolicy mergePolicy
) {
this.docsPath = docsPath;
this.indexPath = indexPath;
Expand All @@ -88,6 +91,7 @@ class KnnIndexer {
this.dim = dim;
this.similarityFunction = similarityFunction;
this.numDocs = numDocs;
this.mergePolicy = mergePolicy;
}

void numSegments(KnnIndexTester.Results result) {
Expand All @@ -104,6 +108,9 @@ void createIndex(KnnIndexTester.Results result) throws IOException, InterruptedE
iwc.setRAMBufferSizeMB(WRITER_BUFFER_MB);
iwc.setUseCompoundFile(false);

if (mergePolicy != null) {
iwc.setMergePolicy(mergePolicy);
}
iwc.setMaxFullFlushMergeWaitMillis(0);

iwc.setInfoStream(new PrintStreamInfoStream(System.out) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.simdvec.ESVectorUtil;

import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.function.IntPredicate;

Expand All @@ -48,7 +49,7 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect
}

@Override
CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery, int nProbe)
throws IOException {
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
final float globalCentroidDp = fieldEntry.globalCentroidDp();
Expand All @@ -68,6 +69,8 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind
return new CentroidQueryScorer() {
int currentCentroid = -1;
long postingListOffset;
float diff = Float.NaN;

private final float[] centroidCorrectiveValues = new float[3];
private final long quantizeCentroidsLength = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES
+ Short.BYTES);
Expand All @@ -90,11 +93,30 @@ public long postingListOffset(int centroidOrdinal) throws IOException {
public void bulkScore(NeighborQueue queue) throws IOException {
// TODO: bulk score centroids like we do with posting lists
centroids.seek(0L);
float[] centroidsScratch = null;
if (numCentroids > nProbe) {
centroidsScratch = new float[numCentroids];
}
for (int i = 0; i < numCentroids; i++) {
queue.add(i, score());
float score = score();
queue.add(i, score);
if (numCentroids > nProbe) {
centroidsScratch[i] = score;
}
}
if (numCentroids > nProbe) {
Arrays.sort(centroidsScratch);
float topScore = centroidsScratch[nProbe - 1];
float nprobeScore = centroidsScratch[0];
diff = (topScore - nprobeScore) / topScore;
}
}

@Override
public float scoreRatioAtNprobe() {
return diff;
}

private float score() throws IOException {
final float qcDist = scorer.int4DotProduct(quantized);
centroids.readFloats(centroidCorrectiveValues, 0, 3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsR
}
}

abstract CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target)
abstract CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target, int nProbe)
throws IOException;

private static IndexInput openDataInput(
Expand Down Expand Up @@ -236,22 +236,34 @@ public final void search(String field, float[] target, KnnCollector knnCollector
}

FieldEntry entry = fields.get(fieldInfo.number);
CentroidQueryScorer centroidQueryScorer = getCentroidScorer(
fieldInfo,
entry.numCentroids,
entry.centroidSlice(ivfCentroids),
target
);
int numCentroids = entry.numCentroids;
if (nProbe == DYNAMIC_NPROBE) {
// empirically based, and a good dynamic to get decent recall while scaling a la "efSearch"
// scaling by the number of centroids vs. the nearest neighbors requested
// not perfect, but a comparative heuristic.
// we might want to utilize the total vector count as well, but this is a good start
nProbe = (int) Math.round(Math.log10(centroidQueryScorer.size()) * Math.sqrt(knnCollector.k()));
nProbe = (int) Math.round(Math.log10(numCentroids) * Math.sqrt(knnCollector.k()));
// clip to be between 1 and the number of centroids
nProbe = Math.max(Math.min(nProbe, centroidQueryScorer.size()), 1);
nProbe = Math.max(Math.min(nProbe, numCentroids), 1);
}

CentroidQueryScorer centroidQueryScorer = getCentroidScorer(
fieldInfo,
numCentroids,
entry.centroidSlice(ivfCentroids),
target,
nProbe
);

final NeighborQueue centroidQueue = scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe);

if (centroidQueue.size() > 2 && numCentroids > nProbe) {
// If the difference is small, increase nprobe to search more centroids
if (centroidQueryScorer.scoreRatioAtNprobe() < 0.001f) {
nProbe = (int) Math.min(nProbe * 2.0, numCentroids);
}
}

PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring);
int centroidsVisited = 0;
long expectedDocs = 0;
Expand Down Expand Up @@ -329,6 +341,9 @@ interface CentroidQueryScorer {
long postingListOffset(int centroidOrdinal) throws IOException;

void bulkScore(NeighborQueue queue) throws IOException;

float scoreRatioAtNprobe();

}

interface PostingVisitor {
Expand Down