Skip to content

Commit 4d1b7a6

Browse files
authored
[DISKBBQ] Don't spill vectors that are numerically equivalent to the centroid (#132706)
This commit changes the degenerated case, where the vector is equivalent to the centroid, then the vector does not ] get a soar assignment, which is defined as a -1 in the soar assignments array.
1 parent cd9f620 commit 4d1b7a6

File tree

2 files changed

+83
-39
lines changed

2 files changed

+83
-39
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828
class KMeansLocal {
2929

3030
// the minimum distance that is considered to be "far enough" to a centroid in order to compute the soar distance.
31-
// For vectors that are closer than this distance to the centroid, we use the squared distance to find the
32-
// second closest centroid.
31+
// For vectors that are closer than this distance to the centroid don't get spilled because they are well represented
32+
// by the centroid itself. In many cases, it indicates a degenerated distribution, e.g the cluster is composed of the
33+
// many equal vectors.
3334
private static final float SOAR_MIN_DISTANCE = 1e-16f;
3435

3536
final int sampleSize;
@@ -281,19 +282,18 @@ private void assignSpilled(
281282
final float[] distances = new float[4];
282283
for (int i = 0; i < vectors.size(); i++) {
283284
float[] vector = vectors.vectorValue(i);
284-
285285
int currAssignment = assignments[i];
286286
float[] currentCentroid = centroids[currAssignment];
287-
288287
// TODO: cache these?
289288
float vectorCentroidDist = VectorUtil.squareDistance(vector, currentCentroid);
290-
291-
if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
292-
for (int j = 0; j < vectors.dimension(); j++) {
293-
diffs[j] = vector[j] - currentCentroid[j];
294-
}
289+
if (vectorCentroidDist <= SOAR_MIN_DISTANCE) {
290+
spilledAssignments[i] = -1; // no SOAR assignment
291+
continue;
295292
}
296293

294+
for (int j = 0; j < vectors.dimension(); j++) {
295+
diffs[j] = vector[j] - currentCentroid[j];
296+
}
297297
final int centroidCount;
298298
final IntToIntFunction centroidOrds;
299299
if (neighborhoods != null) {
@@ -310,29 +310,17 @@ private void assignSpilled(
310310
float minSoar = Float.MAX_VALUE;
311311
int j = 0;
312312
for (; j < limit; j += 4) {
313-
if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
314-
ESVectorUtil.soarDistanceBulk(
315-
vector,
316-
centroids[centroidOrds.apply(j)],
317-
centroids[centroidOrds.apply(j + 1)],
318-
centroids[centroidOrds.apply(j + 2)],
319-
centroids[centroidOrds.apply(j + 3)],
320-
diffs,
321-
soarLambda,
322-
vectorCentroidDist,
323-
distances
324-
);
325-
} else {
326-
// if the vector is very close to the centroid, we look for the second-nearest centroid
327-
ESVectorUtil.squareDistanceBulk(
328-
vector,
329-
centroids[centroidOrds.apply(j)],
330-
centroids[centroidOrds.apply(j + 1)],
331-
centroids[centroidOrds.apply(j + 2)],
332-
centroids[centroidOrds.apply(j + 3)],
333-
distances
334-
);
335-
}
313+
ESVectorUtil.soarDistanceBulk(
314+
vector,
315+
centroids[centroidOrds.apply(j)],
316+
centroids[centroidOrds.apply(j + 1)],
317+
centroids[centroidOrds.apply(j + 2)],
318+
centroids[centroidOrds.apply(j + 3)],
319+
diffs,
320+
soarLambda,
321+
vectorCentroidDist,
322+
distances
323+
);
336324
for (int k = 0; k < distances.length; k++) {
337325
float soar = distances[k];
338326
if (soar < minSoar) {
@@ -344,13 +332,7 @@ private void assignSpilled(
344332

345333
for (; j < centroidCount; j++) {
346334
int centroidOrd = centroidOrds.apply(j);
347-
float soar;
348-
if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
349-
soar = ESVectorUtil.soarDistance(vector, centroids[centroidOrd], diffs, soarLambda, vectorCentroidDist);
350-
} else {
351-
// if the vector is very close to the centroid, we look for the second-nearest centroid
352-
soar = VectorUtil.squareDistance(vector, centroids[centroidOrd]);
353-
}
335+
float soar = ESVectorUtil.soarDistance(vector, centroids[centroidOrd], diffs, soarLambda, vectorCentroidDist);
354336
if (soar < minSoar) {
355337
minSoar = soar;
356338
bestAssignment = centroidOrd;

server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
import org.apache.lucene.index.IndexReader;
2323
import org.apache.lucene.index.IndexWriter;
2424
import org.apache.lucene.index.LeafReader;
25+
import org.apache.lucene.index.LeafReaderContext;
2526
import org.apache.lucene.index.VectorEncoding;
2627
import org.apache.lucene.index.VectorSimilarityFunction;
28+
import org.apache.lucene.search.TopDocs;
2729
import org.apache.lucene.store.Directory;
2830
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
2931
import org.apache.lucene.tests.util.TestUtil;
@@ -145,6 +147,66 @@ public void testSimpleOffHeapSize() throws IOException {
145147
}
146148
}
147149

150+
public void testFewVectorManyTimes() throws IOException {
151+
int numDifferentVectors = random().nextInt(1, 20);
152+
float[][] vectors = new float[numDifferentVectors][];
153+
int dimensions = random().nextInt(12, 500);
154+
for (int i = 0; i < numDifferentVectors; i++) {
155+
vectors[i] = randomVector(dimensions);
156+
}
157+
int numDocs = random().nextInt(100, 10_000);
158+
try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
159+
for (int i = 0; i < numDocs; i++) {
160+
float[] vector = vectors[random().nextInt(numDifferentVectors)];
161+
Document doc = new Document();
162+
doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN));
163+
w.addDocument(doc);
164+
}
165+
w.commit();
166+
if (rarely()) {
167+
w.forceMerge(1);
168+
}
169+
try (IndexReader reader = DirectoryReader.open(w)) {
170+
List<LeafReaderContext> subReaders = reader.leaves();
171+
for (LeafReaderContext r : subReaders) {
172+
LeafReader leafReader = r.reader();
173+
float[] vector = randomVector(dimensions);
174+
TopDocs topDocs = leafReader.searchNearestVectors("f", vector, 10, leafReader.getLiveDocs(), Integer.MAX_VALUE);
175+
assertEquals(Math.min(leafReader.maxDoc(), 10), topDocs.scoreDocs.length);
176+
}
177+
178+
}
179+
}
180+
}
181+
182+
public void testOneRepeatedVector() throws IOException {
183+
int dimensions = random().nextInt(12, 500);
184+
float[] repeatedVector = randomVector(dimensions);
185+
int numDocs = random().nextInt(100, 10_000);
186+
try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
187+
for (int i = 0; i < numDocs; i++) {
188+
float[] vector = random().nextInt(3) == 0 ? repeatedVector : randomVector(dimensions);
189+
Document doc = new Document();
190+
doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN));
191+
w.addDocument(doc);
192+
}
193+
w.commit();
194+
if (rarely()) {
195+
w.forceMerge(1);
196+
}
197+
try (IndexReader reader = DirectoryReader.open(w)) {
198+
List<LeafReaderContext> subReaders = reader.leaves();
199+
for (LeafReaderContext r : subReaders) {
200+
LeafReader leafReader = r.reader();
201+
float[] vector = randomVector(dimensions);
202+
TopDocs topDocs = leafReader.searchNearestVectors("f", vector, 10, leafReader.getLiveDocs(), Integer.MAX_VALUE);
203+
assertEquals(Math.min(leafReader.maxDoc(), 10), topDocs.scoreDocs.length);
204+
}
205+
206+
}
207+
}
208+
}
209+
148210
// this is a modified version of lucene's TestSearchWithThreads test case
149211
public void testWithThreads() throws Exception {
150212
final int numThreads = random().nextInt(2, 5);

0 commit comments

Comments
 (0)