Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

package org.elasticsearch.index.codec.vectors;

record CentroidAssignments(int numCentroids, float[][] centroids, int[][] assignmentsByCluster) {
record CentroidAssignments(int numCentroids, float[][] centroids, int[] assignments, int[] overspillAssignments) {

CentroidAssignments(float[][] centroids, int[][] assignmentsByCluster) {
this(centroids.length, centroids, assignmentsByCluster);
assert centroids.length == assignmentsByCluster.length;
CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments) {
this(centroids.length, centroids, assignments, overspillAssignments);
assert assignments.length == overspillAssignments.length || overspillAssignments.length == 0
: "assignments and overspillAssignments must have the same length";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.IntToIntFunction;
import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans;
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
import org.elasticsearch.logging.LogManager;
Expand Down Expand Up @@ -49,32 +51,58 @@ long[] buildAndWritePostingsLists(
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
int[][] assignmentsByCluster
int[] assignments,
int[] overspillAssignments
) throws IOException {
int[] centroidVectorCount = new int[centroidSupplier.size()];
for (int i = 0; i < assignments.length; i++) {
centroidVectorCount[assignments[i]]++;
// if soar assignments are present, count them as well
if (overspillAssignments.length > i && overspillAssignments[i] != -1) {
centroidVectorCount[overspillAssignments[i]]++;
}
}

int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
for (int c = 0; c < centroidSupplier.size(); c++) {
assignmentsByCluster[c] = new int[centroidVectorCount[c]];
}
Arrays.fill(centroidVectorCount, 0);

for (int i = 0; i < assignments.length; i++) {
int c = assignments[i];
assignmentsByCluster[c][centroidVectorCount[c]++] = i;
// if soar assignments are present, add them to the cluster as well
if (overspillAssignments.length > i) {
int s = overspillAssignments[i];
if (s != -1) {
assignmentsByCluster[s][centroidVectorCount[s]++] = i;
}
}
}
// write the posting lists
final long[] offsets = new long[centroidSupplier.size()];
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
DocIdsWriter docIdsWriter = new DocIdsWriter();
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(
ES91OSQVectorsScorer.BULK_SIZE,
quantizer,
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors(
floatVectorValues,
postingsOutput
fieldInfo.getVectorDimension(),
new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction())
);
for (int c = 0; c < centroidSupplier.size(); c++) {
float[] centroid = centroidSupplier.centroid(c);
// TODO: add back in sorting vectors by distance to centroid
int[] cluster = assignmentsByCluster[c];
// TODO align???
offsets[c] = postingsOutput.getFilePointer();
int size = cluster.length;
postingsOutput.writeVInt(size);
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]);
// TODO we might want to consider putting the docIds in a separate file
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid);
bulkWriter.writeVectors(onHeapQuantizedVectors);
}

if (logger.isDebugEnabled()) {
Expand All @@ -84,6 +112,124 @@ long[] buildAndWritePostingsLists(
return offsets;
}

@Override
long[] buildAndWritePostingsLists(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
MergeState mergeState,
int[] assignments,
int[] overspillAssignments
) throws IOException {
// first, quantize all the vectors into a temporary file
String quantizedVectorsTempName = null;
IndexOutput quantizedVectorsTemp = null;
boolean success = false;
try {
quantizedVectorsTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "qvec_", IOContext.DEFAULT);
quantizedVectorsTempName = quantizedVectorsTemp.getName();
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
int[] quantized = new int[fieldInfo.getVectorDimension()];
byte[] binary = new byte[BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64) / 8];
float[] overspillScratch = new float[fieldInfo.getVectorDimension()];
for (int i = 0; i < assignments.length; i++) {
int c = assignments[i];
float[] centroid = centroidSupplier.centroid(c);
float[] vector = floatVectorValues.vectorValue(i);
boolean overspill = overspillAssignments.length > i && overspillAssignments[i] != -1;
// if overspilling, this means we quantize twice, and quantization mutates the in-memory representation of the vector
// so, make a copy of the vector to avoid mutating it
if (overspill) {
System.arraycopy(vector, 0, overspillScratch, 0, fieldInfo.getVectorDimension());
}

OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(vector, quantized, (byte) 1, centroid);
BQVectorUtils.packAsBinary(quantized, binary);
writeQuantizedValue(quantizedVectorsTemp, binary, result);
if (overspill) {
int s = overspillAssignments[i];
// write the overspill vector as well
result = quantizer.scalarQuantize(overspillScratch, quantized, (byte) 1, centroidSupplier.centroid(s));
BQVectorUtils.packAsBinary(quantized, binary);
writeQuantizedValue(quantizedVectorsTemp, binary, result);
} else {
// write a zero vector for the overspill
Arrays.fill(binary, (byte) 0);
OptimizedScalarQuantizer.QuantizationResult zeroResult = new OptimizedScalarQuantizer.QuantizationResult(0f, 0f, 0f, 0);
writeQuantizedValue(quantizedVectorsTemp, binary, zeroResult);
}
}
// close the temporary file so we can read it later
quantizedVectorsTemp.close();
success = true;
} finally {
if (success == false && quantizedVectorsTemp != null) {
mergeState.segmentInfo.dir.deleteFile(quantizedVectorsTemp.getName());
}
}
int[] centroidVectorCount = new int[centroidSupplier.size()];
for (int i = 0; i < assignments.length; i++) {
centroidVectorCount[assignments[i]]++;
// if soar assignments are present, count them as well
if (overspillAssignments.length > i && overspillAssignments[i] != -1) {
centroidVectorCount[overspillAssignments[i]]++;
}
}

int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
boolean[][] isOverspillByCluster = new boolean[centroidSupplier.size()][];
for (int c = 0; c < centroidSupplier.size(); c++) {
assignmentsByCluster[c] = new int[centroidVectorCount[c]];
isOverspillByCluster[c] = new boolean[centroidVectorCount[c]];
}
Arrays.fill(centroidVectorCount, 0);

for (int i = 0; i < assignments.length; i++) {
int c = assignments[i];
assignmentsByCluster[c][centroidVectorCount[c]++] = i;
// if soar assignments are present, add them to the cluster as well
if (overspillAssignments.length > i) {
int s = overspillAssignments[i];
if (s != -1) {
assignmentsByCluster[s][centroidVectorCount[s]] = i;
isOverspillByCluster[s][centroidVectorCount[s]++] = true;
}
}
}
// now we can read the quantized vectors from the temporary file
try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) {
final long[] offsets = new long[centroidSupplier.size()];
OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors(
quantizedVectorsInput,
fieldInfo.getVectorDimension()
);
DocIdsWriter docIdsWriter = new DocIdsWriter();
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
for (int c = 0; c < centroidSupplier.size(); c++) {
float[] centroid = centroidSupplier.centroid(c);
int[] cluster = assignmentsByCluster[c];
boolean[] isOverspill = isOverspillByCluster[c];
// TODO align???
offsets[c] = postingsOutput.getFilePointer();
int size = cluster.length;
postingsOutput.writeVInt(size);
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]);
// TODO we might want to consider putting the docIds in a separate file
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
bulkWriter.writeVectors(offHeapQuantizedVectors);
}

if (logger.isDebugEnabled()) {
printClusterQualityStatistics(assignmentsByCluster);
}
return offsets;
}
}

private static void printClusterQualityStatistics(int[][] clusters) {
float min = Float.MAX_VALUE;
float max = Float.MIN_VALUE;
Expand Down Expand Up @@ -210,33 +356,7 @@ static CentroidAssignments buildCentroidAssignments(KMeansResult kMeansResult) {
float[][] centroids = kMeansResult.centroids();
int[] assignments = kMeansResult.assignments();
int[] soarAssignments = kMeansResult.soarAssignments();
int[] centroidVectorCount = new int[centroids.length];
for (int i = 0; i < assignments.length; i++) {
centroidVectorCount[assignments[i]]++;
// if soar assignments are present, count them as well
if (soarAssignments.length > i && soarAssignments[i] != -1) {
centroidVectorCount[soarAssignments[i]]++;
}
}

int[][] assignmentsByCluster = new int[centroids.length][];
for (int c = 0; c < centroids.length; c++) {
assignmentsByCluster[c] = new int[centroidVectorCount[c]];
}
Arrays.fill(centroidVectorCount, 0);

for (int i = 0; i < assignments.length; i++) {
int c = assignments[i];
assignmentsByCluster[c][centroidVectorCount[c]++] = i;
// if soar assignments are present, add them to the cluster as well
if (soarAssignments.length > i) {
int s = soarAssignments[i];
if (s != -1) {
assignmentsByCluster[s][centroidVectorCount[s]++] = i;
}
}
}
return new CentroidAssignments(centroids, assignmentsByCluster);
return new CentroidAssignments(centroids, assignments, soarAssignments);
}

static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
Expand Down Expand Up @@ -281,4 +401,132 @@ public float[] centroid(int centroidOrdinal) throws IOException {
return scratch;
}
}

interface QuantizedVectorValues {
int count();

byte[] next() throws IOException;

OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException;
}

interface IntToBooleanFunction {
boolean apply(int ord);
}

static class OnHeapQuantizedVectors implements QuantizedVectorValues {
private final FloatVectorValues vectorValues;
private final OptimizedScalarQuantizer quantizer;
private final byte[] quantizedVector;
private final int[] quantizedVectorScratch;
private OptimizedScalarQuantizer.QuantizationResult corrections;
private float[] currentCentroid;
private IntToIntFunction ordTransformer = null;
private int currOrd = -1;
private int count;

OnHeapQuantizedVectors(FloatVectorValues vectorValues, int dimension, OptimizedScalarQuantizer quantizer) {
this.vectorValues = vectorValues;
this.quantizer = quantizer;
this.quantizedVector = new byte[BQVectorUtils.discretize(dimension, 64) / 8];
this.quantizedVectorScratch = new int[dimension];
this.corrections = null;
}

private void reset(float[] centroid, int count, IntToIntFunction ordTransformer) {
this.currentCentroid = centroid;
this.ordTransformer = ordTransformer;
this.currOrd = -1;
this.count = count;
}

@Override
public int count() {
return count;
}

@Override
public byte[] next() throws IOException {
if (currOrd >= count() - 1) {
throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count());
}
currOrd++;
int ord = ordTransformer.apply(currOrd);
float[] vector = vectorValues.vectorValue(ord);
corrections = quantizer.scalarQuantize(vector, quantizedVectorScratch, (byte) 1, currentCentroid);
BQVectorUtils.packAsBinary(quantizedVectorScratch, quantizedVector);
return quantizedVector;
}

@Override
public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException {
if (currOrd == -1) {
throw new IllegalStateException("No vector read yet, call next first");
}
return corrections;
}
}

static class OffHeapQuantizedVectors implements QuantizedVectorValues {
private final IndexInput quantizedVectorsInput;
private final byte[] binaryScratch;
private final float[] corrections = new float[3];

private final int vectorByteSize;
private short bitSum;
private int currOrd = -1;
private int count;
private IntToBooleanFunction isOverspill = null;
private IntToIntFunction ordTransformer = null;

OffHeapQuantizedVectors(IndexInput quantizedVectorsInput, int dimension) {
this.quantizedVectorsInput = quantizedVectorsInput;
this.binaryScratch = new byte[BQVectorUtils.discretize(dimension, 64) / 8];
this.vectorByteSize = (binaryScratch.length + 3 * Float.BYTES + Short.BYTES);
}

private void reset(int count, IntToBooleanFunction isOverspill, IntToIntFunction ordTransformer) {
this.count = count;
this.isOverspill = isOverspill;
this.ordTransformer = ordTransformer;
this.currOrd = -1;
}

@Override
public int count() {
return count;
}

@Override
public byte[] next() throws IOException {
if (currOrd >= count - 1) {
throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count);
}
currOrd++;
int ord = ordTransformer.apply(currOrd);
boolean isOverspill = this.isOverspill.apply(currOrd);
return getVector(ord, isOverspill);
}

@Override
public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException {
if (currOrd == -1) {
throw new IllegalStateException("No vector read yet, call readQuantizedVector first");
}
return new OptimizedScalarQuantizer.QuantizationResult(corrections[0], corrections[1], corrections[2], bitSum);
}

byte[] getVector(int ord, boolean isOverspill) throws IOException {
readQuantizedVector(ord, isOverspill);
return binaryScratch;
}

public void readQuantizedVector(int ord, boolean isOverspill) throws IOException {
long offset = (long) ord * (vectorByteSize * 2L) + (isOverspill ? vectorByteSize : 0);
quantizedVectorsInput.seek(offset);
quantizedVectorsInput.readBytes(binaryScratch, 0, binaryScratch.length);
quantizedVectorsInput.readFloats(corrections, 0, 3);
bitSum = quantizedVectorsInput.readShort();
}
}
}
Loading