Skip to content

Commit e5da80f

Browse files
authored
Adjust IVF fixup phase to sometimes bypass some of the neighborhood calculations (#130490)
* Improve ivf index time during fixup phase * iter * addressing PR comments
1 parent 8308411 commit e5da80f

File tree

2 files changed

+55
-49
lines changed

2 files changed

+55
-49
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ private static boolean stepLloyd(
8282
float[][] centroids,
8383
float[][] nextCentroids,
8484
int[] assignments,
85-
List<int[]> neighborhoods
85+
List<NeighborHood> neighborhoods
8686
) throws IOException {
8787
boolean changed = false;
8888
int dim = vectors.dimension();
@@ -124,11 +124,20 @@ private static boolean stepLloyd(
124124
return changed;
125125
}
126126

127-
private static int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
127+
private static int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, NeighborHood neighborhood) {
128128
int bestCentroidOffset = centroidIdx;
129129
assert centroidIdx >= 0 && centroidIdx < centroids.length;
130130
float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
131-
for (int offset : centroidOffsets) {
131+
for (int i = 0; i < neighborhood.neighbors.length; i++) {
132+
int offset = neighborhood.neighbors[i];
133+
// float score = neighborhood.scores[i];
134+
assert offset >= 0 && offset < centroids.length : "Invalid neighbor offset: " + offset;
135+
if (minDsq < neighborhood.maxIntraDistance) {
136+
// if the distance found is smaller than the maximum intra-cluster distance
137+
// we don't consider it for further re-assignment
138+
return bestCentroidOffset;
139+
}
140+
// compute the distance to the centroid
132141
float dsq = VectorUtil.squareDistance(vector, centroids[offset]);
133142
if (dsq < minDsq) {
134143
minDsq = dsq;
@@ -151,7 +160,7 @@ private static int getBestCentroid(float[][] centroids, float[] vector) {
151160
return bestCentroidOffset;
152161
}
153162

154-
private void computeNeighborhoods(float[][] centers, List<int[]> neighborhoods, int clustersPerNeighborhood) {
163+
private void computeNeighborhoods(float[][] centers, List<NeighborHood> neighborhoods, int clustersPerNeighborhood) {
155164
int k = neighborhoods.size();
156165

157166
if (k == 0 || clustersPerNeighborhood <= 0) {
@@ -172,14 +181,24 @@ private void computeNeighborhoods(float[][] centers, List<int[]> neighborhoods,
172181

173182
for (int i = 0; i < k; i++) {
174183
NeighborQueue queue = neighborQueues.get(i);
175-
int neighborCount = queue.size();
176-
int[] neighbors = new int[neighborCount];
177-
queue.consumeNodes(neighbors);
178-
neighborhoods.set(i, neighbors);
184+
if (queue.size() == 0) {
185+
// no neighbors, skip
186+
neighborhoods.set(i, NeighborHood.EMPTY);
187+
continue;
188+
}
189+
// consume the queue into the neighbors array and get the maximum intra-cluster distance
190+
int[] neighbors = new int[queue.size()];
191+
float maxIntraDistance = queue.topScore();
192+
int iter = 0;
193+
while (queue.size() > 0) {
194+
neighbors[neighbors.length - ++iter] = queue.pop();
195+
}
196+
NeighborHood neighborHood = new NeighborHood(neighbors, maxIntraDistance);
197+
neighborhoods.set(i, neighborHood);
179198
}
180199
}
181200

182-
private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods, float[][] centroids, int[] assignments)
201+
private int[] assignSpilled(FloatVectorValues vectors, List<NeighborHood> neighborhoods, float[][] centroids, int[] assignments)
183202
throws IOException {
184203
// SOAR uses an adjusted distance for assigning spilled documents which is
185204
// given by:
@@ -200,7 +219,6 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
200219
float[] currentCentroid = centroids[currAssignment];
201220

202221
// TODO: cache these?
203-
// float vectorCentroidDist = assignmentDistances[i];
204222
float vectorCentroidDist = VectorUtil.squareDistance(vector, currentCentroid);
205223

206224
if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
@@ -212,24 +230,33 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
212230

213231
int bestAssignment = -1;
214232
float minSoar = Float.MAX_VALUE;
215-
assert neighborhoods.get(currAssignment) != null;
216-
for (int neighbor : neighborhoods.get(currAssignment)) {
217-
if (neighbor == currAssignment) {
218-
continue;
233+
int centroidCount = centroids.length;
234+
IntToIntFunction centroidOrds = c -> c;
235+
if (neighborhoods != null) {
236+
assert neighborhoods.get(currAssignment) != null;
237+
NeighborHood neighborhood = neighborhoods.get(currAssignment);
238+
centroidCount = neighborhood.neighbors.length;
239+
centroidOrds = c -> neighborhood.neighbors[c];
240+
}
241+
for (int j = 0; j < centroidCount; j++) {
242+
int centroidOrd = centroidOrds.apply(j);
243+
if (centroidOrd == currAssignment) {
244+
continue; // skip the current assignment
219245
}
220-
float[] neighborCentroid = centroids[neighbor];
221-
final float soar;
246+
float[] centroid = centroids[centroidOrd];
247+
float soar;
222248
if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
223-
soar = ESVectorUtil.soarDistance(vector, neighborCentroid, diffs, soarLambda, vectorCentroidDist);
249+
soar = ESVectorUtil.soarDistance(vector, centroid, diffs, soarLambda, vectorCentroidDist);
224250
} else {
225251
// if the vector is very close to the centroid, we look for the second-nearest centroid
226-
soar = VectorUtil.squareDistance(vector, neighborCentroid);
252+
soar = VectorUtil.squareDistance(vector, centroid);
227253
}
228254
if (soar < minSoar) {
229-
bestAssignment = neighbor;
230255
minSoar = soar;
256+
bestAssignment = centroidOrd;
231257
}
232258
}
259+
233260
assert bestAssignment != -1 : "Failed to assign soar vector to centroid";
234261
spilledAssignments[i] = bestAssignment;
235262
}
@@ -250,6 +277,10 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t
250277
cluster(vectors, kMeansIntermediate, false);
251278
}
252279

280+
record NeighborHood(int[] neighbors, float maxIntraDistance) {
281+
static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY);
282+
}
283+
253284
/**
254285
* cluster using a lloyd kmeans algorithm that also considers prior clustered neighborhoods when adjusting centroids
255286
* this also is used to generate the neighborhood aware additional (SOAR) assignments
@@ -266,8 +297,9 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t
266297
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, boolean neighborAware) throws IOException {
267298
float[][] centroids = kMeansIntermediate.centroids();
268299

269-
List<int[]> neighborhoods = null;
270-
if (neighborAware) {
300+
List<NeighborHood> neighborhoods = null;
301+
// if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering
302+
if (neighborAware && centroids.length > clustersPerNeighborhood) {
271303
int k = centroids.length;
272304
neighborhoods = new ArrayList<>(k);
273305
for (int i = 0; i < k; ++i) {
@@ -284,7 +316,8 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b
284316
}
285317
}
286318

287-
private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List<int[]> neighborhoods) throws IOException {
319+
private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List<NeighborHood> neighborhoods)
320+
throws IOException {
288321
float[][] centroids = kMeansIntermediate.centroids();
289322
int k = centroids.length;
290323
int n = vectors.size();

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborQueue.java

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -121,33 +121,6 @@ public int pop() {
121121
return decodeNodeId(heap.pop());
122122
}
123123

124-
public void consumeNodes(int[] dest) {
125-
if (dest.length < size()) {
126-
throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements.");
127-
}
128-
for (int i = 0; i < size(); i++) {
129-
dest[i] = decodeNodeId(heap.get(i + 1));
130-
}
131-
}
132-
133-
public int consumeNodesAndScoresMin(int[] dest, float[] scores) {
134-
if (dest.length < size() || scores.length < size()) {
135-
throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements.");
136-
}
137-
float bestScore = Float.POSITIVE_INFINITY;
138-
int bestIdx = 0;
139-
for (int i = 0; i < size(); i++) {
140-
long heapValue = heap.get(i + 1);
141-
scores[i] = decodeScore(heapValue);
142-
dest[i] = decodeNodeId(heapValue);
143-
if (scores[i] < bestScore) {
144-
bestScore = scores[i];
145-
bestIdx = i;
146-
}
147-
}
148-
return bestIdx;
149-
}
150-
151124
public void clear() {
152125
heap.clear();
153126
}

0 commit comments

Comments
 (0)