Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,35 +68,23 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind
return new CentroidQueryScorer() {
int currentCentroid = -1;
long postingListOffset;
private final float[] centroid = new float[fieldInfo.getVectorDimension()];
private final float[] centroidCorrectiveValues = new float[3];
private final long rawCentroidsOffset = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES);
private final long rawCentroidsByteSize = (long) Float.BYTES * fieldInfo.getVectorDimension() + Long.BYTES;
private final long quantizeCentroidsLength = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES
+ Short.BYTES);

@Override
public int size() {
return numCentroids;
}

@Override
public float[] centroid(int centroidOrdinal) throws IOException {
readDataIfNecessary(centroidOrdinal);
return centroid;
}

@Override
public long postingListOffset(int centroidOrdinal) throws IOException {
readDataIfNecessary(centroidOrdinal);
return postingListOffset;
}

private void readDataIfNecessary(int centroidOrdinal) throws IOException {
if (centroidOrdinal != currentCentroid) {
centroids.seek(rawCentroidsOffset + rawCentroidsByteSize * centroidOrdinal);
centroids.readFloats(centroid, 0, centroid.length);
centroids.seek(quantizeCentroidsLength + (long) Long.BYTES * centroidOrdinal);
postingListOffset = centroids.readLong();
currentCentroid = centroidOrdinal;
}
return postingListOffset;
}

public void bulkScore(NeighborQueue queue) throws IOException {
Expand Down Expand Up @@ -193,7 +181,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
int vectors;
boolean quantized = false;
float centroidDp;
float[] centroid;
final float[] centroid;
long slicePos;
OptimizedScalarQuantizer.QuantizationResult queryCorrections;
DocIdsWriter docIdsWriter = new DocIdsWriter();
Expand All @@ -217,7 +205,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
this.entry = entry;
this.fieldInfo = fieldInfo;
this.needsScoring = needsScoring;

centroid = new float[fieldInfo.getVectorDimension()];
scratch = new float[target.length];
quantizationScratch = new int[target.length];
final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64);
Expand All @@ -229,12 +217,12 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
}

@Override
public int resetPostingsScorer(long offset, float[] centroid) throws IOException {
public int resetPostingsScorer(long offset) throws IOException {
quantized = false;
indexInput.seek(offset);
vectors = indexInput.readVInt();
indexInput.readFloats(centroid, 0, centroid.length);
centroidDp = Float.intBitsToFloat(indexInput.readInt());
this.centroid = centroid;
vectors = indexInput.readVInt();
// read the doc ids
docIdsScratch = vectors > docIdsScratch.length ? new int[vectors] : docIdsScratch;
docIdsWriter.readInts(indexInput, vectors, docIdsScratch);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,25 @@ LongValues buildAndWritePostingsLists(
fieldInfo.getVectorDimension(),
new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction())
);
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (int c = 0; c < centroidSupplier.size(); c++) {
float[] centroid = centroidSupplier.centroid(c);
int[] cluster = assignmentsByCluster[c];
// TODO align???
offsets.add(postingsOutput.getFilePointer());
offsets.add(postingsOutput.alignFilePointer(Float.BYTES));
buffer.asFloatBuffer().put(centroid);
// write raw centroid for quantizing the query vectors
postingsOutput.writeBytes(buffer.array(), buffer.array().length);
// write centroid dot product for quantizing the query vectors
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
int size = cluster.length;
// write docIds
postingsOutput.writeVInt(size);
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]);
// TODO we might want to consider putting the docIds in a separate file
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
// write vectors
bulkWriter.writeVectors(onHeapQuantizedVectors);
}

Expand Down Expand Up @@ -209,20 +215,26 @@ LongValues buildAndWritePostingsLists(
);
DocIdsWriter docIdsWriter = new DocIdsWriter();
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (int c = 0; c < centroidSupplier.size(); c++) {
float[] centroid = centroidSupplier.centroid(c);
int[] cluster = assignmentsByCluster[c];
boolean[] isOverspill = isOverspillByCluster[c];
offsets.add(postingsOutput.getFilePointer());
offsets.add(postingsOutput.alignFilePointer(Float.BYTES));
// write raw centroid for quantizing the query vectors
buffer.asFloatBuffer().put(centroid);
postingsOutput.writeBytes(buffer.array(), buffer.array().length);
// write centroid dot product for quantizing the query vectors
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
// write docIds
int size = cluster.length;
// TODO align???
postingsOutput.writeVInt(size);
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]);
// TODO we might want to consider putting the docIds in a separate file
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
// write vectors
bulkWriter.writeVectors(offHeapQuantizedVectors);
}

Expand Down Expand Up @@ -298,13 +310,8 @@ void writeCentroids(
}
writeQuantizedValue(centroidOutput, quantized, result);
}
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
// write the centroid offsets at the end of the file
for (int i = 0; i < centroidSupplier.size(); i++) {
float[] centroid = centroidSupplier.centroid(i);
buffer.asFloatBuffer().put(centroid);
// write the centroids
centroidOutput.writeBytes(buffer.array(), buffer.array().length);
// write the offset of this posting list
centroidOutput.writeLong(offsets.get(i));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
int centroidOrdinal = centroidQueue.pop();
// todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing
// is enough?
expectedDocs += scorer.resetPostingsScorer(
centroidQueryScorer.postingListOffset(centroidOrdinal),
centroidQueryScorer.centroid(centroidOrdinal)
);
expectedDocs += scorer.resetPostingsScorer(centroidQueryScorer.postingListOffset(centroidOrdinal));
actualDocs += scorer.visit(knnCollector);
}
if (acceptDocs != null) {
Expand All @@ -278,10 +275,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f);
while (centroidQueue.size() > 0 && (actualDocs < expectedScored || actualDocs < knnCollector.k())) {
int centroidOrdinal = centroidQueue.pop();
scorer.resetPostingsScorer(
centroidQueryScorer.postingListOffset(centroidOrdinal),
centroidQueryScorer.centroid(centroidOrdinal)
);
scorer.resetPostingsScorer(centroidQueryScorer.postingListOffset(centroidOrdinal));
actualDocs += scorer.visit(knnCollector);
}
}
Expand Down Expand Up @@ -332,8 +326,6 @@ abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postin
interface CentroidQueryScorer {
int size();

float[] centroid(int centroidOrdinal) throws IOException;

long postingListOffset(int centroidOrdinal) throws IOException;

void bulkScore(NeighborQueue queue) throws IOException;
Expand All @@ -343,7 +335,7 @@ interface PostingVisitor {
// TODO maybe we can not specifically pass the centroid...

/** returns the number of documents in the posting list */
int resetPostingsScorer(long offset, float[] centroid) throws IOException;
int resetPostingsScorer(long offset) throws IOException;

/** returns the number of scored documents */
int visit(KnnCollector collector) throws IOException;
Expand Down