Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.lucene.util.packed.PackedLongValues;
import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans;
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
import org.elasticsearch.index.codec.vectors.cluster.PrefetchingFloatVectorValues;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
Expand Down Expand Up @@ -63,7 +64,7 @@ public DefaultIVFVectorsWriter(
LongValues buildAndWritePostingsLists(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
PrefetchingFloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
long fileOffset,
int[] assignments,
Expand Down Expand Up @@ -155,7 +156,7 @@ LongValues buildAndWritePostingsLists(
LongValues buildAndWritePostingsLists(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
PrefetchingFloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
long fileOffset,
MergeState mergeState,
Expand Down Expand Up @@ -426,7 +427,7 @@ private void writeCentroidsWithoutParents(
private record CentroidGroups(float[][] centroids, int[][] vectors, int maxVectorsPerCentroidLength) {}

private CentroidGroups buildCentroidGroups(FieldInfo fieldInfo, CentroidSupplier centroidSupplier) throws IOException {
final FloatVectorValues floatVectorValues = FloatVectorValues.fromFloats(new AbstractList<>() {
final PrefetchingFloatVectorValues floatVectorValues = PrefetchingFloatVectorValues.floats(new AbstractList<>() {
@Override
public float[] get(int index) {
try {
Expand Down Expand Up @@ -479,7 +480,7 @@ public int size() {
* @throws IOException if an I/O error occurs
*/
@Override
CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid)
CentroidAssignments calculateCentroids(FieldInfo fieldInfo, PrefetchingFloatVectorValues floatVectorValues, float[] globalCentroid)
throws IOException {

long nanoTime = System.nanoTime();
Expand All @@ -506,7 +507,8 @@ CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues fl
return centroidAssignments;
}

static CentroidAssignments buildCentroidAssignments(FloatVectorValues floatVectorValues, int vectorPerCluster) throws IOException {
static CentroidAssignments buildCentroidAssignments(PrefetchingFloatVectorValues floatVectorValues, int vectorPerCluster)
throws IOException {
KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster);
float[][] centroids = kMeansResult.centroids();
int[] assignments = kMeansResult.assignments();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.index.codec.vectors.cluster.PrefetchingFloatVectorValues;

import java.io.IOException;
import java.io.UncheckedIOException;
Expand Down Expand Up @@ -120,8 +121,11 @@ public final KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOExc
return rawVectorDelegate;
}

abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid)
throws IOException;
abstract CentroidAssignments calculateCentroids(
FieldInfo fieldInfo,
PrefetchingFloatVectorValues floatVectorValues,
float[] globalCentroid
) throws IOException;

abstract void writeCentroids(
FieldInfo fieldInfo,
Expand All @@ -134,7 +138,7 @@ abstract void writeCentroids(
abstract LongValues buildAndWritePostingsLists(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
PrefetchingFloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
long fileOffset,
int[] assignments,
Expand All @@ -144,7 +148,7 @@ abstract LongValues buildAndWritePostingsLists(
abstract LongValues buildAndWritePostingsLists(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
PrefetchingFloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
long fileOffset,
MergeState mergeState,
Expand All @@ -165,7 +169,11 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
for (FieldWriter fieldWriter : fieldWriters) {
final float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()];
// build a float vector values with random access
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc);
final PrefetchingFloatVectorValues floatVectorValues = getFloatVectorValues(
fieldWriter.fieldInfo,
fieldWriter.delegate,
maxDoc
);
// build centroids
final CentroidAssignments centroidAssignments = calculateCentroids(fieldWriter.fieldInfo, floatVectorValues, globalCentroid);
// wrap centroids with a supplier
Expand Down Expand Up @@ -199,47 +207,22 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
}
}

private static FloatVectorValues getFloatVectorValues(
private static PrefetchingFloatVectorValues getFloatVectorValues(
FieldInfo fieldInfo,
FlatFieldVectorsWriter<float[]> fieldVectorsWriter,
int maxDoc
) throws IOException {
List<float[]> vectors = fieldVectorsWriter.getVectors();
if (vectors.size() == maxDoc) {
return FloatVectorValues.fromFloats(vectors, fieldInfo.getVectorDimension());
return PrefetchingFloatVectorValues.floats(vectors, fieldInfo.getVectorDimension());
}
final DocIdSetIterator iterator = fieldVectorsWriter.getDocsWithFieldSet().iterator();
final int[] docIds = new int[vectors.size()];
for (int i = 0; i < docIds.length; i++) {
docIds[i] = iterator.nextDoc();
}
assert iterator.nextDoc() == NO_MORE_DOCS;
return new FloatVectorValues() {
@Override
public float[] vectorValue(int ord) {
return vectors.get(ord);
}

@Override
public FloatVectorValues copy() {
return this;
}

@Override
public int dimension() {
return fieldInfo.getVectorDimension();
}

@Override
public int size() {
return vectors.size();
}

@Override
public int ordToDoc(int ord) {
return docIds[ord];
}
};
return PrefetchingFloatVectorValues.floats(vectors, fieldInfo.getVectorDimension(), docIds);
}

@Override
Expand Down Expand Up @@ -297,7 +280,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
IndexInput vectors = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT);
IndexInput docs = docsFileName == null ? null : mergeState.segmentInfo.dir.openInput(docsFileName, IOContext.DEFAULT)
) {
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, docs, vectors, numVectors);
final PrefetchingFloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, docs, vectors, numVectors);

final long centroidOffset;
final long centroidLength;
Expand Down Expand Up @@ -396,15 +379,26 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
}
}

private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput docs, IndexInput vectors, int numVectors)
throws IOException {
private static PrefetchingFloatVectorValues getFloatVectorValues(
FieldInfo fieldInfo,
IndexInput docs,
IndexInput vectors,
int numVectors
) throws IOException {
if (numVectors == 0) {
return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension());
return PrefetchingFloatVectorValues.floats(List.of(), fieldInfo.getVectorDimension());
}
final long vectorLength = (long) Float.BYTES * fieldInfo.getVectorDimension();
final float[] vector = new float[fieldInfo.getVectorDimension()];
final RandomAccessInput randomDocs = docs == null ? null : docs.randomAccessSlice(0, docs.length());
return new FloatVectorValues() {
return new PrefetchingFloatVectorValues() {
@Override
public void prefetch(int... ord) throws IOException {
for (int o : ord) {
vectors.prefetch(o * vectorLength, vectorLength);
}
Comment on lines +397 to +399
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this change for these three lines...LOL

}

@Override
public float[] vectorValue(int ord) throws IOException {
vectors.seek(ord * vectorLength);
Expand All @@ -413,7 +407,8 @@ public float[] vectorValue(int ord) throws IOException {
}

@Override
public FloatVectorValues copy() {
public PrefetchingFloatVectorValues copy() {
assert false;
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,21 @@
package org.elasticsearch.index.codec.vectors;

import org.apache.lucene.codecs.lucene95.HasIndexSlice;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.elasticsearch.index.codec.vectors.cluster.PrefetchingFloatVectorValues;

import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import java.util.function.IntUnaryOperator;

public class SampleReader extends FloatVectorValues implements HasIndexSlice {
private final FloatVectorValues origin;
public class SampleReader extends PrefetchingFloatVectorValues implements HasIndexSlice {
private final PrefetchingFloatVectorValues origin;
private final int sampleSize;
private final IntUnaryOperator sampleFunction;

SampleReader(FloatVectorValues origin, int sampleSize, IntUnaryOperator sampleFunction) {
SampleReader(PrefetchingFloatVectorValues origin, int sampleSize, IntUnaryOperator sampleFunction) {
this.origin = origin;
this.sampleSize = sampleSize;
this.sampleFunction = sampleFunction;
Expand All @@ -51,7 +51,14 @@ public int dimension() {
}

@Override
public FloatVectorValues copy() throws IOException {
public void prefetch(int... ord) throws IOException {
for (int o : ord) {
origin.prefetch(sampleFunction.applyAsInt(o));
}
}

@Override
public PrefetchingFloatVectorValues copy() throws IOException {
throw new IllegalStateException("Not supported");
}

Expand Down Expand Up @@ -81,7 +88,7 @@ public Bits getAcceptOrds(Bits acceptDocs) {
throw new IllegalStateException("Not supported");
}

public static SampleReader createSampleReader(FloatVectorValues origin, int k, long seed) {
public static SampleReader createSampleReader(PrefetchingFloatVectorValues origin, int k, long seed) {
// TODO can we do something algorithmically that aligns an ordinal with a unique integer between 0 and numVectors?
if (k >= origin.size()) {
new SampleReader(origin, origin.size(), i -> i);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,14 @@

package org.elasticsearch.index.codec.vectors.cluster;

import org.apache.lucene.index.FloatVectorValues;

import java.io.IOException;

class FloatVectorValuesSlice extends FloatVectorValues {
class FloatVectorValuesSlice extends PrefetchingFloatVectorValues {

private final FloatVectorValues allValues;
private final PrefetchingFloatVectorValues allValues;
private final int[] slice;

FloatVectorValuesSlice(FloatVectorValues allValues, int[] slice) {
FloatVectorValuesSlice(PrefetchingFloatVectorValues allValues, int[] slice) {
assert slice != null;
assert slice.length <= allValues.size();
this.allValues = allValues;
Expand Down Expand Up @@ -46,7 +44,14 @@ public int ordToDoc(int ord) {
}

@Override
public FloatVectorValues copy() throws IOException {
public PrefetchingFloatVectorValues copy() throws IOException {
return new FloatVectorValuesSlice(this.allValues.copy(), this.slice);
}

@Override
public void prefetch(int... ord) throws IOException {
for (int o : ord) {
this.allValues.prefetch(slice[o]);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

package org.elasticsearch.index.codec.vectors.cluster;

import org.apache.lucene.index.FloatVectorValues;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
Expand Down Expand Up @@ -52,7 +50,7 @@ public HierarchicalKMeans(int dimension, int maxIterations, int samplesPerCluste
* @return the centroids and the vectors assignments and SOAR (spilled from nearby neighborhoods) assignments
* @throws IOException is thrown if vectors is inaccessible
*/
public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IOException {
public KMeansResult cluster(PrefetchingFloatVectorValues vectors, int targetSize) throws IOException {

if (vectors.size() == 0) {
return new KMeansIntermediate();
Expand Down Expand Up @@ -86,7 +84,7 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO
return kMeansIntermediate;
}

KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int targetSize) throws IOException {
KMeansIntermediate clusterAndSplit(final PrefetchingFloatVectorValues vectors, final int targetSize) throws IOException {
if (vectors.size() <= targetSize) {
return new KMeansIntermediate();
}
Expand Down Expand Up @@ -132,7 +130,7 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
final int count = centroidVectorCount[c];
final int adjustedCentroid = c - removedElements;
if (100 * count > 134 * targetSize) {
final FloatVectorValues sample = createClusterSlice(count, adjustedCentroid, vectors, assignments);
final PrefetchingFloatVectorValues sample = createClusterSlice(count, adjustedCentroid, vectors, assignments);
// TODO: consider iterative here instead of recursive
// recursive call to build out the sub partitions around this centroid c
// subsequently reconcile and flatten the space of all centroids and assignments into one structure we can return
Expand Down Expand Up @@ -163,7 +161,12 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
return kMeansIntermediate;
}

static FloatVectorValues createClusterSlice(int clusterSize, int cluster, FloatVectorValues vectors, int[] assignments) {
static PrefetchingFloatVectorValues createClusterSlice(
int clusterSize,
int cluster,
PrefetchingFloatVectorValues vectors,
int[] assignments
) {
int[] slice = new int[clusterSize];
int idx = 0;
for (int i = 0; i < assignments.length; i++) {
Expand Down
Loading