Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
import org.elasticsearch.simdvec.ESVectorUtil;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

/**
Expand Down Expand Up @@ -74,7 +72,7 @@ private static boolean stepLloyd(
float[][] centroids,
float[][] nextCentroids,
int[] assignments,
List<NeighborHood> neighborhoods
NeighborHood[] neighborhoods
) throws IOException {
boolean changed = false;
int dim = vectors.dimension();
Expand All @@ -90,7 +88,7 @@ private static boolean stepLloyd(
final int assignment = assignments[vectorOrd];
final int bestCentroidOffset;
if (neighborhoods != null) {
bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods.get(assignment));
bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods[assignment]);
} else {
bestCentroidOffset = getBestCentroid(centroids, vector);
}
Expand Down Expand Up @@ -152,30 +150,27 @@ private static int getBestCentroid(float[][] centroids, float[] vector) {
return bestCentroidOffset;
}

private void computeNeighborhoods(float[][] centers, List<NeighborHood> neighborhoods, int clustersPerNeighborhood) {
int k = neighborhoods.size();

if (k == 0 || clustersPerNeighborhood <= 0) {
return;
}

List<NeighborQueue> neighborQueues = new ArrayList<>(k);
private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) {
int k = centers.length;
assert k > clustersPerNeighborhood;
NeighborQueue[] neighborQueues = new NeighborQueue[k];
for (int i = 0; i < k; i++) {
neighborQueues.add(new NeighborQueue(clustersPerNeighborhood, true));
neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true);
}
for (int i = 0; i < k - 1; i++) {
for (int j = i + 1; j < k; j++) {
float dsq = VectorUtil.squareDistance(centers[i], centers[j]);
neighborQueues.get(j).insertWithOverflow(i, dsq);
neighborQueues.get(i).insertWithOverflow(j, dsq);
neighborQueues[j].insertWithOverflow(i, dsq);
neighborQueues[i].insertWithOverflow(j, dsq);
}
}

NeighborHood[] neighborhoods = new NeighborHood[k];
for (int i = 0; i < k; i++) {
NeighborQueue queue = neighborQueues.get(i);
NeighborQueue queue = neighborQueues[i];
if (queue.size() == 0) {
// no neighbors, skip
neighborhoods.set(i, NeighborHood.EMPTY);
neighborhoods[i] = NeighborHood.EMPTY;
continue;
}
// consume the queue into the neighbors array and get the maximum intra-cluster distance
Expand All @@ -185,16 +180,15 @@ private void computeNeighborhoods(float[][] centers, List<NeighborHood> neighbor
while (queue.size() > 0) {
neighbors[neighbors.length - ++iter] = queue.pop();
}
NeighborHood neighborHood = new NeighborHood(neighbors, maxIntraDistance);
neighborhoods.set(i, neighborHood);
neighborhoods[i] = new NeighborHood(neighbors, maxIntraDistance);
}
return neighborhoods;
}

private int[] assignSpilled(
private void assignSpilled(
FloatVectorValues vectors,
List<NeighborHood> neighborhoods,
float[][] centroids,
int[] assignments,
KMeansIntermediate kmeansIntermediate,
NeighborHood[] neighborhoods,
float soarLambda
) throws IOException {
// SOAR uses an adjusted distance for assigning spilled documents which is
Expand All @@ -205,8 +199,13 @@ private int[] assignSpilled(
// Here, x is the document, c is the nearest centroid, and c_1 is the first
// centroid the document was assigned to. The document is assigned to the
// cluster with the smallest soar(x, c).

int[] spilledAssignments = new int[assignments.length];
int[] assignments = kmeansIntermediate.assignments();
assert assignments != null;
assert assignments.length == vectors.size();
int[] spilledAssignments = kmeansIntermediate.soarAssignments();
assert spilledAssignments != null;
assert spilledAssignments.length == vectors.size();
float[][] centroids = kmeansIntermediate.centroids();

float[] diffs = new float[vectors.dimension()];
for (int i = 0; i < vectors.size(); i++) {
Expand All @@ -230,8 +229,8 @@ private int[] assignSpilled(
int centroidCount = centroids.length;
IntToIntFunction centroidOrds = c -> c;
if (neighborhoods != null) {
assert neighborhoods.get(currAssignment) != null;
NeighborHood neighborhood = neighborhoods.get(currAssignment);
assert neighborhoods[currAssignment] != null;
NeighborHood neighborhood = neighborhoods[currAssignment];
centroidCount = neighborhood.neighbors.length;
centroidOrds = c -> neighborhood.neighbors[c];
}
Expand All @@ -257,8 +256,6 @@ private int[] assignSpilled(
assert bestAssignment != -1 : "Failed to assign soar vector to centroid";
spilledAssignments[i] = bestAssignment;
}

return spilledAssignments;
}

record NeighborHood(int[] neighbors, float maxIntraDistance) {
Expand Down Expand Up @@ -304,27 +301,20 @@ private void doCluster(FloatVectorValues vectors, KMeansIntermediate kMeansInter
throws IOException {
float[][] centroids = kMeansIntermediate.centroids();
boolean neighborAware = clustersPerNeighborhood != -1 && centroids.length > 1;

List<NeighborHood> neighborhoods = null;
NeighborHood[] neighborhoods = null;
// if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering
if (neighborAware && centroids.length > clustersPerNeighborhood) {
int k = centroids.length;
neighborhoods = new ArrayList<>(k);
for (int i = 0; i < k; ++i) {
neighborhoods.add(null);
}
computeNeighborhoods(centroids, neighborhoods, clustersPerNeighborhood);
neighborhoods = computeNeighborhoods(centroids, clustersPerNeighborhood);
}
cluster(vectors, kMeansIntermediate, neighborhoods);
if (neighborAware) {
int[] assignments = kMeansIntermediate.assignments();
assert assignments != null;
assert assignments.length == vectors.size();
kMeansIntermediate.setSoarAssignments(assignSpilled(vectors, neighborhoods, centroids, assignments, soarLambda));
assert kMeansIntermediate.soarAssignments().length == 0;
kMeansIntermediate.setSoarAssignments(new int[vectors.size()]);
assignSpilled(vectors, kMeansIntermediate, neighborhoods, soarLambda);
}
}

private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List<NeighborHood> neighborhoods)
private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, NeighborHood[] neighborhoods)
throws IOException {
float[][] centroids = kMeansIntermediate.centroids();
int k = centroids.length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@ public void testHKmeans() throws IOException {

assertEquals(Math.min(nClusters, nVectors), centroids.length, 8);
assertEquals(nVectors, assignments.length);

for (int assignment : assignments) {
assertTrue(assignment >= 0 && assignment < centroids.length);
}
if (centroids.length > 1 && centroids.length < nVectors) {
assertEquals(nVectors, soarAssignments.length);
// verify no duplicates exist
for (int i = 0; i < assignments.length; i++) {
assert assignments[i] != soarAssignments[i];
assertTrue(soarAssignments[i] >= 0 && soarAssignments[i] < centroids.length);
assertNotEquals(assignments[i], soarAssignments[i]);
}
} else {
assertEquals(0, soarAssignments.length);
Expand Down