diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java index 85ca13e6e8754..c98532d8dd8f5 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java @@ -126,7 +126,10 @@ public void scoreFromArray(Blackhole bh) throws IOException { in.readFloats(corrections, 0, corrections.length); int addition = Short.toUnsignedInt(in.readShort()); float score = scorer.score( - result, + result.lowerInterval(), + result.upperInterval(), + result.quantizedComponentSum(), + result.additionalCorrection(), VectorSimilarityFunction.EUCLIDEAN, centroidDp, corrections[0], @@ -150,7 +153,10 @@ public void scoreFromMemorySegmentOnlyVector(Blackhole bh) throws IOException { in.readFloats(corrections, 0, corrections.length); int addition = Short.toUnsignedInt(in.readShort()); float score = scorer.score( - result, + result.lowerInterval(), + result.upperInterval(), + result.quantizedComponentSum(), + result.additionalCorrection(), VectorSimilarityFunction.EUCLIDEAN, centroidDp, corrections[0], @@ -175,7 +181,10 @@ public void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh) throws IOExceptio in.readFloats(corrections, 0, corrections.length); int addition = Short.toUnsignedInt(in.readShort()); float score = scorer.score( - result, + result.lowerInterval(), + result.upperInterval(), + result.quantizedComponentSum(), + result.additionalCorrection(), VectorSimilarityFunction.EUCLIDEAN, centroidDp, corrections[0], @@ -196,7 +205,16 @@ public void scoreFromMemorySegmentAllBulk(Blackhole bh) throws IOException { for (int j = 0; j < numQueries; j++) { in.seek(0); for (int i = 0; i < numVectors; i += 16) { - scorer.scoreBulk(binaryQueries[j], result, VectorSimilarityFunction.EUCLIDEAN, centroidDp, scratchScores); + scorer.scoreBulk( + binaryQueries[j], + result.lowerInterval(), + result.upperInterval(), + result.quantizedComponentSum(), + result.additionalCorrection(), + VectorSimilarityFunction.EUCLIDEAN, + centroidDp, + scratchScores + ); bh.consume(scratchScores); } } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java index be55c48dbe441..58df8bb03e0cb 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java @@ -95,7 +95,10 @@ public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOExce * Computes the score by applying the necessary corrections to the provided quantized distance. */ public float score( - OptimizedScalarQuantizer.QuantizationResult queryCorrections, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, float lowerInterval, @@ -107,19 +110,19 @@ public float score( float ax = lowerInterval; // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary float lx = upperInterval - ax; - float ay = queryCorrections.lowerInterval(); - float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; - float y1 = queryCorrections.quantizedComponentSum(); + float ay = queryLowerInterval; + float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; + float y1 = queryComponentSum; float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist; // For euclidean, we need to invert the score and apply the additional correction, which is // assumed to be the squared l2norm of the centroid centered vectors. if (similarityFunction == EUCLIDEAN) { - score = queryCorrections.additionalCorrection() + additionalCorrection - 2 * score; + score = queryAdditionalCorrection + additionalCorrection - 2 * score; return Math.max(1 / (1f + score), 0); } else { // For cosine and max inner product, we need to apply the additional correction, which is // assumed to be the non-centered dot-product between the vector and the centroid - score += queryCorrections.additionalCorrection() + additionalCorrection - centroidDp; + score += queryAdditionalCorrection + additionalCorrection - centroidDp; if (similarityFunction == MAXIMUM_INNER_PRODUCT) { return VectorUtil.scaleMaxInnerProductScore(score); } @@ -140,7 +143,10 @@ public float score( */ public void scoreBulk( byte[] q, - OptimizedScalarQuantizer.QuantizationResult queryCorrections, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, float[] scores @@ -154,7 +160,10 @@ public void scoreBulk( in.readFloats(additionalCorrections, 0, BULK_SIZE); for (int i = 0; i < BULK_SIZE; i++) { scores[i] = score( - queryCorrections, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, similarityFunction, centroidDp, lowerIntervals[i], diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java index 46daa074c5e5e..4be6ede34530a 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java @@ -19,7 +19,6 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.VectorUtil; -import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; @@ -298,7 +297,10 @@ private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IO @Override public void scoreBulk( byte[] q, - OptimizedScalarQuantizer.QuantizationResult queryCorrections, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, float[] scores @@ -307,19 +309,49 @@ public void scoreBulk( // 128 / 8 == 16 if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { - score256Bulk(q, queryCorrections, similarityFunction, centroidDp, scores); + score256Bulk( + q, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores + ); return; } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { - score128Bulk(q, queryCorrections, similarityFunction, centroidDp, scores); + score128Bulk( + q, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores + ); return; } } - super.scoreBulk(q, queryCorrections, similarityFunction, centroidDp, scores); + super.scoreBulk( + q, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores + ); } private void score128Bulk( byte[] q, - OptimizedScalarQuantizer.QuantizationResult queryCorrections, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, float[] scores @@ -328,9 +360,9 @@ private void score128Bulk( int limit = FLOAT_SPECIES_128.loopBound(BULK_SIZE); int i = 0; long offset = in.getFilePointer(); - float ay = queryCorrections.lowerInterval(); - float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; - float y1 = queryCorrections.quantizedComponentSum(); + float ay = queryLowerInterval; + float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; + float y1 = queryComponentSum; for (; i < limit; i += FLOAT_SPECIES_128.length()) { var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); var lx = FloatVector.fromMemorySegment( @@ -362,13 +394,13 @@ private void score128Bulk( // For euclidean, we need to invert the score and apply the additional correction, which is // assumed to be the squared l2norm of the centroid centered vectors. if (similarityFunction == EUCLIDEAN) { - res = res.mul(-2).add(additionalCorrections).add(queryCorrections.additionalCorrection()).add(1f); + res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f); res = FloatVector.broadcast(FLOAT_SPECIES_128, 1).div(res).max(0); res.intoArray(scores, i); } else { // For cosine and max inner product, we need to apply the additional correction, which is // assumed to be the non-centered dot-product between the vector and the centroid - res = res.add(queryCorrections.additionalCorrection()).add(additionalCorrections).sub(centroidDp); + res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp); if (similarityFunction == MAXIMUM_INNER_PRODUCT) { res.intoArray(scores, i); // not sure how to do it better @@ -386,7 +418,10 @@ private void score128Bulk( private void score256Bulk( byte[] q, - OptimizedScalarQuantizer.QuantizationResult queryCorrections, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, float[] scores @@ -395,9 +430,9 @@ private void score256Bulk( int limit = FLOAT_SPECIES_256.loopBound(BULK_SIZE); int i = 0; long offset = in.getFilePointer(); - float ay = queryCorrections.lowerInterval(); - float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; - float y1 = queryCorrections.quantizedComponentSum(); + float ay = queryLowerInterval; + float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; + float y1 = queryComponentSum; for (; i < limit; i += FLOAT_SPECIES_256.length()) { var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); var lx = FloatVector.fromMemorySegment( @@ -429,13 +464,13 @@ private void score256Bulk( // For euclidean, we need to invert the score and apply the additional correction, which is // assumed to be the squared l2norm of the centroid centered vectors. if (similarityFunction == EUCLIDEAN) { - res = res.mul(-2).add(additionalCorrections).add(queryCorrections.additionalCorrection()).add(1f); + res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f); res = FloatVector.broadcast(FLOAT_SPECIES_256, 1).div(res).max(0); res.intoArray(scores, i); } else { // For cosine and max inner product, we need to apply the additional correction, which is // assumed to be the non-centered dot-product between the vector and the centroid - res = res.add(queryCorrections.additionalCorrection()).add(additionalCorrections).sub(centroidDp); + res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp); if (similarityFunction == MAXIMUM_INNER_PRODUCT) { res.intoArray(scores, i); // not sure how to do it better diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java index 5544c0686fa5f..5712b452bfb5a 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java @@ -104,8 +104,26 @@ public void testScore() throws Exception { ); final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions); final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions); - defaultScorer.scoreBulk(query, result, similarityFunction, centroidDp, scores1); - panamaScorer.scoreBulk(query, result, similarityFunction, centroidDp, scores2); + defaultScorer.scoreBulk( + query, + result.lowerInterval(), + result.upperInterval(), + result.quantizedComponentSum(), + result.additionalCorrection(), + similarityFunction, + centroidDp, + scores1 + ); + panamaScorer.scoreBulk( + query, + result.lowerInterval(), + result.upperInterval(), + result.quantizedComponentSum(), + result.additionalCorrection(), + similarityFunction, + centroidDp, + scores2 + ); for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { if (scores1[j] > (maxDims * Short.MAX_VALUE)) { int diff = (int) (scores1[j] - scores2[j]); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index ab8ad21674177..2a2bef3dfcf19 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -18,7 +18,6 @@ import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.NeighborQueue; -import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import org.elasticsearch.simdvec.ESVectorUtil; @@ -31,8 +30,8 @@ import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; -import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; -import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte; +import static org.elasticsearch.index.codec.vectors.BQSpaceUtils.transposeHalfByte; +import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize; import static org.elasticsearch.simdvec.ES91OSQVectorsScorer.BULK_SIZE; /** @@ -47,13 +46,8 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect } @Override - CentroidQueryScorer getCentroidScorer( - FieldInfo fieldInfo, - int numCentroids, - IndexInput centroids, - float[] targetQuery, - IndexInput clusters - ) throws IOException { + CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery) + throws IOException { FieldEntry fieldEntry = fields.get(fieldInfo.number); float[] globalCentroid = fieldEntry.globalCentroid(); float globalCentroidDp = fieldEntry.globalCentroidDp(); @@ -259,7 +253,10 @@ void scoreIndividually(int offset) throws IOException { int doc = docIdsScratch[offset + j]; if (doc != -1) { scores[j] = osqVectorsScorer.score( - queryCorrections, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), fieldInfo.getVectorSimilarityFunction(), centroidDp, correctionsLower[j], @@ -297,7 +294,10 @@ public int visit(KnnCollector knnCollector) throws IOException { } else { osqVectorsScorer.scoreBulk( quantizedQueryScratch, - queryCorrections, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), fieldInfo.getVectorSimilarityFunction(), centroidDp, scores @@ -321,7 +321,10 @@ public int visit(KnnCollector knnCollector) throws IOException { indexInput.readFloats(correctiveValues, 0, 3); final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort()); float score = osqVectorsScorer.score( - queryCorrections, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), fieldInfo.getVectorSimilarityFunction(), centroidDp, correctiveValues[0], diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index cac8cc7cd2483..2488cf46cb122 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -18,7 +18,6 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.VectorUtil; -import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans; import org.elasticsearch.index.codec.vectors.cluster.KMeansResult; import org.elasticsearch.logging.LogManager; @@ -30,8 +29,8 @@ import java.nio.ByteOrder; import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS; -import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; -import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary; +import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize; +import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary; /** * Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java index d2495dfc89acd..7a18558703423 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java @@ -15,7 +15,6 @@ import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; @@ -109,13 +108,4 @@ public String toString() { return "IVFVectorsFormat(" + "vectorPerCluster=" + vectorPerCluster + ')'; } - static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) { - if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { - vectorsReader = candidateReader.getFieldReader(fieldName); - } - if (vectorsReader instanceof IVFVectorsReader reader) { - return reader; - } - return null; - } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index d5086cf2d479e..4537804664789 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -89,13 +89,8 @@ protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsR } } - abstract CentroidQueryScorer getCentroidScorer( - FieldInfo fieldInfo, - int numCentroids, - IndexInput centroids, - float[] target, - IndexInput clusters - ) throws IOException; + abstract CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target) + throws IOException; private static IndexInput openDataInput( SegmentReadState state, @@ -249,8 +244,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector fieldInfo, entry.postingListOffsets.length, entry.centroidSlice(ivfCentroids), - target, - ivfClusters + target ); if (nProbe == DYNAMIC_NPROBE) { // empirically based, and a good dynamic to get decent recall while scaling a la "efSearch" diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java index a29c21a158707..f0fba8c8e136f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -51,11 +51,9 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter { private final IndexOutput ivfCentroids, ivfClusters; private final IndexOutput ivfMeta; private final FlatVectorsWriter rawVectorDelegate; - private final SegmentWriteState segmentWriteState; @SuppressWarnings("this-escape") protected IVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) throws IOException { - this.segmentWriteState = state; this.rawVectorDelegate = rawVectorDelegate; final String metaFileName = IndexFileNames.segmentFileName( state.segmentInfo.name, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java index fdd02a0cf752a..6d50e5c473d06 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java @@ -125,8 +125,8 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta } int effectiveK = 0; - for (int i = 0; i < centroidVectorCount.length; i++) { - if (centroidVectorCount[i] > 0) { + for (int j : centroidVectorCount) { + if (j > 0) { effectiveK++; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java index 415a082c5a2b1..1551a1cfd0b6e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -81,8 +81,8 @@ private boolean stepLloyd( int dim = vectors.dimension(); int[] centroidCounts = new int[centroids.length]; - for (int i = 0; i < nextCentroids.length; i++) { - Arrays.fill(nextCentroids[i], 0.0f); + for (float[] nextCentroid : nextCentroids) { + Arrays.fill(nextCentroid, 0.0f); } for (int i = 0; i < sampleSize; i++) { @@ -99,7 +99,7 @@ private boolean stepLloyd( } assignments[i] = bestCentroidOffset; centroidCounts[bestCentroidOffset]++; - for (short d = 0; d < dim; d++) { + for (int d = 0; d < dim; d++) { nextCentroids[bestCentroidOffset][d] += vector[d]; } } @@ -107,7 +107,7 @@ private boolean stepLloyd( for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) { if (centroidCounts[clusterIdx] > 0) { float countF = (float) centroidCounts[clusterIdx]; - for (short d = 0; d < dim; d++) { + for (int d = 0; d < dim; d++) { centroids[clusterIdx][d] = nextCentroids[clusterIdx][d] / countF; } } @@ -185,7 +185,7 @@ private int[] assignSpilled(FloatVectorValues vectors, List neighborhoods int currAssignment = assignments[i]; float[] currentCentroid = centroids[currAssignment]; - for (short j = 0; j < vectors.dimension(); j++) { + for (int j = 0; j < vectors.dimension(); j++) { float diff = vector[j] - currentCentroid[j]; diffs[j] = diff; }