Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
class KMeansLocal {

// the minimum distance that is considered to be "far enough" to a centroid in order to compute the soar distance.
// For vectors that are closer than this distance to the centroid, we use the squared distance to find the
// second closest centroid.
// For vectors that are closer than this distance to the centroid don't get spilled because they are well represented
// by the centroid itself. In many cases, it indicates a degenerated distribution, e.g the cluster is composed of the
// many equal vectors.
private static final float SOAR_MIN_DISTANCE = 1e-16f;

final int sampleSize;
Expand Down Expand Up @@ -281,19 +282,18 @@ private void assignSpilled(
final float[] distances = new float[4];
for (int i = 0; i < vectors.size(); i++) {
float[] vector = vectors.vectorValue(i);

int currAssignment = assignments[i];
float[] currentCentroid = centroids[currAssignment];

// TODO: cache these?
float vectorCentroidDist = VectorUtil.squareDistance(vector, currentCentroid);

if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
for (int j = 0; j < vectors.dimension(); j++) {
diffs[j] = vector[j] - currentCentroid[j];
}
if (vectorCentroidDist <= SOAR_MIN_DISTANCE) {
spilledAssignments[i] = -1; // no SOAR assignment
continue;
}

for (int j = 0; j < vectors.dimension(); j++) {
diffs[j] = vector[j] - currentCentroid[j];
}
final int centroidCount;
final IntToIntFunction centroidOrds;
if (neighborhoods != null) {
Expand All @@ -310,29 +310,17 @@ private void assignSpilled(
float minSoar = Float.MAX_VALUE;
int j = 0;
for (; j < limit; j += 4) {
if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
ESVectorUtil.soarDistanceBulk(
vector,
centroids[centroidOrds.apply(j)],
centroids[centroidOrds.apply(j + 1)],
centroids[centroidOrds.apply(j + 2)],
centroids[centroidOrds.apply(j + 3)],
diffs,
soarLambda,
vectorCentroidDist,
distances
);
} else {
// if the vector is very close to the centroid, we look for the second-nearest centroid
ESVectorUtil.squareDistanceBulk(
vector,
centroids[centroidOrds.apply(j)],
centroids[centroidOrds.apply(j + 1)],
centroids[centroidOrds.apply(j + 2)],
centroids[centroidOrds.apply(j + 3)],
distances
);
}
ESVectorUtil.soarDistanceBulk(
vector,
centroids[centroidOrds.apply(j)],
centroids[centroidOrds.apply(j + 1)],
centroids[centroidOrds.apply(j + 2)],
centroids[centroidOrds.apply(j + 3)],
diffs,
soarLambda,
vectorCentroidDist,
distances
);
for (int k = 0; k < distances.length; k++) {
float soar = distances[k];
if (soar < minSoar) {
Expand All @@ -344,13 +332,7 @@ private void assignSpilled(

for (; j < centroidCount; j++) {
int centroidOrd = centroidOrds.apply(j);
float soar;
if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
soar = ESVectorUtil.soarDistance(vector, centroids[centroidOrd], diffs, soarLambda, vectorCentroidDist);
} else {
// if the vector is very close to the centroid, we look for the second-nearest centroid
soar = VectorUtil.squareDistance(vector, centroids[centroidOrd]);
}
float soar = ESVectorUtil.soarDistance(vector, centroids[centroidOrd], diffs, soarLambda, vectorCentroidDist);
if (soar < minSoar) {
minSoar = soar;
bestAssignment = centroidOrd;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.tests.util.TestUtil;
Expand Down Expand Up @@ -145,6 +147,66 @@ public void testSimpleOffHeapSize() throws IOException {
}
}

public void testFewVectorManyTimes() throws IOException {
int numDifferentVectors = random().nextInt(1, 20);
float[][] vectors = new float[numDifferentVectors][];
int dimensions = random().nextInt(12, 500);
for (int i = 0; i < numDifferentVectors; i++) {
vectors[i] = randomVector(dimensions);
}
int numDocs = random().nextInt(100, 10_000);
try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
for (int i = 0; i < numDocs; i++) {
float[] vector = vectors[random().nextInt(numDifferentVectors)];
Document doc = new Document();
doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN));
w.addDocument(doc);
}
w.commit();
if (rarely()) {
w.forceMerge(1);
}
try (IndexReader reader = DirectoryReader.open(w)) {
List<LeafReaderContext> subReaders = reader.leaves();
for (LeafReaderContext r : subReaders) {
LeafReader leafReader = r.reader();
float[] vector = randomVector(dimensions);
TopDocs topDocs = leafReader.searchNearestVectors("f", vector, 10, leafReader.getLiveDocs(), Integer.MAX_VALUE);
assertEquals(Math.min(leafReader.maxDoc(), 10), topDocs.scoreDocs.length);
}

}
}
}

public void testOneRepeatedVector() throws IOException {
int dimensions = random().nextInt(12, 500);
float[] repeatedVector = randomVector(dimensions);
int numDocs = random().nextInt(100, 10_000);
try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
for (int i = 0; i < numDocs; i++) {
float[] vector = random().nextInt(3) == 0 ? repeatedVector : randomVector(dimensions);
Document doc = new Document();
doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN));
w.addDocument(doc);
}
w.commit();
if (rarely()) {
w.forceMerge(1);
}
try (IndexReader reader = DirectoryReader.open(w)) {
List<LeafReaderContext> subReaders = reader.leaves();
for (LeafReaderContext r : subReaders) {
LeafReader leafReader = r.reader();
float[] vector = randomVector(dimensions);
TopDocs topDocs = leafReader.searchNearestVectors("f", vector, 10, leafReader.getLiveDocs(), Integer.MAX_VALUE);
assertEquals(Math.min(leafReader.maxDoc(), 10), topDocs.scoreDocs.length);
}

}
}
}

// this is a modified version of lucene's TestSearchWithThreads test case
public void testWithThreads() throws Exception {
final int numThreads = random().nextInt(2, 5);
Expand Down