diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index 5e696b74530a8..2c0f70ef9b670 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -25,6 +25,7 @@ import org.apache.lucene.util.packed.PackedLongValues; import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans; import org.elasticsearch.index.codec.vectors.cluster.KMeansResult; +import org.elasticsearch.index.codec.vectors.cluster.PrefetchingFloatVectorValues; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; @@ -63,7 +64,7 @@ public DefaultIVFVectorsWriter( LongValues buildAndWritePostingsLists( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, - FloatVectorValues floatVectorValues, + PrefetchingFloatVectorValues floatVectorValues, IndexOutput postingsOutput, long fileOffset, int[] assignments, @@ -155,7 +156,7 @@ LongValues buildAndWritePostingsLists( LongValues buildAndWritePostingsLists( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, - FloatVectorValues floatVectorValues, + PrefetchingFloatVectorValues floatVectorValues, IndexOutput postingsOutput, long fileOffset, MergeState mergeState, @@ -426,7 +427,7 @@ private void writeCentroidsWithoutParents( private record CentroidGroups(float[][] centroids, int[][] vectors, int maxVectorsPerCentroidLength) {} private CentroidGroups buildCentroidGroups(FieldInfo fieldInfo, CentroidSupplier centroidSupplier) throws IOException { - final FloatVectorValues floatVectorValues = FloatVectorValues.fromFloats(new AbstractList<>() { + final PrefetchingFloatVectorValues floatVectorValues = PrefetchingFloatVectorValues.floats(new AbstractList<>() { @Override public float[] get(int index) { try { @@ -479,7 +480,7 @@ public int size() { * @throws IOException if an I/O error occurs */ @Override - CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid) + CentroidAssignments calculateCentroids(FieldInfo fieldInfo, PrefetchingFloatVectorValues floatVectorValues, float[] globalCentroid) throws IOException { long nanoTime = System.nanoTime(); @@ -506,7 +507,8 @@ CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues fl return centroidAssignments; } - static CentroidAssignments buildCentroidAssignments(FloatVectorValues floatVectorValues, int vectorPerCluster) throws IOException { + static CentroidAssignments buildCentroidAssignments(PrefetchingFloatVectorValues floatVectorValues, int vectorPerCluster) + throws IOException { KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster); float[][] centroids = kMeansResult.centroids(); int[] assignments = kMeansResult.assignments(); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java index 308ee391b5f4a..1f4d81c741cf0 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -32,6 +32,7 @@ import org.apache.lucene.util.VectorUtil; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.index.codec.vectors.cluster.PrefetchingFloatVectorValues; import java.io.IOException; import java.io.UncheckedIOException; @@ -120,8 +121,11 @@ public final KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOExc return rawVectorDelegate; } - abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid) - throws IOException; + abstract CentroidAssignments calculateCentroids( + FieldInfo fieldInfo, + PrefetchingFloatVectorValues floatVectorValues, + float[] globalCentroid + ) throws IOException; abstract void writeCentroids( FieldInfo fieldInfo, @@ -134,7 +138,7 @@ abstract void writeCentroids( abstract LongValues buildAndWritePostingsLists( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, - FloatVectorValues floatVectorValues, + PrefetchingFloatVectorValues floatVectorValues, IndexOutput postingsOutput, long fileOffset, int[] assignments, @@ -144,7 +148,7 @@ abstract LongValues buildAndWritePostingsLists( abstract LongValues buildAndWritePostingsLists( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, - FloatVectorValues floatVectorValues, + PrefetchingFloatVectorValues floatVectorValues, IndexOutput postingsOutput, long fileOffset, MergeState mergeState, @@ -165,7 +169,11 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { for (FieldWriter fieldWriter : fieldWriters) { final float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()]; // build a float vector values with random access - final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc); + final PrefetchingFloatVectorValues floatVectorValues = getFloatVectorValues( + fieldWriter.fieldInfo, + fieldWriter.delegate, + maxDoc + ); // build centroids final CentroidAssignments centroidAssignments = calculateCentroids(fieldWriter.fieldInfo, floatVectorValues, globalCentroid); // wrap centroids with a supplier @@ -199,14 +207,14 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { } } - private static FloatVectorValues getFloatVectorValues( + private static PrefetchingFloatVectorValues getFloatVectorValues( FieldInfo fieldInfo, FlatFieldVectorsWriter fieldVectorsWriter, int maxDoc ) throws IOException { List vectors = fieldVectorsWriter.getVectors(); if (vectors.size() == maxDoc) { - return FloatVectorValues.fromFloats(vectors, fieldInfo.getVectorDimension()); + return PrefetchingFloatVectorValues.floats(vectors, fieldInfo.getVectorDimension()); } final DocIdSetIterator iterator = fieldVectorsWriter.getDocsWithFieldSet().iterator(); final int[] docIds = new int[vectors.size()]; @@ -214,32 +222,7 @@ private static FloatVectorValues getFloatVectorValues( docIds[i] = iterator.nextDoc(); } assert iterator.nextDoc() == NO_MORE_DOCS; - return new FloatVectorValues() { - @Override - public float[] vectorValue(int ord) { - return vectors.get(ord); - } - - @Override - public FloatVectorValues copy() { - return this; - } - - @Override - public int dimension() { - return fieldInfo.getVectorDimension(); - } - - @Override - public int size() { - return vectors.size(); - } - - @Override - public int ordToDoc(int ord) { - return docIds[ord]; - } - }; + return PrefetchingFloatVectorValues.floats(vectors, fieldInfo.getVectorDimension(), docIds); } @Override @@ -297,7 +280,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws IndexInput vectors = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT); IndexInput docs = docsFileName == null ? null : mergeState.segmentInfo.dir.openInput(docsFileName, IOContext.DEFAULT) ) { - final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, docs, vectors, numVectors); + final PrefetchingFloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, docs, vectors, numVectors); final long centroidOffset; final long centroidLength; @@ -396,15 +379,26 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws } } - private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput docs, IndexInput vectors, int numVectors) - throws IOException { + private static PrefetchingFloatVectorValues getFloatVectorValues( + FieldInfo fieldInfo, + IndexInput docs, + IndexInput vectors, + int numVectors + ) throws IOException { if (numVectors == 0) { - return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension()); + return PrefetchingFloatVectorValues.floats(List.of(), fieldInfo.getVectorDimension()); } final long vectorLength = (long) Float.BYTES * fieldInfo.getVectorDimension(); final float[] vector = new float[fieldInfo.getVectorDimension()]; final RandomAccessInput randomDocs = docs == null ? null : docs.randomAccessSlice(0, docs.length()); - return new FloatVectorValues() { + return new PrefetchingFloatVectorValues() { + @Override + public void prefetch(int... ord) throws IOException { + for (int o : ord) { + vectors.prefetch(o * vectorLength, vectorLength); + } + } + @Override public float[] vectorValue(int ord) throws IOException { vectors.seek(ord * vectorLength); @@ -413,7 +407,8 @@ public float[] vectorValue(int ord) throws IOException { } @Override - public FloatVectorValues copy() { + public PrefetchingFloatVectorValues copy() { + assert false; return this; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/SampleReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/SampleReader.java index f2d7944f1088c..3351580a46aa1 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/SampleReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/SampleReader.java @@ -20,21 +20,21 @@ package org.elasticsearch.index.codec.vectors; import org.apache.lucene.codecs.lucene95.HasIndexSlice; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; +import org.elasticsearch.index.codec.vectors.cluster.PrefetchingFloatVectorValues; import java.io.IOException; import java.util.Arrays; import java.util.Random; import java.util.function.IntUnaryOperator; -public class SampleReader extends FloatVectorValues implements HasIndexSlice { - private final FloatVectorValues origin; +public class SampleReader extends PrefetchingFloatVectorValues implements HasIndexSlice { + private final PrefetchingFloatVectorValues origin; private final int sampleSize; private final IntUnaryOperator sampleFunction; - SampleReader(FloatVectorValues origin, int sampleSize, IntUnaryOperator sampleFunction) { + SampleReader(PrefetchingFloatVectorValues origin, int sampleSize, IntUnaryOperator sampleFunction) { this.origin = origin; this.sampleSize = sampleSize; this.sampleFunction = sampleFunction; @@ -51,7 +51,14 @@ public int dimension() { } @Override - public FloatVectorValues copy() throws IOException { + public void prefetch(int... ord) throws IOException { + for (int o : ord) { + origin.prefetch(sampleFunction.applyAsInt(o)); + } + } + + @Override + public PrefetchingFloatVectorValues copy() throws IOException { throw new IllegalStateException("Not supported"); } @@ -81,7 +88,7 @@ public Bits getAcceptOrds(Bits acceptDocs) { throw new IllegalStateException("Not supported"); } - public static SampleReader createSampleReader(FloatVectorValues origin, int k, long seed) { + public static SampleReader createSampleReader(PrefetchingFloatVectorValues origin, int k, long seed) { // TODO can we do something algorithmically that aligns an ordinal with a unique integer between 0 and numVectors? if (k >= origin.size()) { new SampleReader(origin, origin.size(), i -> i); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/FloatVectorValuesSlice.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/FloatVectorValuesSlice.java index 6da6ff196e93e..6fc3186a9ff53 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/FloatVectorValuesSlice.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/FloatVectorValuesSlice.java @@ -9,16 +9,14 @@ package org.elasticsearch.index.codec.vectors.cluster; -import org.apache.lucene.index.FloatVectorValues; - import java.io.IOException; -class FloatVectorValuesSlice extends FloatVectorValues { +class FloatVectorValuesSlice extends PrefetchingFloatVectorValues { - private final FloatVectorValues allValues; + private final PrefetchingFloatVectorValues allValues; private final int[] slice; - FloatVectorValuesSlice(FloatVectorValues allValues, int[] slice) { + FloatVectorValuesSlice(PrefetchingFloatVectorValues allValues, int[] slice) { assert slice != null; assert slice.length <= allValues.size(); this.allValues = allValues; @@ -46,7 +44,14 @@ public int ordToDoc(int ord) { } @Override - public FloatVectorValues copy() throws IOException { + public PrefetchingFloatVectorValues copy() throws IOException { return new FloatVectorValuesSlice(this.allValues.copy(), this.slice); } + + @Override + public void prefetch(int... ord) throws IOException { + for (int o : ord) { + this.allValues.prefetch(slice[o]); + } + } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java index de654fb851554..a823263b10e96 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java @@ -9,8 +9,6 @@ package org.elasticsearch.index.codec.vectors.cluster; -import org.apache.lucene.index.FloatVectorValues; - import java.io.IOException; import java.util.Arrays; import java.util.Objects; @@ -52,7 +50,7 @@ public HierarchicalKMeans(int dimension, int maxIterations, int samplesPerCluste * @return the centroids and the vectors assignments and SOAR (spilled from nearby neighborhoods) assignments * @throws IOException is thrown if vectors is inaccessible */ - public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IOException { + public KMeansResult cluster(PrefetchingFloatVectorValues vectors, int targetSize) throws IOException { if (vectors.size() == 0) { return new KMeansIntermediate(); @@ -86,7 +84,7 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO return kMeansIntermediate; } - KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int targetSize) throws IOException { + KMeansIntermediate clusterAndSplit(final PrefetchingFloatVectorValues vectors, final int targetSize) throws IOException { if (vectors.size() <= targetSize) { return new KMeansIntermediate(); } @@ -132,7 +130,7 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta final int count = centroidVectorCount[c]; final int adjustedCentroid = c - removedElements; if (100 * count > 134 * targetSize) { - final FloatVectorValues sample = createClusterSlice(count, adjustedCentroid, vectors, assignments); + final PrefetchingFloatVectorValues sample = createClusterSlice(count, adjustedCentroid, vectors, assignments); // TODO: consider iterative here instead of recursive // recursive call to build out the sub partitions around this centroid c // subsequently reconcile and flatten the space of all centroids and assignments into one structure we can return @@ -163,7 +161,12 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta return kMeansIntermediate; } - static FloatVectorValues createClusterSlice(int clusterSize, int cluster, FloatVectorValues vectors, int[] assignments) { + static PrefetchingFloatVectorValues createClusterSlice( + int clusterSize, + int cluster, + PrefetchingFloatVectorValues vectors, + int[] assignments + ) { int[] slice = new int[clusterSize]; int idx = 0; for (int i = 0; i < assignments.length; i++) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java index 9e83accef1268..1028a814def7a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -9,7 +9,6 @@ package org.elasticsearch.index.codec.vectors.cluster; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.IntToIntFunction; @@ -50,7 +49,7 @@ class KMeansLocal { * @return randomly selected centroids that are the min of centroidCount and sampleSize * @throws IOException is thrown if vectors is inaccessible */ - static float[][] pickInitialCentroids(FloatVectorValues vectors, int centroidCount) throws IOException { + static float[][] pickInitialCentroids(PrefetchingFloatVectorValues vectors, int centroidCount) throws IOException { Random random = new Random(42L); int centroidsSize = Math.min(vectors.size(), centroidCount); float[][] centroids = new float[centroidsSize][vectors.dimension()]; @@ -69,7 +68,7 @@ static float[][] pickInitialCentroids(FloatVectorValues vectors, int centroidCou } private static boolean stepLloyd( - FloatVectorValues vectors, + PrefetchingFloatVectorValues vectors, IntToIntFunction translateOrd, float[][] centroids, FixedBitSet centroidChanged, @@ -81,7 +80,11 @@ private static boolean stepLloyd( int dim = vectors.dimension(); centroidChanged.clear(); final float[] distances = new float[4]; + vectors.prefetch(0); for (int idx = 0; idx < vectors.size(); idx++) { + if (idx < vectors.size() - 1) { + vectors.prefetch(idx + 1); + } float[] vector = vectors.vectorValue(idx); int vectorOrd = translateOrd.apply(idx); final int assignment = assignments[vectorOrd]; @@ -105,6 +108,7 @@ private static boolean stepLloyd( for (int idx = 0; idx < vectors.size(); idx++) { final int assignment = assignments[translateOrd.apply(idx)]; if (centroidChanged.get(assignment)) { + vectors.prefetch(idx); float[] centroid = centroids[assignment]; if (centroidCounts[assignment]++ == 0) { Arrays.fill(centroid, 0.0f); @@ -257,7 +261,7 @@ private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNe } private void assignSpilled( - FloatVectorValues vectors, + PrefetchingFloatVectorValues vectors, KMeansIntermediate kmeansIntermediate, NeighborHood[] neighborhoods, float soarLambda @@ -280,7 +284,11 @@ private void assignSpilled( float[] diffs = new float[vectors.dimension()]; final float[] distances = new float[4]; + vectors.prefetch(0); for (int i = 0; i < vectors.size(); i++) { + if (i < vectors.size() - 1) { + vectors.prefetch(i + 1); + } float[] vector = vectors.vectorValue(i); int currAssignment = assignments[i]; float[] currentCentroid = centroids[currAssignment]; @@ -357,7 +365,7 @@ record NeighborHood(int[] neighbors, float maxIntraDistance) { * passing in a valid output object with a centroids array that is the size of centroids expected * @throws IOException is thrown if vectors is inaccessible */ - void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) throws IOException { + void cluster(PrefetchingFloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) throws IOException { doCluster(vectors, kMeansIntermediate, -1, -1); } @@ -375,7 +383,7 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t * * @throws IOException is thrown if vectors is inaccessible or if the clustersPerNeighborhood is less than 2 */ - void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, int clustersPerNeighborhood, float soarLambda) + void cluster(PrefetchingFloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, int clustersPerNeighborhood, float soarLambda) throws IOException { if (clustersPerNeighborhood < 2) { throw new IllegalArgumentException("clustersPerNeighborhood must be at least 2, got [" + clustersPerNeighborhood + "]"); @@ -383,8 +391,12 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, i doCluster(vectors, kMeansIntermediate, clustersPerNeighborhood, soarLambda); } - private void doCluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, int clustersPerNeighborhood, float soarLambda) - throws IOException { + private void doCluster( + PrefetchingFloatVectorValues vectors, + KMeansIntermediate kMeansIntermediate, + int clustersPerNeighborhood, + float soarLambda + ) throws IOException { float[][] centroids = kMeansIntermediate.centroids(); boolean neighborAware = clustersPerNeighborhood != -1 && centroids.length > 1; NeighborHood[] neighborhoods = null; @@ -400,7 +412,7 @@ private void doCluster(FloatVectorValues vectors, KMeansIntermediate kMeansInter } } - private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, NeighborHood[] neighborhoods) + private void cluster(PrefetchingFloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, NeighborHood[] neighborhoods) throws IOException { float[][] centroids = kMeansIntermediate.centroids(); int k = centroids.length; @@ -412,7 +424,7 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme return; } IntToIntFunction translateOrd = i -> i; - FloatVectorValues sampledVectors = vectors; + PrefetchingFloatVectorValues sampledVectors = vectors; if (sampleSize < n) { sampledVectors = SampleReader.createSampleReader(vectors, sampleSize, 42L); translateOrd = sampledVectors::ordToDoc; @@ -435,7 +447,7 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme } /** - * helper that calls {@link KMeansLocal#cluster(FloatVectorValues, KMeansIntermediate)} given a set of initialized centroids, + * helper that calls {@link KMeansLocal#cluster(PrefetchingFloatVectorValues, KMeansIntermediate)} given a set of initialized centroids, * this call is not neighbor aware * * @param vectors the vectors to cluster @@ -443,7 +455,8 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme * @param sampleSize the subset of vectors to use when shifting centroids * @param maxIterations the max iterations to shift centroids */ - public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException { + public static void cluster(PrefetchingFloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) + throws IOException { KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, new int[vectors.size()], vectors::ordToDoc); KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations); kMeans.cluster(vectors, kMeansIntermediate); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/PrefetchingFloatVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/PrefetchingFloatVectorValues.java new file mode 100644 index 0000000000000..a08b2ce88ae83 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/PrefetchingFloatVectorValues.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.cluster; + +import org.apache.lucene.index.FloatVectorValues; + +import java.io.IOException; +import java.util.List; + +public abstract class PrefetchingFloatVectorValues extends FloatVectorValues { + public abstract void prefetch(int... ord) throws IOException; + + @Override + public abstract PrefetchingFloatVectorValues copy() throws IOException; + + public static PrefetchingFloatVectorValues floats(List vectors, int dimension, int[] ordToDoc) { + return new PrefetchingFloatVectorValues() { + @Override + public void prefetch(int... ord) { + // no-op + } + + @Override + public float[] vectorValue(int ord) { + return vectors.get(ord); + } + + @Override + public PrefetchingFloatVectorValues copy() { + return this; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return vectors.size(); + } + + @Override + public int ordToDoc(int ord) { + return ordToDoc[ord]; + } + }; + } + + public static PrefetchingFloatVectorValues floats(List vectors, int dimension) { + return new PrefetchingFloatVectorValues() { + @Override + public void prefetch(int... ord) { + // no-op + } + + @Override + public float[] vectorValue(int ord) { + return vectors.get(ord); + } + + @Override + public PrefetchingFloatVectorValues copy() { + return this; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return vectors.size(); + } + + @Override + public int ordToDoc(int ord) { + return ord; + } + }; + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java index a2d34d28f3784..6170005220190 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java @@ -9,7 +9,6 @@ package org.elasticsearch.index.codec.vectors.cluster; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.test.ESTestCase; @@ -27,7 +26,7 @@ public void testIllegalClustersPerNeighborhood() { IllegalArgumentException ex = expectThrows( IllegalArgumentException.class, () -> kMeansLocal.cluster( - FloatVectorValues.fromFloats(List.of(), randomInt(1024)), + PrefetchingFloatVectorValues.floats(List.of(), randomInt(1024)), kMeansIntermediate, randomIntBetween(Integer.MIN_VALUE, 1), randomFloat() @@ -44,7 +43,7 @@ public void testKMeansNeighbors() throws IOException { int maxIterations = random().nextInt(0, 100); int clustersPerNeighborhood = random().nextInt(2, 512); float soarLambda = random().nextFloat(0.5f, 1.5f); - FloatVectorValues vectors = generateData(nVectors, dims, nClusters); + PrefetchingFloatVectorValues vectors = generateData(nVectors, dims, nClusters); float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, nClusters); KMeansLocal.cluster(vectors, centroids, sampleSize, maxIterations); @@ -85,7 +84,7 @@ public void testKMeansNeighborsAllZero() throws IOException { vectors.add(vector); } int sampleSize = vectors.size(); - FloatVectorValues fvv = FloatVectorValues.fromFloats(vectors, 5); + PrefetchingFloatVectorValues fvv = PrefetchingFloatVectorValues.floats(vectors, 5); float[][] centroids = KMeansLocal.pickInitialCentroids(fvv, nClusters); KMeansLocal.cluster(fvv, centroids, sampleSize, maxIterations); @@ -121,7 +120,7 @@ public void testKMeansNeighborsAllZero() throws IOException { } } - private static FloatVectorValues generateData(int nSamples, int nDims, int nClusters) { + private static PrefetchingFloatVectorValues generateData(int nSamples, int nDims, int nClusters) { List vectors = new ArrayList<>(nSamples); float[][] centroids = new float[nClusters][nDims]; // Generate random centroids @@ -139,6 +138,6 @@ private static FloatVectorValues generateData(int nSamples, int nDims, int nClus } vectors.add(vector); } - return FloatVectorValues.fromFloats(vectors, nDims); + return PrefetchingFloatVectorValues.floats(vectors, nDims); } }