Skip to content

Commit 2dd98f1

Browse files
committed
Improve ivf index time during fixup phase
1 parent 136442d commit 2dd98f1

File tree

2 files changed

+30
-27
lines changed

2 files changed

+30
-27
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ private static boolean stepLloyd(
8282
float[][] centroids,
8383
float[][] nextCentroids,
8484
int[] assignments,
85-
List<int[]> neighborhoods
85+
List<NeighborHood> neighborhoods
8686
) throws IOException {
8787
boolean changed = false;
8888
int dim = vectors.dimension();
@@ -124,11 +124,20 @@ private static boolean stepLloyd(
124124
return changed;
125125
}
126126

127-
private static int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
127+
private static int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, NeighborHood neighborhood) {
128128
int bestCentroidOffset = centroidIdx;
129129
assert centroidIdx >= 0 && centroidIdx < centroids.length;
130130
float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
131-
for (int offset : centroidOffsets) {
131+
for (int i = 0; i < neighborhood.neighbors.length; i++) {
132+
int offset = neighborhood.neighbors[i];
133+
// float score = neighborhood.scores[i];
134+
assert offset >= 0 && offset < centroids.length : "Invalid neighbor offset: " + offset;
135+
if (minDsq < neighborhood.maxIntraDistance) {
136+
// if the distance found is smaller than the maximum intra-cluster distance
137+
// we don't consider it for further re-assignment
138+
return bestCentroidOffset;
139+
}
140+
// compute the distance to the centroid
132141
float dsq = VectorUtil.squareDistance(vector, centroids[offset]);
133142
if (dsq < minDsq) {
134143
minDsq = dsq;
@@ -151,7 +160,7 @@ private static int getBestCentroid(float[][] centroids, float[] vector) {
151160
return bestCentroidOffset;
152161
}
153162

154-
private void computeNeighborhoods(float[][] centers, List<int[]> neighborhoods, int clustersPerNeighborhood) {
163+
private void computeNeighborhoods(float[][] centers, List<NeighborHood> neighborhoods, int clustersPerNeighborhood) {
155164
int k = neighborhoods.size();
156165

157166
if (k == 0 || clustersPerNeighborhood <= 0) {
@@ -174,12 +183,14 @@ private void computeNeighborhoods(float[][] centers, List<int[]> neighborhoods,
174183
NeighborQueue queue = neighborQueues.get(i);
175184
int neighborCount = queue.size();
176185
int[] neighbors = new int[neighborCount];
177-
queue.consumeNodes(neighbors);
178-
neighborhoods.set(i, neighbors);
186+
float[] scores = new float[clustersPerNeighborhood];
187+
float maxIntraDistance = queue.consumeNodesWithWorstScore(neighbors);
188+
NeighborHood neighborHood = new NeighborHood(neighbors, maxIntraDistance);
189+
neighborhoods.set(i, neighborHood);
179190
}
180191
}
181192

182-
private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods, float[][] centroids, int[] assignments)
193+
private int[] assignSpilled(FloatVectorValues vectors, List<NeighborHood> neighborhoods, float[][] centroids, int[] assignments)
183194
throws IOException {
184195
// SOAR uses an adjusted distance for assigning spilled documents which is
185196
// given by:
@@ -213,7 +224,7 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
213224
int bestAssignment = -1;
214225
float minSoar = Float.MAX_VALUE;
215226
assert neighborhoods.get(currAssignment) != null;
216-
for (int neighbor : neighborhoods.get(currAssignment)) {
227+
for (int neighbor : neighborhoods.get(currAssignment).neighbors()) {
217228
if (neighbor == currAssignment) {
218229
continue;
219230
}
@@ -250,6 +261,8 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t
250261
cluster(vectors, kMeansIntermediate, false);
251262
}
252263

264+
record NeighborHood(int[] neighbors, float maxIntraDistance) {}
265+
253266
/**
254267
* cluster using a lloyd kmeans algorithm that also considers prior clustered neighborhoods when adjusting centroids
255268
* this also is used to generate the neighborhood aware additional (SOAR) assignments
@@ -266,7 +279,7 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t
266279
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, boolean neighborAware) throws IOException {
267280
float[][] centroids = kMeansIntermediate.centroids();
268281

269-
List<int[]> neighborhoods = null;
282+
List<NeighborHood> neighborhoods = null;
270283
if (neighborAware) {
271284
int k = centroids.length;
272285
neighborhoods = new ArrayList<>(k);
@@ -284,7 +297,8 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b
284297
}
285298
}
286299

287-
private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List<int[]> neighborhoods) throws IOException {
300+
private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List<NeighborHood> neighborhoods)
301+
throws IOException {
288302
float[][] centroids = kMeansIntermediate.centroids();
289303
int k = centroids.length;
290304
int n = vectors.size();

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborQueue.java

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -121,31 +121,20 @@ public int pop() {
121121
return decodeNodeId(heap.pop());
122122
}
123123

124-
public void consumeNodes(int[] dest) {
124+
public float consumeNodesWithWorstScore(int[] dest) {
125125
if (dest.length < size()) {
126126
throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements.");
127127
}
128-
for (int i = 0; i < size(); i++) {
129-
dest[i] = decodeNodeId(heap.get(i + 1));
130-
}
131-
}
132-
133-
public int consumeNodesAndScoresMin(int[] dest, float[] scores) {
134-
if (dest.length < size() || scores.length < size()) {
135-
throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements.");
136-
}
137-
float bestScore = Float.POSITIVE_INFINITY;
138-
int bestIdx = 0;
128+
float worstScore = Float.NEGATIVE_INFINITY;
139129
for (int i = 0; i < size(); i++) {
140130
long heapValue = heap.get(i + 1);
141-
scores[i] = decodeScore(heapValue);
131+
float score = decodeScore(heapValue);
142132
dest[i] = decodeNodeId(heapValue);
143-
if (scores[i] < bestScore) {
144-
bestScore = scores[i];
145-
bestIdx = i;
133+
if (score > worstScore) {
134+
worstScore = score;
146135
}
147136
}
148-
return bestIdx;
137+
return worstScore;
149138
}
150139

151140
public void clear() {

0 commit comments

Comments
 (0)