diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java index f7d804c428830..72f7f620d211a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java @@ -14,12 +14,16 @@ import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.common.util.Maps; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import java.io.IOException; +import java.util.Collections; +import java.util.Map; /** * Codec format for Inverted File Vector indexes. This index expects to break the dimensional space @@ -59,6 +63,7 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat { private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); + private static final Map supportedFormats = Map.of(rawVectorFormat.getName(), rawVectorFormat); // This dynamically sets the cluster probe based on the `k` requested and the number of clusters. // useful when searching with 'efSearch' type parameters instead of requiring a specific ratio. @@ -106,12 +111,23 @@ public ES920DiskBBQVectorsFormat() { @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new ES920DiskBBQVectorsWriter(state, rawVectorFormat.fieldsWriter(state), vectorPerCluster, centroidsPerParentCluster); + return new ES920DiskBBQVectorsWriter( + rawVectorFormat.getName(), + state, + rawVectorFormat.fieldsWriter(state), + vectorPerCluster, + centroidsPerParentCluster + ); } @Override public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { - return new ES920DiskBBQVectorsReader(state, rawVectorFormat.fieldsReader(state)); + Map readers = Maps.newHashMapWithExpectedSize(supportedFormats.size()); + for (var fe : supportedFormats.entrySet()) { + readers.put(fe.getKey(), fe.getValue().fieldsReader(state)); + } + + return new ES920DiskBBQVectorsReader(state, Collections.unmodifiableMap(readers)); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java index 2c76573731329..45a8c281028ae 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java @@ -40,7 +40,7 @@ */ public class ES920DiskBBQVectorsReader extends IVFVectorsReader implements OffHeapStats { - public ES920DiskBBQVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException { + public ES920DiskBBQVectorsReader(SegmentReadState state, Map rawVectorsReader) throws IOException { super(state, rawVectorsReader); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java index 3e71dbe0d4416..d4dc3c377ee2f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java @@ -51,12 +51,13 @@ public class ES920DiskBBQVectorsWriter extends IVFVectorsWriter { private final int centroidsPerParentCluster; public ES920DiskBBQVectorsWriter( + String rawVectorFormatName, SegmentWriteState state, FlatVectorsWriter rawVectorDelegate, int vectorPerCluster, int centroidsPerParentCluster ) throws IOException { - super(state, rawVectorDelegate); + super(state, rawVectorFormatName, rawVectorDelegate); this.vectorPerCluster = vectorPerCluster; this.centroidsPerParentCluster = centroidsPerParentCluster; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java index dc531c7ca8e56..8d7ac579451ad 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java @@ -32,7 +32,12 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.search.vectors.IVFKnnSearchStrategy; +import java.io.Closeable; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.DYNAMIC_VISIT_RATIO; @@ -46,14 +51,14 @@ public abstract class IVFVectorsReader extends KnnVectorsReader { private final SegmentReadState state; private final FieldInfos fieldInfos; protected final IntObjectHashMap fields; - private final FlatVectorsReader rawVectorsReader; + private final Map rawVectorReaders; @SuppressWarnings("this-escape") - protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException { + protected IVFVectorsReader(SegmentReadState state, Map rawVectorReaders) throws IOException { this.state = state; this.fieldInfos = state.fieldInfos; - this.rawVectorsReader = rawVectorsReader; this.fields = new IntObjectHashMap<>(); + this.rawVectorReaders = rawVectorReaders; String meta = IndexFileNames.segmentFileName( state.segmentInfo.name, state.segmentSuffix, @@ -156,6 +161,7 @@ private void readFields(ChecksumIndexInput meta) throws IOException { } private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { + final String rawVectorFormat = input.readString(); final VectorEncoding vectorEncoding = readVectorEncoding(input); final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); if (similarityFunction != info.getVectorSimilarityFunction()) { @@ -182,6 +188,7 @@ private FieldEntry readField(IndexInput input, FieldInfo info) throws IOExceptio globalCentroidDp = Float.intBitsToFloat(input.readInt()); } return new FieldEntry( + rawVectorFormat, similarityFunction, vectorEncoding, numCentroids, @@ -212,26 +219,46 @@ private static VectorEncoding readVectorEncoding(DataInput input) throws IOExcep @Override public final void checkIntegrity() throws IOException { - rawVectorsReader.checkIntegrity(); + for (var reader : rawVectorReaders.values()) { + reader.checkIntegrity(); + } CodecUtil.checksumEntireFile(ivfCentroids); CodecUtil.checksumEntireFile(ivfClusters); } + private FieldEntry getFieldEntryOrThrow(String field) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry entry; + if (info == null || (entry = fields.get(info.number)) == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } + return entry; + } + + private FlatVectorsReader getReaderForField(String field) { + var formatName = getFieldEntryOrThrow(field).rawVectorFormatName; + FlatVectorsReader reader = rawVectorReaders.get(formatName); + if (reader == null) throw new IllegalArgumentException( + "Could not find raw vector format [" + formatName + "] for field [" + field + "]" + ); + return reader; + } + @Override public final FloatVectorValues getFloatVectorValues(String field) throws IOException { - return rawVectorsReader.getFloatVectorValues(field); + return getReaderForField(field).getFloatVectorValues(field); } @Override public final ByteVectorValues getByteVectorValues(String field) throws IOException { - return rawVectorsReader.getByteVectorValues(field); + return getReaderForField(field).getByteVectorValues(field); } @Override public final void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field); if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) { - rawVectorsReader.search(field, target, knnCollector, acceptDocs); + getReaderForField(field).search(field, target, knnCollector, acceptDocs); return; } if (fieldInfo.getVectorDimension() != target.length) { @@ -243,7 +270,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector if (acceptDocs instanceof BitSet bitSet) { percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length())); } - int numVectors = rawVectorsReader.getFloatVectorValues(field).size(); + int numVectors = getReaderForField(field).getFloatVectorValues(field).size(); float visitRatio = DYNAMIC_VISIT_RATIO; // Search strategy may be null if this is being called from checkIndex (e.g. from a test) if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) { @@ -309,7 +336,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector @Override public final void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field); - final ByteVectorValues values = rawVectorsReader.getByteVectorValues(field); + final ByteVectorValues values = getReaderForField(field).getByteVectorValues(field); for (int i = 0; i < values.size(); i++) { final float score = fieldInfo.getVectorSimilarityFunction().compare(target, values.vectorValue(i)); knnCollector.collect(values.ordToDoc(i), score); @@ -321,10 +348,13 @@ public final void search(String field, byte[] target, KnnCollector knnCollector, @Override public void close() throws IOException { - IOUtils.close(rawVectorsReader, ivfCentroids, ivfClusters); + List closeables = new ArrayList<>(rawVectorReaders.values()); + Collections.addAll(closeables, ivfCentroids, ivfClusters); + IOUtils.close(closeables); } protected record FieldEntry( + String rawVectorFormatName, VectorSimilarityFunction similarityFunction, VectorEncoding vectorEncoding, int numCentroids, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsWriter.java index 26b1e12991bd5..0f5988b2cd48c 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsWriter.java @@ -51,10 +51,13 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter { private final List fieldWriters = new ArrayList<>(); private final IndexOutput ivfCentroids, ivfClusters; private final IndexOutput ivfMeta; + private final String rawVectorFormatName; private final FlatVectorsWriter rawVectorDelegate; @SuppressWarnings("this-escape") - protected IVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) throws IOException { + protected IVFVectorsWriter(SegmentWriteState state, String rawVectorFormatName, FlatVectorsWriter rawVectorDelegate) + throws IOException { + this.rawVectorFormatName = rawVectorFormatName; this.rawVectorDelegate = rawVectorDelegate; final String metaFileName = IndexFileNames.segmentFileName( state.segmentInfo.name, @@ -116,6 +119,9 @@ public final KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOExc @SuppressWarnings("unchecked") final FlatFieldVectorsWriter floatWriter = (FlatFieldVectorsWriter) rawVectorDelegate; fieldWriters.add(new FieldWriter(fieldInfo, floatWriter)); + } else { + // we simply write information that the field is present but we don't do anything with it. + fieldWriters.add(new FieldWriter(fieldInfo, null)); } return rawVectorDelegate; } @@ -165,6 +171,11 @@ abstract CentroidSupplier createCentroidSupplier( public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { rawVectorDelegate.flush(maxDoc, sortMap); for (FieldWriter fieldWriter : fieldWriters) { + if (fieldWriter.delegate == null) { + // field is not float, we just write meta information + writeMeta(fieldWriter.fieldInfo, 0, 0, 0, 0, 0, null); + continue; + } final float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()]; // build a float vector values with random access final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc); @@ -248,6 +259,9 @@ public int ordToDoc(int ord) { public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { mergeOneFieldIVF(fieldInfo, mergeState); + } else { + // we simply write information that the field is present but we don't do anything with it. + writeMeta(fieldInfo, 0, 0, 0, 0, 0, null); } // we merge the vectors at the end so we only have two copies of the vectors on disk at the same time. rawVectorDelegate.mergeOneField(fieldInfo, mergeState); @@ -476,6 +490,7 @@ private void writeMeta( float[] globalCentroid ) throws IOException { ivfMeta.writeInt(field.number); + ivfMeta.writeString(rawVectorFormatName); ivfMeta.writeInt(field.getVectorEncoding().ordinal()); ivfMeta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction())); ivfMeta.writeInt(numCentroids);