Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
import org.apache.lucene.util.Bits;

import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import java.util.function.IntUnaryOperator;

class SampleReader extends FloatVectorValues implements HasIndexSlice {
public class SampleReader extends FloatVectorValues implements HasIndexSlice {
private final FloatVectorValues origin;
private final int sampleSize;
private final IntUnaryOperator sampleFunction;
Expand Down Expand Up @@ -71,21 +72,24 @@ public int getVectorByteLength() {

@Override
public int ordToDoc(int ord) {
throw new IllegalStateException("Not supported");
// get the original ordinal from the sample ordinal
return sampleFunction.applyAsInt(ord);
}

@Override
public Bits getAcceptOrds(Bits acceptDocs) {
throw new IllegalStateException("Not supported");
}

static SampleReader createSampleReader(FloatVectorValues origin, int k, long seed) {
public static SampleReader createSampleReader(FloatVectorValues origin, int k, long seed) {
// TODO can we do something algorithmically that aligns an ordinal with a unique integer between 0 and numVectors?
if (k >= origin.size()) {
new SampleReader(origin, origin.size(), i -> i);
}
// TODO maybe use bigArrays?
int[] samples = reservoirSample(origin.size(), k, seed);
// sort to prevent random backwards access weirdness
Arrays.sort(samples);
return new SampleReader(origin, samples.length, i -> samples[i]);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.index.FloatVectorValues;

import java.io.IOException;
import java.util.Arrays;

/**
* An implementation of the hierarchical k-means algorithm that better partitions data than naive k-means
Expand Down Expand Up @@ -84,6 +85,8 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta

// TODO: instead of creating a sub-cluster assignments reuse the parent array each time
int[] assignments = new int[vectors.size()];
// ensure we don't over assign to cluster 0 without adjusting it
Arrays.fill(assignments, -1);
KMeansLocal kmeans = new KMeansLocal(m, maxIterations);
float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k);
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.IntToIntFunction;
import org.elasticsearch.index.codec.vectors.SampleReader;
import org.elasticsearch.simdvec.ESVectorUtil;

import java.io.IOException;
Expand Down Expand Up @@ -74,12 +76,12 @@ static float[][] pickInitialCentroids(FloatVectorValues vectors, int centroidCou
return centroids;
}

private boolean stepLloyd(
private static boolean stepLloyd(
FloatVectorValues vectors,
IntToIntFunction translateOrd,
float[][] centroids,
float[][] nextCentroids,
int[] assignments,
int sampleSize,
List<int[]> neighborhoods
) throws IOException {
boolean changed = false;
Expand All @@ -90,17 +92,18 @@ private boolean stepLloyd(
Arrays.fill(nextCentroid, 0.0f);
}

for (int i = 0; i < sampleSize; i++) {
float[] vector = vectors.vectorValue(i);
final int assignment = assignments[i];
for (int idx = 0; idx < vectors.size(); idx++) {
float[] vector = vectors.vectorValue(idx);
int vectorOrd = translateOrd.apply(idx);
final int assignment = assignments[vectorOrd];
final int bestCentroidOffset;
if (neighborhoods != null) {
bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods.get(assignment));
} else {
bestCentroidOffset = getBestCentroid(centroids, vector);
}
if (assignment != bestCentroidOffset) {
assignments[i] = bestCentroidOffset;
assignments[vectorOrd] = bestCentroidOffset;
changed = true;
}
centroidCounts[bestCentroidOffset]++;
Expand All @@ -121,7 +124,7 @@ private boolean stepLloyd(
return changed;
}

int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
private static int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
int bestCentroidOffset = centroidIdx;
assert centroidIdx >= 0 && centroidIdx < centroids.length;
float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
Expand All @@ -135,7 +138,7 @@ int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centr
return bestCentroidOffset;
}

int getBestCentroid(float[][] centroids, float[] vector) {
private static int getBestCentroid(float[][] centroids, float[] vector) {
int bestCentroidOffset = 0;
float minDsq = Float.MAX_VALUE;
for (int i = 0; i < centroids.length; i++) {
Expand Down Expand Up @@ -281,24 +284,34 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b
}
}

void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List<int[]> neighborhoods) throws IOException {
private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List<int[]> neighborhoods) throws IOException {
float[][] centroids = kMeansIntermediate.centroids();
int k = centroids.length;
int n = vectors.size();

if (k == 1 || k >= n) {
return;
}

IntToIntFunction translateOrd = i -> i;
FloatVectorValues sampledVectors = vectors;
if (sampleSize < n) {
sampledVectors = SampleReader.createSampleReader(vectors, sampleSize, 42L);
translateOrd = sampledVectors::ordToDoc;
}
int[] assignments = kMeansIntermediate.assignments();
assert assignments.length == n;
float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
for (int i = 0; i < maxIterations; i++) {
if (stepLloyd(vectors, centroids, nextCentroids, assignments, sampleSize, neighborhoods) == false) {
// This is potentially sampled, so we need to translate ordinals
if (stepLloyd(sampledVectors, translateOrd, centroids, nextCentroids, assignments, neighborhoods) == false) {
break;
}
}
stepLloyd(vectors, centroids, nextCentroids, assignments, vectors.size(), neighborhoods);
// If we were sampled, do a once over the full set of vectors to finalize the centroids
if (sampleSize < n) {
// No ordinal translation needed here, we are using the full set of vectors
stepLloyd(vectors, i -> i, centroids, nextCentroids, assignments, neighborhoods);
}
}

/**
Expand Down