From dae118f2a6f62ce9886265d36cdc81a77838893d Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Mon, 7 Jul 2025 09:49:17 +0100 Subject: [PATCH] [IVF] Remove unnecessary loop over centroids and some clean up --- .../codec/vectors/cluster/KMeansLocal.java | 74 ++++++++----------- .../cluster/HierarchicalKMeansTests.java | 7 +- 2 files changed, 38 insertions(+), 43 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java index 25291d2bac7d1..a3be558128577 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -16,9 +16,7 @@ import org.elasticsearch.simdvec.ESVectorUtil; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; -import java.util.List; import java.util.Random; /** @@ -74,7 +72,7 @@ private static boolean stepLloyd( float[][] centroids, float[][] nextCentroids, int[] assignments, - List neighborhoods + NeighborHood[] neighborhoods ) throws IOException { boolean changed = false; int dim = vectors.dimension(); @@ -90,7 +88,7 @@ private static boolean stepLloyd( final int assignment = assignments[vectorOrd]; final int bestCentroidOffset; if (neighborhoods != null) { - bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods.get(assignment)); + bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods[assignment]); } else { bestCentroidOffset = getBestCentroid(centroids, vector); } @@ -152,30 +150,27 @@ private static int getBestCentroid(float[][] centroids, float[] vector) { return bestCentroidOffset; } - private void computeNeighborhoods(float[][] centers, List neighborhoods, int clustersPerNeighborhood) { - int k = neighborhoods.size(); - - if (k == 0 || clustersPerNeighborhood <= 0) { - return; - } - - List neighborQueues = new ArrayList<>(k); + private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) { + int k = centers.length; + assert k > clustersPerNeighborhood; + NeighborQueue[] neighborQueues = new NeighborQueue[k]; for (int i = 0; i < k; i++) { - neighborQueues.add(new NeighborQueue(clustersPerNeighborhood, true)); + neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true); } for (int i = 0; i < k - 1; i++) { for (int j = i + 1; j < k; j++) { float dsq = VectorUtil.squareDistance(centers[i], centers[j]); - neighborQueues.get(j).insertWithOverflow(i, dsq); - neighborQueues.get(i).insertWithOverflow(j, dsq); + neighborQueues[j].insertWithOverflow(i, dsq); + neighborQueues[i].insertWithOverflow(j, dsq); } } + NeighborHood[] neighborhoods = new NeighborHood[k]; for (int i = 0; i < k; i++) { - NeighborQueue queue = neighborQueues.get(i); + NeighborQueue queue = neighborQueues[i]; if (queue.size() == 0) { // no neighbors, skip - neighborhoods.set(i, NeighborHood.EMPTY); + neighborhoods[i] = NeighborHood.EMPTY; continue; } // consume the queue into the neighbors array and get the maximum intra-cluster distance @@ -185,16 +180,15 @@ private void computeNeighborhoods(float[][] centers, List neighbor while (queue.size() > 0) { neighbors[neighbors.length - ++iter] = queue.pop(); } - NeighborHood neighborHood = new NeighborHood(neighbors, maxIntraDistance); - neighborhoods.set(i, neighborHood); + neighborhoods[i] = new NeighborHood(neighbors, maxIntraDistance); } + return neighborhoods; } - private int[] assignSpilled( + private void assignSpilled( FloatVectorValues vectors, - List neighborhoods, - float[][] centroids, - int[] assignments, + KMeansIntermediate kmeansIntermediate, + NeighborHood[] neighborhoods, float soarLambda ) throws IOException { // SOAR uses an adjusted distance for assigning spilled documents which is @@ -205,8 +199,13 @@ private int[] assignSpilled( // Here, x is the document, c is the nearest centroid, and c_1 is the first // centroid the document was assigned to. The document is assigned to the // cluster with the smallest soar(x, c). - - int[] spilledAssignments = new int[assignments.length]; + int[] assignments = kmeansIntermediate.assignments(); + assert assignments != null; + assert assignments.length == vectors.size(); + int[] spilledAssignments = kmeansIntermediate.soarAssignments(); + assert spilledAssignments != null; + assert spilledAssignments.length == vectors.size(); + float[][] centroids = kmeansIntermediate.centroids(); float[] diffs = new float[vectors.dimension()]; for (int i = 0; i < vectors.size(); i++) { @@ -230,8 +229,8 @@ private int[] assignSpilled( int centroidCount = centroids.length; IntToIntFunction centroidOrds = c -> c; if (neighborhoods != null) { - assert neighborhoods.get(currAssignment) != null; - NeighborHood neighborhood = neighborhoods.get(currAssignment); + assert neighborhoods[currAssignment] != null; + NeighborHood neighborhood = neighborhoods[currAssignment]; centroidCount = neighborhood.neighbors.length; centroidOrds = c -> neighborhood.neighbors[c]; } @@ -257,8 +256,6 @@ private int[] assignSpilled( assert bestAssignment != -1 : "Failed to assign soar vector to centroid"; spilledAssignments[i] = bestAssignment; } - - return spilledAssignments; } record NeighborHood(int[] neighbors, float maxIntraDistance) { @@ -304,27 +301,20 @@ private void doCluster(FloatVectorValues vectors, KMeansIntermediate kMeansInter throws IOException { float[][] centroids = kMeansIntermediate.centroids(); boolean neighborAware = clustersPerNeighborhood != -1 && centroids.length > 1; - - List neighborhoods = null; + NeighborHood[] neighborhoods = null; // if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering if (neighborAware && centroids.length > clustersPerNeighborhood) { - int k = centroids.length; - neighborhoods = new ArrayList<>(k); - for (int i = 0; i < k; ++i) { - neighborhoods.add(null); - } - computeNeighborhoods(centroids, neighborhoods, clustersPerNeighborhood); + neighborhoods = computeNeighborhoods(centroids, clustersPerNeighborhood); } cluster(vectors, kMeansIntermediate, neighborhoods); if (neighborAware) { - int[] assignments = kMeansIntermediate.assignments(); - assert assignments != null; - assert assignments.length == vectors.size(); - kMeansIntermediate.setSoarAssignments(assignSpilled(vectors, neighborhoods, centroids, assignments, soarLambda)); + assert kMeansIntermediate.soarAssignments().length == 0; + kMeansIntermediate.setSoarAssignments(new int[vectors.size()]); + assignSpilled(vectors, kMeansIntermediate, neighborhoods, soarLambda); } } - private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List neighborhoods) + private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, NeighborHood[] neighborhoods) throws IOException { float[][] centroids = kMeansIntermediate.centroids(); int k = centroids.length; diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java index 8b4291fb6b3ec..d057063739bfb 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java @@ -38,11 +38,16 @@ public void testHKmeans() throws IOException { assertEquals(Math.min(nClusters, nVectors), centroids.length, 8); assertEquals(nVectors, assignments.length); + + for (int assignment : assignments) { + assertTrue(assignment >= 0 && assignment < centroids.length); + } if (centroids.length > 1 && centroids.length < nVectors) { assertEquals(nVectors, soarAssignments.length); // verify no duplicates exist for (int i = 0; i < assignments.length; i++) { - assert assignments[i] != soarAssignments[i]; + assertTrue(soarAssignments[i] >= 0 && soarAssignments[i] < centroids.length); + assertNotEquals(assignments[i], soarAssignments[i]); } } else { assertEquals(0, soarAssignments.length);