Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ static class OnHeapQuantizedVectors implements QuantizedVectorValues {
private final OptimizedScalarQuantizer quantizer;
private final byte[] quantizedVector;
private final int[] quantizedVectorScratch;
private final float[] floatVectorScratch;
private OptimizedScalarQuantizer.QuantizationResult corrections;
private float[] currentCentroid;
private IntToIntFunction ordTransformer = null;
Expand All @@ -430,6 +431,7 @@ static class OnHeapQuantizedVectors implements QuantizedVectorValues {
this.vectorValues = vectorValues;
this.quantizer = quantizer;
this.quantizedVector = new byte[BQVectorUtils.discretize(dimension, 64) / 8];
this.floatVectorScratch = new float[dimension];
this.quantizedVectorScratch = new int[dimension];
this.corrections = null;
}
Expand All @@ -454,7 +456,10 @@ public byte[] next() throws IOException {
currOrd++;
int ord = ordTransformer.apply(currOrd);
float[] vector = vectorValues.vectorValue(ord);
corrections = quantizer.scalarQuantize(vector, quantizedVectorScratch, (byte) 1, currentCentroid);
// Its possible that the vectors are on-heap and we cannot mutate them as we may quantize twice
// due to overspill, so we copy the vector to a scratch array
System.arraycopy(vector, 0, floatVectorScratch, 0, vector.length);
corrections = quantizer.scalarQuantize(floatVectorScratch, quantizedVectorScratch, (byte) 1, currentCentroid);
BQVectorUtils.packAsBinary(quantizedVectorScratch, quantizedVector);
return quantizedVector;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,20 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO
return new KMeansIntermediate();
}

// if we have a small number of vectors pick one and output that as the centroid
// if we have a small number of vectors calculate the centroid directly
if (vectors.size() <= targetSize) {
float[] centroid = new float[dimension];
System.arraycopy(vectors.vectorValue(0), 0, centroid, 0, dimension);
// sum the vectors
for (int i = 0; i < vectors.size(); i++) {
float[] vector = vectors.vectorValue(i);
for (int j = 0; j < dimension; j++) {
centroid[j] += vector[j];
}
}
// average the vectors
for (int j = 0; j < dimension; j++) {
centroid[j] /= vectors.size();
}
return new KMeansIntermediate(new float[][] { centroid }, new int[vectors.size()]);
}

Expand Down