Skip to content

Commit 5026756

Browse files
committed
iter
1 parent 2dd98f1 commit 5026756

File tree

2 files changed

+38
-14
lines changed

2 files changed

+38
-14
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,26 @@ private void computeNeighborhoods(float[][] centers, List<NeighborHood> neighbor
179179
}
180180
}
181181

182+
float[] scores = new float[clustersPerNeighborhood];
182183
for (int i = 0; i < k; i++) {
183184
NeighborQueue queue = neighborQueues.get(i);
184185
int neighborCount = queue.size();
185186
int[] neighbors = new int[neighborCount];
186-
float[] scores = new float[clustersPerNeighborhood];
187-
float maxIntraDistance = queue.consumeNodesWithWorstScore(neighbors);
187+
float maxIntraDistance = queue.consumeNodesWithWorstScore(neighbors, scores);
188+
// Sort neighbors by their score
189+
for (int j = 0; j < neighborCount; j++) {
190+
for (int l = j + 1; l < neighborCount; l++) {
191+
if (scores[j] > scores[l]) {
192+
// swap
193+
int tmp = neighbors[j];
194+
neighbors[j] = neighbors[l];
195+
neighbors[l] = tmp;
196+
float tmpScore = scores[j];
197+
scores[j] = scores[l];
198+
scores[l] = tmpScore;
199+
}
200+
}
201+
}
188202
NeighborHood neighborHood = new NeighborHood(neighbors, maxIntraDistance);
189203
neighborhoods.set(i, neighborHood);
190204
}
@@ -211,7 +225,6 @@ private int[] assignSpilled(FloatVectorValues vectors, List<NeighborHood> neighb
211225
float[] currentCentroid = centroids[currAssignment];
212226

213227
// TODO: cache these?
214-
// float vectorCentroidDist = assignmentDistances[i];
215228
float vectorCentroidDist = VectorUtil.squareDistance(vector, currentCentroid);
216229

217230
if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
@@ -223,24 +236,33 @@ private int[] assignSpilled(FloatVectorValues vectors, List<NeighborHood> neighb
223236

224237
int bestAssignment = -1;
225238
float minSoar = Float.MAX_VALUE;
226-
assert neighborhoods.get(currAssignment) != null;
227-
for (int neighbor : neighborhoods.get(currAssignment).neighbors()) {
228-
if (neighbor == currAssignment) {
229-
continue;
239+
int centroidCount = centroids.length;
240+
IntToIntFunction centroidOrds = c -> c;
241+
if (neighborhoods != null) {
242+
assert neighborhoods.get(currAssignment) != null;
243+
NeighborHood neighborhood = neighborhoods.get(currAssignment);
244+
centroidCount = neighborhood.neighbors.length;
245+
centroidOrds = c -> neighborhood.neighbors[c];
246+
}
247+
for (int j = 0; j < centroidCount; j++) {
248+
int centroidOrd = centroidOrds.apply(j);
249+
if (centroidOrd == currAssignment) {
250+
continue; // skip the current assignment
230251
}
231-
float[] neighborCentroid = centroids[neighbor];
232-
final float soar;
252+
float[] centroid = centroids[centroidOrd];
253+
float soar;
233254
if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
234-
soar = ESVectorUtil.soarDistance(vector, neighborCentroid, diffs, soarLambda, vectorCentroidDist);
255+
soar = ESVectorUtil.soarDistance(vector, centroid, diffs, soarLambda, vectorCentroidDist);
235256
} else {
236257
// if the vector is very close to the centroid, we look for the second-nearest centroid
237-
soar = VectorUtil.squareDistance(vector, neighborCentroid);
258+
soar = VectorUtil.squareDistance(vector, centroid);
238259
}
239260
if (soar < minSoar) {
240-
bestAssignment = neighbor;
241261
minSoar = soar;
262+
bestAssignment = centroidOrd;
242263
}
243264
}
265+
244266
assert bestAssignment != -1 : "Failed to assign soar vector to centroid";
245267
spilledAssignments[i] = bestAssignment;
246268
}
@@ -280,7 +302,8 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b
280302
float[][] centroids = kMeansIntermediate.centroids();
281303

282304
List<NeighborHood> neighborhoods = null;
283-
if (neighborAware) {
305+
// if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering
306+
if (neighborAware && centroids.length > clustersPerNeighborhood) {
284307
int k = centroids.length;
285308
neighborhoods = new ArrayList<>(k);
286309
for (int i = 0; i < k; ++i) {

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborQueue.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ public int pop() {
121121
return decodeNodeId(heap.pop());
122122
}
123123

124-
public float consumeNodesWithWorstScore(int[] dest) {
124+
public float consumeNodesWithWorstScore(int[] dest, float[] scores) {
125125
if (dest.length < size()) {
126126
throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements.");
127127
}
@@ -130,6 +130,7 @@ public float consumeNodesWithWorstScore(int[] dest) {
130130
long heapValue = heap.get(i + 1);
131131
float score = decodeScore(heapValue);
132132
dest[i] = decodeNodeId(heapValue);
133+
scores[i] = score;
133134
if (score > worstScore) {
134135
worstScore = score;
135136
}

0 commit comments

Comments
 (0)