Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
return new CentroidQueryScorer() {
int currentCentroid = -1;
long postingListOffset;
private final float[] centroid = new float[fieldInfo.getVectorDimension()];
private final float[] centroidCorrectiveValues = new float[3];
private final long rawCentroidsOffset = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES);
private final long rawCentroidsByteSize = (long) Float.BYTES * fieldInfo.getVectorDimension();
private final long rawCentroidsByteSize = (long) Float.BYTES * fieldInfo.getVectorDimension() + Long.BYTES;

@Override
public int size() {
Expand All @@ -79,12 +80,23 @@ public int size() {

@Override
public float[] centroid(int centroidOrdinal) throws IOException {
readDataIfNecessary(centroidOrdinal);
return centroid;
}

@Override
public long postingListOffset(int centroidOrdinal) throws IOException {
readDataIfNecessary(centroidOrdinal);
return postingListOffset;
}

private void readDataIfNecessary(int centroidOrdinal) throws IOException {
if (centroidOrdinal != currentCentroid) {
centroids.seek(rawCentroidsOffset + rawCentroidsByteSize * centroidOrdinal);
centroids.readFloats(centroid, 0, centroid.length);
postingListOffset = centroids.readLong();
currentCentroid = centroidOrdinal;
}
return centroid;
}

public void bulkScore(NeighborQueue queue) throws IOException {
Expand Down Expand Up @@ -217,9 +229,9 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
}

@Override
public int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException {
public int resetPostingsScorer(long offset, float[] centroid) throws IOException {
quantized = false;
indexInput.seek(entry.postingListOffsets()[centroidOrdinal]);
indexInput.seek(offset);
vectors = indexInput.readVInt();
centroidDp = Float.intBitsToFloat(indexInput.readInt());
this.centroid = centroid;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,71 +265,66 @@ CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numCentro
return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo);
}

static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput)
throws IOException {
@Override
void writeCentroids(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
float[] globalCentroid,
long[] offsets,
IndexOutput centroidOutput
) throws IOException {

final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
int[] quantizedScratch = new int[fieldInfo.getVectorDimension()];
float[] centroidScratch = new float[fieldInfo.getVectorDimension()];
final byte[] quantized = new byte[fieldInfo.getVectorDimension()];
// TODO do we want to store these distances as well for future use?
// TODO: sort centroids by global centroid (was doing so previously here)
// TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned
for (float[] centroid : centroids) {
for (int i = 0; i < centroidSupplier.size(); i++) {
float[] centroid = centroidSupplier.centroid(i);
System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length);
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(
centroidScratch,
quantizedScratch,
(byte) 4,
globalCentroid
);
for (int i = 0; i < quantizedScratch.length; i++) {
quantized[i] = (byte) quantizedScratch[i];
for (int j = 0; j < quantizedScratch.length; j++) {
quantized[j] = (byte) quantizedScratch[j];
}
writeQuantizedValue(centroidOutput, quantized, result);
}
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (float[] centroid : centroids) {
for (int i = 0; i < centroidSupplier.size(); i++) {
float[] centroid = centroidSupplier.centroid(i);
buffer.asFloatBuffer().put(centroid);
// write the centroids
centroidOutput.writeBytes(buffer.array(), buffer.array().length);
// write the offset of this posting list
centroidOutput.writeLong(offsets[i]);
}
}

@Override
CentroidAssignments calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput centroidOutput,
MergeState mergeState,
float[] globalCentroid
) throws IOException {
// TODO: take advantage of prior generated clusters from mergeState in the future
return calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, globalCentroid);
}

/**
* Calculate the centroids for the given field and write them to the given centroid output.
* Calculate the centroids for the given field.
* We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments
*
* @param fieldInfo merging field info
* @param floatVectorValues the float vector values to merge
* @param centroidOutput the centroid output
* @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
* @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed
* @throws IOException if an I/O error occurs
*/
@Override
CentroidAssignments calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput centroidOutput,
float[] globalCentroid
) throws IOException {
CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid)
throws IOException {

long nanoTime = System.nanoTime();

// TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids
KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster);
float[][] centroids = kMeansResult.centroids();
CentroidAssignments centroidAssignments = buildCentroidAssignments(floatVectorValues, vectorPerCluster);
float[][] centroids = centroidAssignments.centroids();
// TODO: for flush we are doing this over the vectors and here centroids which seems duplicative
// preliminary tests suggest recall is good using only centroids but need to do further evaluation
// TODO: push this logic into vector util?
Expand All @@ -342,17 +337,15 @@ CentroidAssignments calculateAndWriteCentroids(
globalCentroid[j] /= centroids.length;
}

// write centroids
writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput);

if (logger.isDebugEnabled()) {
logger.debug("calculate centroids and assign vectors time ms: {}", (System.nanoTime() - nanoTime) / 1000000.0);
logger.debug("final centroid count: {}", centroids.length);
}
return buildCentroidAssignments(kMeansResult);
return centroidAssignments;
}

static CentroidAssignments buildCentroidAssignments(KMeansResult kMeansResult) {
static CentroidAssignments buildCentroidAssignments(FloatVectorValues floatVectorValues, int vectorPerCluster) throws IOException {
KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster);
float[][] centroids = kMeansResult.centroids();
int[] assignments = kMeansResult.assignments();
int[] soarAssignments = kMeansResult.soarAssignments();
Expand All @@ -374,15 +367,13 @@ static class OffHeapCentroidSupplier implements CentroidSupplier {
private final int numCentroids;
private final int dimension;
private final float[] scratch;
private final long rawCentroidOffset;
private int currOrd = -1;

OffHeapCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo info) {
this.centroidsInput = centroidsInput;
this.numCentroids = numCentroids;
this.dimension = info.getVectorDimension();
this.scratch = new float[dimension];
this.rawCentroidOffset = (dimension + 3 * Float.BYTES + Short.BYTES) * numCentroids;
}

@Override
Expand All @@ -395,7 +386,7 @@ public float[] centroid(int centroidOrdinal) throws IOException {
if (centroidOrdinal == currOrd) {
return scratch;
}
centroidsInput.seek(rawCentroidOffset + (long) centroidOrdinal * dimension * Float.BYTES);
centroidsInput.seek((long) centroidOrdinal * dimension * Float.BYTES);
centroidsInput.readFloats(scratch, 0, dimension);
this.currOrd = centroidOrdinal;
return scratch;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,6 @@ private void readFields(ChecksumIndexInput meta) throws IOException {
private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
final VectorEncoding vectorEncoding = readVectorEncoding(input);
final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
final long centroidOffset = input.readLong();
final long centroidLength = input.readLong();
final int numPostingLists = input.readVInt();
final long[] postingListOffsets = new long[numPostingLists];
for (int i = 0; i < numPostingLists; i++) {
postingListOffsets[i] = input.readLong();
}
final float[] globalCentroid = new float[info.getVectorDimension()];
float globalCentroidDp = 0;
if (numPostingLists > 0) {
input.readFloats(globalCentroid, 0, globalCentroid.length);
globalCentroidDp = Float.intBitsToFloat(input.readInt());
}
if (similarityFunction != info.getVectorSimilarityFunction()) {
throw new IllegalStateException(
"Inconsistent vector similarity function for field=\""
Expand All @@ -163,12 +150,21 @@ private FieldEntry readField(IndexInput input, FieldInfo info) throws IOExceptio
+ info.getVectorSimilarityFunction()
);
}
final int numCentroids = input.readInt();
final long centroidOffset = input.readLong();
final long centroidLength = input.readLong();
final float[] globalCentroid = new float[info.getVectorDimension()];
float globalCentroidDp = 0;
if (centroidLength > 0) {
input.readFloats(globalCentroid, 0, globalCentroid.length);
globalCentroidDp = Float.intBitsToFloat(input.readInt());
}
return new FieldEntry(
similarityFunction,
vectorEncoding,
numCentroids,
centroidOffset,
centroidLength,
postingListOffsets,
globalCentroid,
globalCentroidDp
);
Expand Down Expand Up @@ -242,7 +238,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
FieldEntry entry = fields.get(fieldInfo.number);
CentroidQueryScorer centroidQueryScorer = getCentroidScorer(
fieldInfo,
entry.postingListOffsets.length,
entry.numCentroids,
entry.centroidSlice(ivfCentroids),
target
);
Expand Down Expand Up @@ -270,7 +266,10 @@ public final void search(String field, float[] target, KnnCollector knnCollector
int centroidOrdinal = centroidQueue.pop();
// todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing
// is enough?
expectedDocs += scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal));
expectedDocs += scorer.resetPostingsScorer(
centroidQueryScorer.postingListOffset(centroidOrdinal),
centroidQueryScorer.centroid(centroidOrdinal)
);
actualDocs += scorer.visit(knnCollector);
}
if (acceptDocs != null) {
Expand All @@ -279,7 +278,10 @@ public final void search(String field, float[] target, KnnCollector knnCollector
float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f);
while (centroidQueue.size() > 0 && (actualDocs < expectedScored || actualDocs < knnCollector.k())) {
int centroidOrdinal = centroidQueue.pop();
scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal));
scorer.resetPostingsScorer(
centroidQueryScorer.postingListOffset(centroidOrdinal),
centroidQueryScorer.centroid(centroidOrdinal)
);
actualDocs += scorer.visit(knnCollector);
}
}
Expand Down Expand Up @@ -313,9 +315,9 @@ public void close() throws IOException {
protected record FieldEntry(
VectorSimilarityFunction similarityFunction,
VectorEncoding vectorEncoding,
int numCentroids,
long centroidOffset,
long centroidLength,
long[] postingListOffsets,
float[] globalCentroid,
float globalCentroidDp
) {
Expand All @@ -332,14 +334,16 @@ interface CentroidQueryScorer {

float[] centroid(int centroidOrdinal) throws IOException;

long postingListOffset(int centroidOrdinal) throws IOException;

void bulkScore(NeighborQueue queue) throws IOException;
}

interface PostingVisitor {
// TODO maybe we can not specifically pass the centroid...

/** returns the number of documents in the posting list */
int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException;
int resetPostingsScorer(long offset, float[] centroid) throws IOException;

/** returns the number of scored documents */
int visit(KnnCollector collector) throws IOException;
Expand Down
Loading