Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
import org.elasticsearch.simdvec.ESVectorUtil;

import java.io.IOException;
Expand Down Expand Up @@ -61,14 +61,14 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
targetQueryCopy,
scratch,
(byte) 4,
(byte) 7,
fieldEntry.globalCentroid()
);
final byte[] quantized = new byte[targetQuery.length];
for (int i = 0; i < quantized.length; i++) {
quantized[i] = (byte) scratch[i];
}
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
final ES92Int7VectorsScorer scorer = ESVectorUtil.getES92Int7VectorsScorer(centroids, fieldInfo.getVectorDimension());
centroids.seek(0L);
int numParents = centroids.readVInt();
if (numParents > 0) {
Expand All @@ -90,7 +90,7 @@ private static CentroidIterator getCentroidIteratorNoParent(
FieldInfo fieldInfo,
IndexInput centroids,
int numCentroids,
ES91Int4VectorsScorer scorer,
ES92Int7VectorsScorer scorer,
byte[] quantizeQuery,
OptimizedScalarQuantizer.QuantizationResult queryParams,
float globalCentroidDp
Expand All @@ -105,7 +105,7 @@ private static CentroidIterator getCentroidIteratorNoParent(
queryParams,
globalCentroidDp,
fieldInfo.getVectorSimilarityFunction(),
new float[ES91Int4VectorsScorer.BULK_SIZE]
new float[ES92Int7VectorsScorer.BULK_SIZE]
);
long offset = centroids.getFilePointer();
return new CentroidIterator() {
Expand All @@ -128,7 +128,7 @@ private static CentroidIterator getCentroidIteratorWithParents(
IndexInput centroids,
int numParents,
int numCentroids,
ES91Int4VectorsScorer scorer,
ES92Int7VectorsScorer scorer,
byte[] quantizeQuery,
OptimizedScalarQuantizer.QuantizationResult queryParams,
float globalCentroidDp
Expand All @@ -140,7 +140,7 @@ private static CentroidIterator getCentroidIteratorWithParents(
final int bufferSize = (int) Math.max(numCentroids * CENTROID_SAMPLING_PERCENTAGE, 1);
final NeighborQueue neighborQueue = new NeighborQueue(bufferSize, true);
// score the parents
final float[] scores = new float[ES91Int4VectorsScorer.BULK_SIZE];
final float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE];
score(
parentsQueue,
numParents,
Expand All @@ -152,7 +152,7 @@ private static CentroidIterator getCentroidIteratorWithParents(
fieldInfo.getVectorSimilarityFunction(),
scores
);
final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES;
final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Integer.BYTES;
final long offset = centroids.getFilePointer();
final long childrenOffset = offset + (long) Long.BYTES * numParents;
// populate the children's queue by reading parents one by one
Expand Down Expand Up @@ -227,7 +227,7 @@ private static void populateOneChildrenGroup(
long childrenOffset,
long centroidQuantizeSize,
FieldInfo fieldInfo,
ES91Int4VectorsScorer scorer,
ES92Int7VectorsScorer scorer,
byte[] quantizeQuery,
OptimizedScalarQuantizer.QuantizationResult queryParams,
float globalCentroidDp,
Expand All @@ -254,16 +254,16 @@ private static void score(
NeighborQueue neighborQueue,
int size,
int scoresOffset,
ES91Int4VectorsScorer scorer,
ES92Int7VectorsScorer scorer,
byte[] quantizeQuery,
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
float centroidDp,
VectorSimilarityFunction similarityFunction,
float[] scores
) throws IOException {
int limit = size - ES91Int4VectorsScorer.BULK_SIZE + 1;
int limit = size - ES92Int7VectorsScorer.BULK_SIZE + 1;
int i = 0;
for (; i < limit; i += ES91Int4VectorsScorer.BULK_SIZE) {
for (; i < limit; i += ES92Int7VectorsScorer.BULK_SIZE) {
scorer.scoreBulk(
quantizeQuery,
queryCorrections.lowerInterval(),
Expand All @@ -274,7 +274,7 @@ private static void score(
centroidDp,
scores
);
for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) {
for (int j = 0; j < ES92Int7VectorsScorer.BULK_SIZE; j++) {
neighborQueue.add(scoresOffset + i + j, scores[j]);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;

import java.io.IOException;
import java.io.UncheckedIOException;
Expand Down Expand Up @@ -315,8 +315,8 @@ private void writeCentroidsWithParents(
LongValues offsets,
IndexOutput centroidOutput
) throws IOException {
DiskBBQBulkWriter.FourBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.FourBitDiskBBQBulkWriter(
ES91Int4VectorsScorer.BULK_SIZE,
DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter(
ES92Int7VectorsScorer.BULK_SIZE,
centroidOutput
);
final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
Expand Down Expand Up @@ -365,8 +365,8 @@ private void writeCentroidsWithoutParents(
IndexOutput centroidOutput
) throws IOException {
centroidOutput.writeVInt(0);
DiskBBQBulkWriter.FourBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.FourBitDiskBBQBulkWriter(
ES91Int4VectorsScorer.BULK_SIZE,
DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter(
ES92Int7VectorsScorer.BULK_SIZE,
centroidOutput
);
final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
Expand Down Expand Up @@ -571,7 +571,7 @@ public byte[] next() throws IOException {
// Its possible that the vectors are on-heap and we cannot mutate them as we may quantize twice
// due to overspill, so we copy the vector to a scratch array
System.arraycopy(vector, 0, floatVectorScratch, 0, vector.length);
corrections = quantizer.scalarQuantize(floatVectorScratch, quantizedVectorScratch, (byte) 4, centroid);
corrections = quantizer.scalarQuantize(floatVectorScratch, quantizedVectorScratch, (byte) 7, centroid);
for (int i = 0; i < quantizedVectorScratch.length; i++) {
quantizedVector[i] = (byte) quantizedVectorScratch[i];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,6 @@ protected DiskBBQBulkWriter(int bulkSize, IndexOutput out) {

abstract void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException;

private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException {
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
int targetComponentSum = correction.quantizedComponentSum();
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
out.writeShort((short) targetComponentSum);
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
}
}

private static void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correction, IndexOutput out) throws IOException {
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
int targetComponentSum = correction.quantizedComponentSum();
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
out.writeShort((short) targetComponentSum);
}

static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
private final OptimizedScalarQuantizer.QuantizationResult[] corrections;

Expand All @@ -73,22 +47,48 @@ void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOEx
corrections[j] = qvv.getCorrections();
out.writeBytes(qv, qv.length);
}
writeCorrections(corrections, out);
writeCorrections(corrections);
}
// write tail
for (; i < qvv.count(); ++i) {
byte[] qv = qvv.next();
OptimizedScalarQuantizer.QuantizationResult correction = qvv.getCorrections();
out.writeBytes(qv, qv.length);
writeCorrection(correction, out);
writeCorrection(correction);
}
}

private void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections) throws IOException {
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
int targetComponentSum = correction.quantizedComponentSum();
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
out.writeShort((short) targetComponentSum);
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
}
}

private void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correction) throws IOException {
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
int targetComponentSum = correction.quantizedComponentSum();
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
out.writeShort((short) targetComponentSum);
}
}

static class FourBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
static class SevenBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
private final OptimizedScalarQuantizer.QuantizationResult[] corrections;

FourBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) {
SevenBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) {
super(bulkSize, out);
this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize];
}
Expand All @@ -103,15 +103,37 @@ void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOEx
corrections[j] = qvv.getCorrections();
out.writeBytes(qv, qv.length);
}
writeCorrections(corrections, out);
writeCorrections(corrections);
}
// write tail
for (; i < qvv.count(); ++i) {
byte[] qv = qvv.next();
OptimizedScalarQuantizer.QuantizationResult correction = qvv.getCorrections();
out.writeBytes(qv, qv.length);
writeCorrection(correction, out);
writeCorrection(correction);
}
}

private void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections) throws IOException {
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(correction.quantizedComponentSum());
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
}
}

private void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correction) throws IOException {
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
out.writeInt(correction.quantizedComponentSum());
}
}
}