Skip to content

Commit 17c7b3e

Browse files
iverasemridula-s109
authored andcommitted
Handle soar assignments when vector and centroid are very close (elastic#130206)
1 parent 20fe219 commit 17c7b3e

File tree

4 files changed

+63
-44
lines changed

4 files changed

+63
-44
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,23 @@
99

1010
package org.elasticsearch.index.codec.vectors;
1111

12-
import org.apache.lucene.internal.hppc.IntArrayList;
13-
1412
final class CentroidAssignments {
1513

1614
private final int numCentroids;
1715
private final float[][] cachedCentroids;
18-
private final IntArrayList[] assignmentsByCluster;
16+
private final int[][] assignmentsByCluster;
1917

20-
private CentroidAssignments(int numCentroids, float[][] cachedCentroids, IntArrayList[] assignmentsByCluster) {
18+
private CentroidAssignments(int numCentroids, float[][] cachedCentroids, int[][] assignmentsByCluster) {
2119
this.numCentroids = numCentroids;
2220
this.cachedCentroids = cachedCentroids;
2321
this.assignmentsByCluster = assignmentsByCluster;
2422
}
2523

26-
CentroidAssignments(float[][] centroids, IntArrayList[] assignmentsByCluster) {
24+
CentroidAssignments(float[][] centroids, int[][] assignmentsByCluster) {
2725
this(centroids.length, centroids, assignmentsByCluster);
2826
}
2927

30-
CentroidAssignments(int numCentroids, IntArrayList[] assignmentsByCluster) {
28+
CentroidAssignments(int numCentroids, int[][] assignmentsByCluster) {
3129
this(numCentroids, null, assignmentsByCluster);
3230
}
3331

@@ -40,7 +38,7 @@ public float[][] cachedCentroids() {
4038
return cachedCentroids;
4139
}
4240

43-
public IntArrayList[] assignmentsByCluster() {
41+
public int[][] assignmentsByCluster() {
4442
return assignmentsByCluster;
4543
}
4644
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import org.apache.lucene.index.FloatVectorValues;
1515
import org.apache.lucene.index.MergeState;
1616
import org.apache.lucene.index.SegmentWriteState;
17-
import org.apache.lucene.internal.hppc.IntArrayList;
1817
import org.apache.lucene.store.IndexInput;
1918
import org.apache.lucene.store.IndexOutput;
2019
import org.apache.lucene.util.VectorUtil;
@@ -27,6 +26,7 @@
2726
import java.io.IOException;
2827
import java.nio.ByteBuffer;
2928
import java.nio.ByteOrder;
29+
import java.util.Arrays;
3030

3131
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
3232
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
@@ -53,7 +53,7 @@ long[] buildAndWritePostingsLists(
5353
CentroidSupplier centroidSupplier,
5454
FloatVectorValues floatVectorValues,
5555
IndexOutput postingsOutput,
56-
IntArrayList[] assignmentsByCluster
56+
int[][] assignmentsByCluster
5757
) throws IOException {
5858
// write the posting lists
5959
final long[] offsets = new long[centroidSupplier.size()];
@@ -65,16 +65,16 @@ long[] buildAndWritePostingsLists(
6565
float[] centroid = centroidSupplier.centroid(c);
6666
binarizedByteVectorValues.centroid = centroid;
6767
// TODO: add back in sorting vectors by distance to centroid
68-
IntArrayList cluster = assignmentsByCluster[c];
68+
int[] cluster = assignmentsByCluster[c];
6969
// TODO align???
7070
offsets[c] = postingsOutput.getFilePointer();
71-
int size = cluster.size();
71+
int size = cluster.length;
7272
postingsOutput.writeVInt(size);
7373
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
7474
// TODO we might want to consider putting the docIds in a separate file
7575
// to aid with only having to fetch vectors from slower storage when they are required
7676
// keeping them in the same file indicates we pull the entire file into cache
77-
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput);
77+
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
7878
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
7979
}
8080

@@ -85,23 +85,23 @@ long[] buildAndWritePostingsLists(
8585
return offsets;
8686
}
8787

88-
private static void printClusterQualityStatistics(IntArrayList[] clusters) {
88+
private static void printClusterQualityStatistics(int[][] clusters) {
8989
float min = Float.MAX_VALUE;
9090
float max = Float.MIN_VALUE;
9191
float mean = 0;
9292
float m2 = 0;
9393
// iteratively compute the variance & mean
9494
int count = 0;
95-
for (IntArrayList cluster : clusters) {
95+
for (int[] cluster : clusters) {
9696
count += 1;
9797
if (cluster == null) {
9898
continue;
9999
}
100-
float delta = cluster.size() - mean;
100+
float delta = cluster.length - mean;
101101
mean += delta / count;
102-
m2 += delta * (cluster.size() - mean);
103-
min = Math.min(min, cluster.size());
104-
max = Math.max(max, cluster.size());
102+
m2 += delta * (cluster.length - mean);
103+
min = Math.min(min, cluster.length);
104+
max = Math.max(max, cluster.length);
105105
}
106106
float variance = m2 / (clusters.length - 1);
107107
logger.debug(
@@ -115,16 +115,16 @@ private static void printClusterQualityStatistics(IntArrayList[] clusters) {
115115
);
116116
}
117117

118-
private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues)
118+
private void writePostingList(int[] cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues)
119119
throws IOException {
120-
int limit = cluster.size() - ES91OSQVectorsScorer.BULK_SIZE + 1;
120+
int limit = cluster.length - ES91OSQVectorsScorer.BULK_SIZE + 1;
121121
int cidx = 0;
122122
OptimizedScalarQuantizer.QuantizationResult[] corrections =
123123
new OptimizedScalarQuantizer.QuantizationResult[ES91OSQVectorsScorer.BULK_SIZE];
124124
// Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE.
125125
for (; cidx < limit; cidx += ES91OSQVectorsScorer.BULK_SIZE) {
126126
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
127-
int ord = cluster.get(cidx + j);
127+
int ord = cluster[cidx + j];
128128
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
129129
// write vector
130130
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
@@ -147,8 +147,8 @@ private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput,
147147
}
148148
}
149149
// write tail
150-
for (; cidx < cluster.size(); cidx++) {
151-
int ord = cluster.get(cidx);
150+
for (; cidx < cluster.length; cidx++) {
151+
int ord = cluster[cidx];
152152
// write vector
153153
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
154154
OptimizedScalarQuantizer.QuantizationResult correction = binarizedByteVectorValues.getCorrectiveTerms(ord);
@@ -261,23 +261,31 @@ CentroidAssignments calculateAndWriteCentroids(
261261
logger.debug("final centroid count: {}", centroids.length);
262262
}
263263

264-
IntArrayList[] assignmentsByCluster = new IntArrayList[centroids.length];
265-
for (int c = 0; c < centroids.length; c++) {
266-
IntArrayList cluster = new IntArrayList(vectorPerCluster);
267-
for (int j = 0; j < assignments.length; j++) {
268-
if (assignments[j] == c) {
269-
cluster.add(j);
270-
}
264+
int[] centroidVectorCount = new int[centroids.length];
265+
for (int i = 0; i < assignments.length; i++) {
266+
centroidVectorCount[assignments[i]]++;
267+
// if soar assignments are present, count them as well
268+
if (soarAssignments.length > i && soarAssignments[i] != -1) {
269+
centroidVectorCount[soarAssignments[i]]++;
271270
}
271+
}
272272

273-
for (int j = 0; j < soarAssignments.length; j++) {
274-
if (soarAssignments[j] == c) {
275-
cluster.add(j);
273+
int[][] assignmentsByCluster = new int[centroids.length][];
274+
for (int c = 0; c < centroids.length; c++) {
275+
assignmentsByCluster[c] = new int[centroidVectorCount[c]];
276+
}
277+
Arrays.fill(centroidVectorCount, 0);
278+
279+
for (int i = 0; i < assignments.length; i++) {
280+
int c = assignments[i];
281+
assignmentsByCluster[c][centroidVectorCount[c]++] = i;
282+
// if soar assignments are present, add them to the cluster as well
283+
if (soarAssignments.length > i) {
284+
int s = soarAssignments[i];
285+
if (s != -1) {
286+
assignmentsByCluster[s][centroidVectorCount[s]++] = i;
276287
}
277288
}
278-
279-
cluster.trimToSize();
280-
assignmentsByCluster[c] = cluster;
281289
}
282290

283291
if (cacheCentroids) {

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import org.apache.lucene.index.Sorter;
2424
import org.apache.lucene.index.VectorEncoding;
2525
import org.apache.lucene.index.VectorSimilarityFunction;
26-
import org.apache.lucene.internal.hppc.IntArrayList;
2726
import org.apache.lucene.search.DocIdSetIterator;
2827
import org.apache.lucene.store.IOContext;
2928
import org.apache.lucene.store.IndexInput;
@@ -140,7 +139,7 @@ abstract long[] buildAndWritePostingsLists(
140139
CentroidSupplier centroidSupplier,
141140
FloatVectorValues floatVectorValues,
142141
IndexOutput postingsOutput,
143-
IntArrayList[] assignmentsByCluster
142+
int[][] assignmentsByCluster
144143
) throws IOException;
145144

146145
abstract CentroidSupplier createCentroidSupplier(

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@
2626
*/
2727
class KMeansLocal {
2828

29+
// the minimum distance that is considered to be "far enough" to a centroid in order to compute the soar distance.
30+
// For vectors that are closer than this distance to the centroid, we use the squared distance to find the
31+
// second closest centroid.
32+
private static final float SOAR_MIN_DISTANCE = 1e-16f;
33+
2934
final int sampleSize;
3035
final int maxIterations;
3136
final int clustersPerNeighborhood;
@@ -190,15 +195,18 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
190195

191196
int currAssignment = assignments[i];
192197
float[] currentCentroid = centroids[currAssignment];
193-
for (int j = 0; j < vectors.dimension(); j++) {
194-
float diff = vector[j] - currentCentroid[j];
195-
diffs[j] = diff;
196-
}
197198

198199
// TODO: cache these?
199200
// float vectorCentroidDist = assignmentDistances[i];
200201
float vectorCentroidDist = VectorUtil.squareDistance(vector, currentCentroid);
201202

203+
if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
204+
for (int j = 0; j < vectors.dimension(); j++) {
205+
float diff = vector[j] - currentCentroid[j];
206+
diffs[j] = diff;
207+
}
208+
}
209+
202210
int bestAssignment = -1;
203211
float minSoar = Float.MAX_VALUE;
204212
assert neighborhoods.get(currAssignment) != null;
@@ -207,13 +215,19 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
207215
continue;
208216
}
209217
float[] neighborCentroid = centroids[neighbor];
210-
float soar = ESVectorUtil.soarDistance(vector, neighborCentroid, diffs, soarLambda, vectorCentroidDist);
218+
final float soar;
219+
if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
220+
soar = ESVectorUtil.soarDistance(vector, neighborCentroid, diffs, soarLambda, vectorCentroidDist);
221+
} else {
222+
// if the vector is very close to the centroid, we look for the second-nearest centroid
223+
soar = VectorUtil.squareDistance(vector, neighborCentroid);
224+
}
211225
if (soar < minSoar) {
212226
bestAssignment = neighbor;
213227
minSoar = soar;
214228
}
215229
}
216-
230+
assert bestAssignment != -1 : "Failed to assign soar vector to centroid";
217231
spilledAssignments[i] = bestAssignment;
218232
}
219233

0 commit comments

Comments
 (0)