diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java index 744fd248b2a49..9e83accef1268 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -28,8 +28,9 @@ class KMeansLocal { // the minimum distance that is considered to be "far enough" to a centroid in order to compute the soar distance. - // For vectors that are closer than this distance to the centroid, we use the squared distance to find the - // second closest centroid. + // For vectors that are closer than this distance to the centroid don't get spilled because they are well represented + // by the centroid itself. In many cases, it indicates a degenerated distribution, e.g the cluster is composed of the + // many equal vectors. private static final float SOAR_MIN_DISTANCE = 1e-16f; final int sampleSize; @@ -281,19 +282,18 @@ private void assignSpilled( final float[] distances = new float[4]; for (int i = 0; i < vectors.size(); i++) { float[] vector = vectors.vectorValue(i); - int currAssignment = assignments[i]; float[] currentCentroid = centroids[currAssignment]; - // TODO: cache these? float vectorCentroidDist = VectorUtil.squareDistance(vector, currentCentroid); - - if (vectorCentroidDist > SOAR_MIN_DISTANCE) { - for (int j = 0; j < vectors.dimension(); j++) { - diffs[j] = vector[j] - currentCentroid[j]; - } + if (vectorCentroidDist <= SOAR_MIN_DISTANCE) { + spilledAssignments[i] = -1; // no SOAR assignment + continue; } + for (int j = 0; j < vectors.dimension(); j++) { + diffs[j] = vector[j] - currentCentroid[j]; + } final int centroidCount; final IntToIntFunction centroidOrds; if (neighborhoods != null) { @@ -310,29 +310,17 @@ private void assignSpilled( float minSoar = Float.MAX_VALUE; int j = 0; for (; j < limit; j += 4) { - if (vectorCentroidDist > SOAR_MIN_DISTANCE) { - ESVectorUtil.soarDistanceBulk( - vector, - centroids[centroidOrds.apply(j)], - centroids[centroidOrds.apply(j + 1)], - centroids[centroidOrds.apply(j + 2)], - centroids[centroidOrds.apply(j + 3)], - diffs, - soarLambda, - vectorCentroidDist, - distances - ); - } else { - // if the vector is very close to the centroid, we look for the second-nearest centroid - ESVectorUtil.squareDistanceBulk( - vector, - centroids[centroidOrds.apply(j)], - centroids[centroidOrds.apply(j + 1)], - centroids[centroidOrds.apply(j + 2)], - centroids[centroidOrds.apply(j + 3)], - distances - ); - } + ESVectorUtil.soarDistanceBulk( + vector, + centroids[centroidOrds.apply(j)], + centroids[centroidOrds.apply(j + 1)], + centroids[centroidOrds.apply(j + 2)], + centroids[centroidOrds.apply(j + 3)], + diffs, + soarLambda, + vectorCentroidDist, + distances + ); for (int k = 0; k < distances.length; k++) { float soar = distances[k]; if (soar < minSoar) { @@ -344,13 +332,7 @@ private void assignSpilled( for (; j < centroidCount; j++) { int centroidOrd = centroidOrds.apply(j); - float soar; - if (vectorCentroidDist > SOAR_MIN_DISTANCE) { - soar = ESVectorUtil.soarDistance(vector, centroids[centroidOrd], diffs, soarLambda, vectorCentroidDist); - } else { - // if the vector is very close to the centroid, we look for the second-nearest centroid - soar = VectorUtil.squareDistance(vector, centroids[centroidOrd]); - } + float soar = ESVectorUtil.soarDistance(vector, centroids[centroidOrd], diffs, soarLambda, vectorCentroidDist); if (soar < minSoar) { minSoar = soar; bestAssignment = centroidOrd; diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java index 2c0d2f3fc7449..08693bf524691 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java @@ -22,8 +22,10 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.util.TestUtil; @@ -145,6 +147,66 @@ public void testSimpleOffHeapSize() throws IOException { } } + public void testFewVectorManyTimes() throws IOException { + int numDifferentVectors = random().nextInt(1, 20); + float[][] vectors = new float[numDifferentVectors][]; + int dimensions = random().nextInt(12, 500); + for (int i = 0; i < numDifferentVectors; i++) { + vectors[i] = randomVector(dimensions); + } + int numDocs = random().nextInt(100, 10_000); + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int i = 0; i < numDocs; i++) { + float[] vector = vectors[random().nextInt(numDifferentVectors)]; + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN)); + w.addDocument(doc); + } + w.commit(); + if (rarely()) { + w.forceMerge(1); + } + try (IndexReader reader = DirectoryReader.open(w)) { + List subReaders = reader.leaves(); + for (LeafReaderContext r : subReaders) { + LeafReader leafReader = r.reader(); + float[] vector = randomVector(dimensions); + TopDocs topDocs = leafReader.searchNearestVectors("f", vector, 10, leafReader.getLiveDocs(), Integer.MAX_VALUE); + assertEquals(Math.min(leafReader.maxDoc(), 10), topDocs.scoreDocs.length); + } + + } + } + } + + public void testOneRepeatedVector() throws IOException { + int dimensions = random().nextInt(12, 500); + float[] repeatedVector = randomVector(dimensions); + int numDocs = random().nextInt(100, 10_000); + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int i = 0; i < numDocs; i++) { + float[] vector = random().nextInt(3) == 0 ? repeatedVector : randomVector(dimensions); + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN)); + w.addDocument(doc); + } + w.commit(); + if (rarely()) { + w.forceMerge(1); + } + try (IndexReader reader = DirectoryReader.open(w)) { + List subReaders = reader.leaves(); + for (LeafReaderContext r : subReaders) { + LeafReader leafReader = r.reader(); + float[] vector = randomVector(dimensions); + TopDocs topDocs = leafReader.searchNearestVectors("f", vector, 10, leafReader.getLiveDocs(), Integer.MAX_VALUE); + assertEquals(Math.min(leafReader.maxDoc(), 10), topDocs.scoreDocs.length); + } + + } + } + } + // this is a modified version of lucene's TestSearchWithThreads test case public void testWithThreads() throws Exception { final int numThreads = random().nextInt(2, 5);