Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@
import java.nio.ByteOrder;
import java.util.Arrays;

import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary;

/**
* Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
* partition the vector space, and then stores the centroids and posting list in a sequential
Expand All @@ -58,12 +54,15 @@ long[] buildAndWritePostingsLists(
// write the posting lists
final long[] offsets = new long[centroidSupplier.size()];
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
DocIdsWriter docIdsWriter = new DocIdsWriter();

DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(
ES91OSQVectorsScorer.BULK_SIZE,
quantizer,
floatVectorValues,
postingsOutput
);
for (int c = 0; c < centroidSupplier.size(); c++) {
float[] centroid = centroidSupplier.centroid(c);
binarizedByteVectorValues.centroid = centroid;
// TODO: add back in sorting vectors by distance to centroid
int[] cluster = assignmentsByCluster[c];
// TODO align???
Expand All @@ -75,7 +74,7 @@ long[] buildAndWritePostingsLists(
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid);
}

if (logger.isDebugEnabled()) {
Expand Down Expand Up @@ -115,54 +114,6 @@ private static void printClusterQualityStatistics(int[][] clusters) {
);
}

private void writePostingList(int[] cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues)
throws IOException {
int limit = cluster.length - ES91OSQVectorsScorer.BULK_SIZE + 1;
int cidx = 0;
OptimizedScalarQuantizer.QuantizationResult[] corrections =
new OptimizedScalarQuantizer.QuantizationResult[ES91OSQVectorsScorer.BULK_SIZE];
// Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE.
for (; cidx < limit; cidx += ES91OSQVectorsScorer.BULK_SIZE) {
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
int ord = cluster[cidx + j];
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
// write vector
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
corrections[j] = binarizedByteVectorValues.getCorrectiveTerms(ord);
}
// write corrections
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].lowerInterval()));
}
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].upperInterval()));
}
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
int targetComponentSum = corrections[j].quantizedComponentSum();
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
postingsOutput.writeShort((short) targetComponentSum);
}
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].additionalCorrection()));
}
}
// write tail
for (; cidx < cluster.length; cidx++) {
int ord = cluster[cidx];
// write vector
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
OptimizedScalarQuantizer.QuantizationResult correction = binarizedByteVectorValues.getCorrectiveTerms(ord);
writeQuantizedValue(postingsOutput, binaryValue, correction);
binarizedByteVectorValues.getCorrectiveTerms(ord);
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
postingsOutput.writeInt(Float.floatToIntBits(correction.lowerInterval()));
postingsOutput.writeInt(Float.floatToIntBits(correction.upperInterval()));
postingsOutput.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
assert correction.quantizedComponentSum() >= 0 && correction.quantizedComponentSum() <= 0xffff;
postingsOutput.writeShort((short) correction.quantizedComponentSum());
Comment on lines -155 to -162
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were actually writing all vectors TWICE in the tail of postings. This (just a tad) negatively impacted recall. This is also why the skew index test started failing as the testing value there apparently relied on this :/

}
}

@Override
CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, float[] globalCentroid) {
return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo);
Expand Down Expand Up @@ -295,47 +246,6 @@ CentroidAssignments calculateAndWriteCentroids(
}
}

// TODO unify with OSQ format
static class BinarizedFloatVectorValues {
private OptimizedScalarQuantizer.QuantizationResult corrections;
private final byte[] binarized;
private final byte[] initQuantized;
private float[] centroid;
private final FloatVectorValues values;
private final OptimizedScalarQuantizer quantizer;

private int lastOrd = -1;

BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer) {
this.values = delegate;
this.quantizer = quantizer;
this.binarized = new byte[discretize(delegate.dimension(), 64) / 8];
this.initQuantized = new byte[delegate.dimension()];
}

public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) {
if (ord != lastOrd) {
throw new IllegalStateException(
"attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd
);
}
return corrections;
}

public byte[] vectorValue(int ord) throws IOException {
if (ord != lastOrd) {
binarize(ord);
lastOrd = ord;
}
return binarized;
}

private void binarize(int ord) throws IOException {
corrections = quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid);
packAsBinary(initQuantized, binarized);
}
}

static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
throws IOException {
indexOutput.writeBytes(binaryValue, binaryValue.length);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors;

import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.hnsw.IntToIntFunction;

import java.io.IOException;

import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary;

/**
* Base class for bulk writers that write vectors to disk using the BBQ encoding.
* This class provides the structure for writing vectors in bulk, with specific
* implementations for different bit sizes strategies.
*/
public abstract class DiskBBQBulkWriter {
protected final int bulkSize;
protected final OptimizedScalarQuantizer quantizer;
protected final IndexOutput out;
protected final FloatVectorValues fvv;

protected DiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
this.bulkSize = bulkSize;
this.quantizer = quantizer;
this.out = out;
this.fvv = fvv;
}

public abstract void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException;

private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException {
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
int targetComponentSum = correction.quantizedComponentSum();
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
out.writeShort((short) targetComponentSum);
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
}
}

private static void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correction, IndexOutput out) throws IOException {
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
int targetComponentSum = correction.quantizedComponentSum();
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
out.writeShort((short) targetComponentSum);
}

public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
private final byte[] binarized;
private final byte[] initQuantized;
private final OptimizedScalarQuantizer.QuantizationResult[] corrections;

public OneBitDiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
super(bulkSize, quantizer, fvv, out);
this.binarized = new byte[discretize(fvv.dimension(), 64) / 8];
this.initQuantized = new byte[fvv.dimension()];
this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize];
}

@Override
public void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException {
int limit = count - bulkSize + 1;
int i = 0;
for (; i < limit; i += bulkSize) {
for (int j = 0; j < bulkSize; j++) {
int ord = ords.apply(i + j);
float[] fv = fvv.vectorValue(ord);
corrections[j] = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid);
packAsBinary(initQuantized, binarized);
out.writeBytes(binarized, binarized.length);
}
writeCorrections(corrections, out);
}
// write tail
for (; i < count; ++i) {
int ord = ords.apply(i);
float[] fv = fvv.vectorValue(ord);
OptimizedScalarQuantizer.QuantizationResult correction = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid);
packAsBinary(initQuantized, binarized);
out.writeBytes(binarized, binarized.length);
writeCorrection(correction, out);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ public void testSkewedIndex() throws IOException {
);
assertEquals(8, results.scoreDocs.length);
assertIdMatches(reader, "10", results.scoreDocs[0].doc);
assertIdMatches(reader, "8", results.scoreDocs[7].doc);
assertIdMatches(reader, "6", results.scoreDocs[7].doc);
}
}
}
Expand Down