Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,16 @@ record CmdLineArgs(
int quantizeBits,
VectorEncoding vectorEncoding,
int dimensions,
boolean earlyTermination
boolean earlyTermination,
FILTER_KIND filterKind,
boolean sortIndex
) implements ToXContentObject {

public enum FILTER_KIND {
RANDOM,
RANGE
}

static final ParseField DOC_VECTORS_FIELD = new ParseField("doc_vectors");
static final ParseField QUERY_VECTORS_FIELD = new ParseField("query_vectors");
static final ParseField NUM_DOCS_FIELD = new ParseField("num_docs");
Expand All @@ -79,6 +86,8 @@ record CmdLineArgs(
static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination");
static final ParseField FILTER_SELECTIVITY_FIELD = new ParseField("filter_selectivity");
static final ParseField SEED_FIELD = new ParseField("seed");
static final ParseField FILTER_KIND_FIELD = new ParseField("filter_kind");
static final ParseField SORT_INDEX_FIELD = new ParseField("sort_index");

static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
Builder builder = PARSER.apply(parser, null);
Expand Down Expand Up @@ -112,6 +121,8 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
PARSER.declareBoolean(Builder::setEarlyTermination, EARLY_TERMINATION_FIELD);
PARSER.declareFloat(Builder::setFilterSelectivity, FILTER_SELECTIVITY_FIELD);
PARSER.declareLong(Builder::setSeed, SEED_FIELD);
PARSER.declareString(Builder::setFilterKind, FILTER_KIND_FIELD);
PARSER.declareBoolean(Builder::setSortIndex, SORT_INDEX_FIELD);
}

@Override
Expand Down Expand Up @@ -179,6 +190,8 @@ static class Builder {
private boolean earlyTermination;
private float filterSelectivity = 1f;
private long seed = 1751900822751L;
private FILTER_KIND filterKind = FILTER_KIND.RANDOM;
private boolean sortIndex = false;

public Builder setDocVectors(List<String> docVectors) {
if (docVectors == null || docVectors.isEmpty()) {
Expand Down Expand Up @@ -304,6 +317,16 @@ public Builder setSeed(long seed) {
return this;
}

public Builder setFilterKind(String filterKind) {
this.filterKind = FILTER_KIND.valueOf(filterKind.toUpperCase(Locale.ROOT));
return this;
}

public Builder setSortIndex(boolean sortIndex) {
this.sortIndex = sortIndex;
return this;
}

public CmdLineArgs build() {
if (docVectors == null) {
throw new IllegalArgumentException("Document vectors path must be provided");
Expand Down Expand Up @@ -337,7 +360,9 @@ public CmdLineArgs build() {
quantizeBits,
vectorEncoding,
dimensions,
earlyTermination
earlyTermination,
filterKind,
sortIndex
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ public static void main(String[] args) throws Exception {
cmdLineArgs.vectorEncoding(),
cmdLineArgs.dimensions(),
cmdLineArgs.vectorSpace(),
cmdLineArgs.numDocs()
cmdLineArgs.numDocs(),
cmdLineArgs.sortIndex()
);
if (cmdLineArgs.reindex() == false && Files.exists(indexPath) == false) {
throw new IllegalArgumentException("Index path does not exist: " + indexPath);
Expand Down
13 changes: 11 additions & 2 deletions qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.ConcurrentMergeScheduler;
import org.apache.lucene.index.DirectoryReader;
Expand All @@ -33,6 +34,8 @@
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.PrintStreamInfoStream;
import org.elasticsearch.common.io.Channels;
Expand Down Expand Up @@ -69,6 +72,7 @@ class KnnIndexer {
private final Codec codec;
private final int numDocs;
private final int numIndexThreads;
private final boolean sortIndex;

KnnIndexer(
List<Path> docsPath,
Expand All @@ -78,7 +82,8 @@ class KnnIndexer {
VectorEncoding vectorEncoding,
int dim,
VectorSimilarityFunction similarityFunction,
int numDocs
int numDocs,
boolean sortIndex
) {
this.docsPath = docsPath;
this.indexPath = indexPath;
Expand All @@ -88,6 +93,7 @@ class KnnIndexer {
this.dim = dim;
this.similarityFunction = similarityFunction;
this.numDocs = numDocs;
this.sortIndex = sortIndex;
}

void numSegments(KnnIndexTester.Results result) {
Expand All @@ -103,7 +109,9 @@ void createIndex(KnnIndexTester.Results result) throws IOException, InterruptedE
iwc.setCodec(codec);
iwc.setRAMBufferSizeMB(WRITER_BUFFER_MB);
iwc.setUseCompoundFile(false);

if (sortIndex) {
iwc.setIndexSort(new Sort(new SortField(ID_FIELD + "_sort", SortField.Type.LONG, false)));
}
iwc.setMaxFullFlushMergeWaitMillis(0);

iwc.setInfoStream(new PrintStreamInfoStream(System.out) {
Expand Down Expand Up @@ -292,6 +300,7 @@ private void _run() throws IOException {
logger.debug("Done indexing " + (id + 1) + " documents.");
}
doc.add(new StoredField(ID_FIELD, id));
doc.add(new NumericDocValuesField(ID_FIELD + "_sort", id));
iw.addDocument(doc);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class KnnSearcher {
private final float overSamplingFactor;
private final int searchThreads;
private final int numSearchers;
private final CmdLineArgs.FILTER_KIND filterKind;

KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, int nProbe) {
this.docPath = cmdLineArgs.docVectors();
Expand All @@ -137,10 +138,13 @@ class KnnSearcher {
this.numSearchers = cmdLineArgs.numSearchers();
this.randomSeed = cmdLineArgs.seed();
this.selectivity = cmdLineArgs.filterSelectivity();
this.filterKind = cmdLineArgs.filterKind();
}

void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) throws IOException {
Query filterQuery = this.selectivity < 1f ? generateRandomQuery(new Random(randomSeed), indexPath, numDocs, selectivity) : null;
Query filterQuery = this.selectivity < 1f
? generateRandomQuery(new Random(randomSeed), indexPath, numDocs, selectivity, filterKind)
: null;
TopDocs[] results = new TopDocs[numQueryVectors];
int[][] resultIds = new int[numQueryVectors][];
long elapsed, totalCpuTimeMS, totalVisited = 0;
Expand Down Expand Up @@ -307,14 +311,27 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th
finalResults.avgCpuCount = (double) totalCpuTimeMS / elapsed;
}

private static Query generateRandomQuery(Random random, Path indexPath, int size, float selectivity) throws IOException {
private static Query generateRandomQuery(Random random, Path indexPath, int size, float selectivity, CmdLineArgs.FILTER_KIND filterKind)
throws IOException {
FixedBitSet bitSet = new FixedBitSet(size);
for (int i = 0; i < size; i++) {
if (random.nextFloat() < selectivity) {
bitSet.set(i);
} else {
bitSet.clear(i);
}
switch (filterKind) {
case RANDOM:
for (int i = 0; i < size; i++) {
if (random.nextFloat() < selectivity) {
bitSet.set(i);
} else {
bitSet.clear(i);
}
}
break;
case RANGE:
int rangeBound = (int) (size * selectivity);
// set a random range of bits of length rangeBound
int start = random.nextInt(size - rangeBound);
for (int i = start; i < start + rangeBound; i++) {
bitSet.set(i);
}
break;
}

try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
Expand Down Expand Up @@ -346,6 +363,7 @@ private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes, Query filterQue
topK,
similarityFunction.ordinal(),
selectivity,
filterKind,
randomSeed
),
36
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.GroupVIntUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
Expand Down Expand Up @@ -177,14 +178,13 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
final int[] correctionsSum = new int[BULK_SIZE];
final float[] correctionsAdd = new float[BULK_SIZE];

int[] docIdsScratch = new int[0];
int vectors;
int[] docIdsScratch = new int[0], spilledDocIdsScratch = new int[0];
int vectors, spilledVectors;
boolean quantized = false;
float centroidDp;
final float[] centroid;
long slicePos;
private long slicePos;
OptimizedScalarQuantizer.QuantizationResult queryCorrections;
DocIdsWriter docIdsWriter = new DocIdsWriter();

final float[] scratch;
final int[] quantizationScratch;
Expand Down Expand Up @@ -222,18 +222,28 @@ public int resetPostingsScorer(long offset) throws IOException {
indexInput.seek(offset);
indexInput.readFloats(centroid, 0, centroid.length);
centroidDp = Float.intBitsToFloat(indexInput.readInt());
vectors = indexInput.readVInt();
vectors = indexInput.readInt();
spilledVectors = indexInput.readInt();
// read the doc ids
docIdsScratch = vectors > docIdsScratch.length ? new int[vectors] : docIdsScratch;
docIdsWriter.readInts(indexInput, vectors, docIdsScratch);
spilledDocIdsScratch = spilledVectors > spilledDocIdsScratch.length ? new int[spilledVectors] : spilledDocIdsScratch;
GroupVIntUtil.readGroupVInts(indexInput, docIdsScratch, vectors);
GroupVIntUtil.readGroupVInts(indexInput, spilledDocIdsScratch, spilledVectors);
// reconstitute from the deltas
for (int i = 1; i < vectors; i++) {
docIdsScratch[i] += docIdsScratch[i - 1];
}
for (int i = 1; i < spilledVectors; i++) {
spilledDocIdsScratch[i] += spilledDocIdsScratch[i - 1];
}
slicePos = indexInput.getFilePointer();
return vectors;
}

void scoreIndividually(int offset) throws IOException {
private void scoreIndividually(int offset, long slicePos, int[] docIds) throws IOException {
// score individually, first the quantized byte chunk
for (int j = 0; j < BULK_SIZE; j++) {
int doc = docIdsScratch[j + offset];
int doc = docIds[j + offset];
if (doc != -1) {
indexInput.seek(slicePos + (offset * quantizedByteLength) + (j * quantizedVectorByteSize));
float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch);
Expand All @@ -250,7 +260,7 @@ void scoreIndividually(int offset) throws IOException {
indexInput.readFloats(correctionsAdd, 0, BULK_SIZE);
// Now apply corrections
for (int j = 0; j < BULK_SIZE; j++) {
int doc = docIdsScratch[offset + j];
int doc = docIds[offset + j];
if (doc != -1) {
scores[j] = osqVectorsScorer.score(
queryCorrections.lowerInterval(),
Expand All @@ -271,16 +281,25 @@ void scoreIndividually(int offset) throws IOException {

@Override
public int visit(KnnCollector knnCollector) throws IOException {
int scored = scoreDocs(knnCollector, vectors, 0, docIdsScratch);
if (spilledVectors > 0) {
scored += scoreDocs(knnCollector, spilledVectors, vectors * quantizedByteLength, spilledDocIdsScratch);
}
return scored;
}

private int scoreDocs(KnnCollector knnCollector, int count, long sliceOffset, int[] docIds) throws IOException {
// block processing
int scoredDocs = 0;
int limit = vectors - BULK_SIZE + 1;
int limit = count - BULK_SIZE + 1;
int i = 0;
long slicePos = this.slicePos + (int) sliceOffset;
for (; i < limit; i += BULK_SIZE) {
int docsToScore = BULK_SIZE;
for (int j = 0; j < BULK_SIZE; j++) {
int doc = docIdsScratch[i + j];
int doc = docIds[i + j];
if (needsScoring.test(doc) == false) {
docIdsScratch[i + j] = -1;
docIds[i + j] = -1;
docsToScore--;
}
}
Expand All @@ -290,7 +309,7 @@ public int visit(KnnCollector knnCollector) throws IOException {
quantizeQueryIfNecessary();
indexInput.seek(slicePos + i * quantizedByteLength);
if (docsToScore < BULK_SIZE / 2) {
scoreIndividually(i);
scoreIndividually(i, slicePos, docIds);
} else {
osqVectorsScorer.scoreBulk(
quantizedQueryScratch,
Expand All @@ -304,16 +323,16 @@ public int visit(KnnCollector knnCollector) throws IOException {
);
}
for (int j = 0; j < BULK_SIZE; j++) {
int doc = docIdsScratch[i + j];
int doc = docIds[i + j];
if (doc != -1) {
scoredDocs++;
knnCollector.collect(doc, scores[j]);
}
}
}
// process tail
for (; i < vectors; i++) {
int doc = docIdsScratch[i];
for (; i < count; i++) {
int doc = docIds[i];
if (needsScoring.test(doc)) {
quantizeQueryIfNecessary();
indexInput.seek(slicePos + i * quantizedByteLength);
Expand Down
Loading