diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index ab93bcdeae11c..2987b3849e663 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -464,6 +464,7 @@ org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat, org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat, + org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; provides org.apache.lucene.codecs.Codec diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DirectIOCapableFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DirectIOCapableFlatVectorsFormat.java index 9f3fa74f3b88e..0d67281bf5606 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DirectIOCapableFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DirectIOCapableFlatVectorsFormat.java @@ -19,5 +19,10 @@ protected DirectIOCapableFlatVectorsFormat(String name) { super(name); } + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return fieldsReader(state, false); + } + public abstract FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/GenericFlatVectorReaders.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/GenericFlatVectorReaders.java new file mode 100644 index 0000000000000..3fd66fe757d7b --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/GenericFlatVectorReaders.java @@ -0,0 +1,92 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; + +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.IdentityHashMap; +import java.util.Map; + +/** + * Keeps track of field-specific raw vector readers for vector reads + */ +public class GenericFlatVectorReaders { + + public interface Field { + String rawVectorFormatName(); + + boolean useDirectIOReads(); + } + + @FunctionalInterface + public interface LoadFlatVectorsReader { + FlatVectorsReader getReader(String formatName, boolean useDirectIO) throws IOException; + } + + private record FlatVectorsReaderKey(String formatName, boolean useDirectIO) { + private FlatVectorsReaderKey(Field field) { + this(field.rawVectorFormatName(), field.useDirectIOReads()); + } + + @Override + public String toString() { + return formatName + (useDirectIO ? " with Direct IO" : ""); + } + } + + private final Map readers = new HashMap<>(); + private final Map readersForFields = new HashMap<>(); + + public void loadField(int fieldNumber, Field field, LoadFlatVectorsReader loadReader) throws IOException { + FlatVectorsReaderKey key = new FlatVectorsReaderKey(field); + FlatVectorsReader reader = readers.get(key); + if (reader == null) { + reader = loadReader.getReader(field.rawVectorFormatName(), field.useDirectIOReads()); + if (reader == null) { + throw new IllegalStateException("Cannot find flat vector format: " + field.rawVectorFormatName()); + } + readers.put(key, reader); + } + readersForFields.put(fieldNumber, reader); + } + + public FlatVectorsReader getReaderForField(int fieldNumber) { + FlatVectorsReader reader = readersForFields.get(fieldNumber); + if (reader == null) { + throw new IllegalArgumentException("Invalid field number [" + fieldNumber + "]"); + } + return reader; + } + + public Collection allReaders() { + return Collections.unmodifiableCollection(readers.values()); + } + + public GenericFlatVectorReaders getMergeInstance() throws IOException { + GenericFlatVectorReaders mergeReaders = new GenericFlatVectorReaders(); + + // link the original instance with the merge instance + Map mergeInstances = new IdentityHashMap<>(); + for (var reader : readers.entrySet()) { + FlatVectorsReader mergeInstance = reader.getValue().getMergeInstance(); + mergeInstances.put(reader.getValue(), mergeInstance); + mergeReaders.readers.put(reader.getKey(), mergeInstance); + } + // link up the fields to the merge readers + for (var field : readersForFields.entrySet()) { + mergeReaders.readersForFields.put(field.getKey(), mergeInstances.get(field.getValue())); + } + return mergeReaders; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java index 7d8a563c99f7f..3d6f7f24a1600 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java @@ -16,6 +16,7 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.index.codec.vectors.GenericFlatVectorReaders; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.cluster.NeighborQueue; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; @@ -38,7 +39,7 @@ */ public class ES920DiskBBQVectorsReader extends IVFVectorsReader { - ES920DiskBBQVectorsReader(SegmentReadState state, GetFormatReader getFormatReader) throws IOException { + ES920DiskBBQVectorsReader(SegmentReadState state, GenericFlatVectorReaders.LoadFlatVectorsReader getFormatReader) throws IOException { super(state, getFormatReader); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java index 81c1e144effdb..1fb0ecd1b7552 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java @@ -30,13 +30,13 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; import org.elasticsearch.core.IOUtils; +import org.elasticsearch.index.codec.vectors.GenericFlatVectorReaders; import org.elasticsearch.search.vectors.IVFKnnSearchStrategy; import java.io.Closeable; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; @@ -49,33 +49,18 @@ */ public abstract class IVFVectorsReader extends KnnVectorsReader { - private record FlatVectorsReaderKey(String formatName, boolean useDirectIO) { - private FlatVectorsReaderKey(FieldEntry entry) { - this(entry.rawVectorFormatName, entry.useDirectIOReads); - } - - @Override - public String toString() { - return formatName + (useDirectIO ? " with Direct IO" : ""); - } - } - private final IndexInput ivfCentroids, ivfClusters; private final SegmentReadState state; private final FieldInfos fieldInfos; protected final IntObjectHashMap fields; - private final Map rawVectorReaders; - - @FunctionalInterface - public interface GetFormatReader { - FlatVectorsReader getReader(String formatName, boolean useDirectIO) throws IOException; - } + private final GenericFlatVectorReaders genericReaders; @SuppressWarnings("this-escape") - protected IVFVectorsReader(SegmentReadState state, GetFormatReader getFormatReader) throws IOException { + protected IVFVectorsReader(SegmentReadState state, GenericFlatVectorReaders.LoadFlatVectorsReader loadReader) throws IOException { this.state = state; this.fieldInfos = state.fieldInfos; this.fields = new IntObjectHashMap<>(); + this.genericReaders = new GenericFlatVectorReaders(); String meta = IndexFileNames.segmentFileName( state.segmentInfo.name, state.segmentSuffix, @@ -86,7 +71,6 @@ protected IVFVectorsReader(SegmentReadState state, GetFormatReader getFormatRead boolean success = false; try (ChecksumIndexInput ivfMeta = state.directory.openChecksumInput(meta)) { Throwable priorE = null; - Map readers = null; try { versionMeta = CodecUtil.checkIndexHeader( ivfMeta, @@ -96,13 +80,12 @@ protected IVFVectorsReader(SegmentReadState state, GetFormatReader getFormatRead state.segmentInfo.getId(), state.segmentSuffix ); - readers = readFields(ivfMeta, getFormatReader, versionMeta); + readFields(ivfMeta, versionMeta, genericReaders, loadReader); } catch (Throwable exception) { priorE = exception; } finally { CodecUtil.checkFooter(ivfMeta, priorE); } - this.rawVectorReaders = readers; ivfCentroids = openDataInput( state, versionMeta, @@ -169,9 +152,12 @@ private static IndexInput openDataInput( } } - private Map readFields(ChecksumIndexInput meta, GetFormatReader loadReader, int versionMeta) - throws IOException { - Map readers = new HashMap<>(); + private void readFields( + ChecksumIndexInput meta, + int versionMeta, + GenericFlatVectorReaders genericFields, + GenericFlatVectorReaders.LoadFlatVectorsReader loadReader + ) throws IOException { for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { final FieldInfo info = fieldInfos.fieldInfo(fieldNumber); if (info == null) { @@ -179,20 +165,10 @@ private Map readFields(ChecksumIndexInp } FieldEntry fieldEntry = readField(meta, info, versionMeta); - FlatVectorsReaderKey key = new FlatVectorsReaderKey(fieldEntry); - - FlatVectorsReader reader = readers.get(key); - if (reader == null) { - reader = loadReader.getReader(fieldEntry.rawVectorFormatName, fieldEntry.useDirectIOReads); - if (reader == null) { - throw new IllegalStateException("Cannot find flat vector format: " + fieldEntry.rawVectorFormatName); - } - readers.put(key, reader); - } + genericFields.loadField(fieldNumber, fieldEntry, loadReader); fields.put(info.number, fieldEntry); } - return readers; } private FieldEntry readField(IndexInput input, FieldInfo info, int versionMeta) throws IOException { @@ -256,29 +232,17 @@ private static VectorEncoding readVectorEncoding(DataInput input) throws IOExcep @Override public final void checkIntegrity() throws IOException { - for (var reader : rawVectorReaders.values()) { + for (var reader : genericReaders.allReaders()) { reader.checkIntegrity(); } CodecUtil.checksumEntireFile(ivfCentroids); CodecUtil.checksumEntireFile(ivfClusters); } - private FieldEntry getFieldEntryOrThrow(String field) { - final FieldInfo info = fieldInfos.fieldInfo(field); - final FieldEntry entry; - if (info == null || (entry = fields.get(info.number)) == null) { - throw new IllegalArgumentException("field=\"" + field + "\" not found"); - } - return entry; - } - private FlatVectorsReader getReaderForField(String field) { - var readerKey = new FlatVectorsReaderKey(getFieldEntryOrThrow(field)); - FlatVectorsReader reader = rawVectorReaders.get(readerKey); - if (reader == null) throw new IllegalArgumentException( - "Could not find raw vector format [" + readerKey + "] for field [" + field + "]" - ); - return reader; + FieldInfo info = fieldInfos.fieldInfo(field); + if (info == null) throw new IllegalArgumentException("Could not find field [" + field + "]"); + return genericReaders.getReaderForField(info.number); } @Override @@ -399,7 +363,7 @@ public Map getOffHeapByteSize(FieldInfo fieldInfo) { @Override public void close() throws IOException { - List closeables = new ArrayList<>(rawVectorReaders.values()); + List closeables = new ArrayList<>(genericReaders.allReaders()); Collections.addAll(closeables, ivfCentroids, ivfClusters); IOUtils.close(closeables); } @@ -416,7 +380,7 @@ protected record FieldEntry( long postingListLength, float[] globalCentroid, float globalCentroidDp - ) { + ) implements GenericFlatVectorReaders.Field { IndexInput centroidSlice(IndexInput centroidFile) throws IOException { return centroidFile.slice("centroids", centroidOffset, centroidLength); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java index 13e343e5273b9..2d71070dad9ec 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java @@ -16,6 +16,7 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.index.codec.vectors.GenericFlatVectorReaders; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.cluster.NeighborQueue; import org.elasticsearch.index.codec.vectors.diskbbq.DocIdsWriter; @@ -40,7 +41,8 @@ */ public class ESNextDiskBBQVectorsReader extends IVFVectorsReader { - public ESNextDiskBBQVectorsReader(SegmentReadState state, GetFormatReader getFormatReader) throws IOException { + public ESNextDiskBBQVectorsReader(SegmentReadState state, GenericFlatVectorReaders.LoadFlatVectorsReader getFormatReader) + throws IOException { super(state, getFormatReader); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java index 119f2c0f77535..69f4f96a4e829 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java @@ -32,6 +32,7 @@ import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsWriter; import java.io.IOException; +import java.util.Map; /** * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 @@ -86,19 +87,33 @@ *
  • The sparse vector information, if required, mapping vector ordinal to doc ID * */ -public class ES93BinaryQuantizedVectorsFormat extends DirectIOCapableFlatVectorsFormat { +public class ES93BinaryQuantizedVectorsFormat extends ES93GenericFlatVectorsFormat { public static final String NAME = "ES93BinaryQuantizedVectorsFormat"; - private final DirectIOCapableLucene99FlatVectorsFormat rawVectorFormat; + private static final DirectIOCapableFlatVectorsFormat rawVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + private static final Map supportedFormats = Map.of( + rawVectorFormat.getName(), + rawVectorFormat + ); private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); + private final boolean useDirectIO; + public ES93BinaryQuantizedVectorsFormat() { super(NAME); - rawVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + this.useDirectIO = false; + } + + public ES93BinaryQuantizedVectorsFormat(boolean useDirectIO) { + super(NAME); + this.useDirectIO = useDirectIO; } @Override @@ -107,17 +122,27 @@ protected FlatVectorsScorer flatVectorsScorer() { } @Override - public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new ES818BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state); + protected boolean useDirectIOReads() { + return useDirectIO; } @Override - public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { - return new ES818BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), scorer); + protected DirectIOCapableFlatVectorsFormat writeFlatVectorsFormat() { + return rawVectorFormat; } @Override - public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException { - return new ES818BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state, useDirectIO), scorer); + protected Map supportedReadFlatVectorsFormats() { + return supportedFormats; + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES818BinaryQuantizedVectorsWriter(scorer, super.fieldsWriter(state), state); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new ES818BinaryQuantizedVectorsReader(state, super.fieldsReader(state), scorer); } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java new file mode 100644 index 0000000000000..526a4241ed89e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.AbstractFlatVectorsFormat; +import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; + +import java.io.IOException; +import java.util.Map; + +public abstract class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat { + + static final String VECTOR_FORMAT_INFO_EXTENSION = "vfi"; + static final String META_CODEC_NAME = "ES93GenericFlatVectorsFormatMeta"; + + public static final int VERSION_START = 0; + public static final int VERSION_CURRENT = VERSION_START; + + private static final GenericFormatMetaInformation META = new GenericFormatMetaInformation( + VECTOR_FORMAT_INFO_EXTENSION, + META_CODEC_NAME, + VERSION_START, + VERSION_CURRENT + ); + + public ES93GenericFlatVectorsFormat(String name) { + super(name); + } + + protected abstract DirectIOCapableFlatVectorsFormat writeFlatVectorsFormat(); + + protected abstract boolean useDirectIOReads(); + + protected abstract Map supportedReadFlatVectorsFormats(); + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + var flatFormat = writeFlatVectorsFormat(); + boolean directIO = useDirectIOReads(); + return new ES93GenericFlatVectorsWriter(META, flatFormat.getName(), directIO, state, flatFormat.fieldsWriter(state)); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + var readFormats = supportedReadFlatVectorsFormats(); + return new ES93GenericFlatVectorsReader(META, state, (f, dio) -> { + var format = readFormats.get(f); + if (format == null) return null; + return format.fieldsReader(state, dio); + }); + } + + @Override + public String toString() { + return getName() + + "(name=" + + getName() + + ", writeFlatVectorFormat=" + + writeFlatVectorsFormat() + + ", readFlatVectorsFormats=" + + supportedReadFlatVectorsFormats().values() + + ")"; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsReader.java index 02811d31a8024..96c5dea8ba4d1 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsReader.java @@ -13,131 +13,193 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.elasticsearch.core.IOUtils; +import org.elasticsearch.index.codec.vectors.GenericFlatVectorReaders; import java.io.IOException; import java.util.Map; -import static org.elasticsearch.index.codec.vectors.es93.ES93GenericHnswVectorsFormat.META_CODEC_NAME; -import static org.elasticsearch.index.codec.vectors.es93.ES93GenericHnswVectorsFormat.VECTOR_FORMAT_INFO_EXTENSION; -import static org.elasticsearch.index.codec.vectors.es93.ES93GenericHnswVectorsFormat.VERSION_CURRENT; -import static org.elasticsearch.index.codec.vectors.es93.ES93GenericHnswVectorsFormat.VERSION_START; - class ES93GenericFlatVectorsReader extends FlatVectorsReader { - private final FlatVectorsReader vectorsReader; + private final FieldInfos fieldInfos; + private final GenericFlatVectorReaders genericReaders; - @FunctionalInterface - interface GetFormatReader { - FlatVectorsReader getReader(String formatName, boolean useDirectIO) throws IOException; - } + ES93GenericFlatVectorsReader( + GenericFormatMetaInformation metaInfo, + SegmentReadState state, + GenericFlatVectorReaders.LoadFlatVectorsReader loadReader + ) throws IOException { + super(null); // this is not actually used by anything + + this.fieldInfos = state.fieldInfos; + this.genericReaders = new GenericFlatVectorReaders(); - ES93GenericFlatVectorsReader(SegmentReadState state, GetFormatReader getFormatReader) throws IOException { - super(null); // Hacks ahoy! // read in the meta information - final String metaFileName = IndexFileNames.segmentFileName( - state.segmentInfo.name, - state.segmentSuffix, - VECTOR_FORMAT_INFO_EXTENSION - ); + final String metaFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, metaInfo.extension()); int versionMeta = -1; - FlatVectorsReader reader = null; try (var metaIn = state.directory.openChecksumInput(metaFileName)) { Throwable priorE = null; try { versionMeta = CodecUtil.checkIndexHeader( metaIn, - META_CODEC_NAME, - VERSION_START, - VERSION_CURRENT, + metaInfo.codecName(), + metaInfo.versionStart(), + metaInfo.versionCurrent(), state.segmentInfo.getId(), state.segmentSuffix ); - String innerFormatName = metaIn.readString(); - byte useDirectIO = metaIn.readByte(); - reader = getFormatReader.getReader(innerFormatName, useDirectIO == 1); - if (reader == null) { - throw new IllegalStateException( - "Cannot find knn vector format [" + innerFormatName + "]" + (useDirectIO == 1 ? " with directIO" : "") - ); - } + + readFields(metaIn, state.fieldInfos, genericReaders, loadReader); } catch (Throwable exception) { priorE = exception; } finally { CodecUtil.checkFooter(metaIn, priorE); } - vectorsReader = reader; } catch (Throwable t) { IOUtils.closeWhileHandlingException(this); throw t; } } + private ES93GenericFlatVectorsReader(FieldInfos fieldInfos, GenericFlatVectorReaders genericReaders) { + super(null); + this.fieldInfos = fieldInfos; + this.genericReaders = genericReaders; + } + + private static void readFields( + IndexInput meta, + FieldInfos fieldInfos, + GenericFlatVectorReaders fieldHelper, + GenericFlatVectorReaders.LoadFlatVectorsReader loadReader + ) throws IOException { + record FieldEntry(String rawVectorFormatName, boolean useDirectIOReads) implements GenericFlatVectorReaders.Field {} + + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + final FieldInfo info = fieldInfos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + + FieldEntry entry = new FieldEntry(meta.readString(), meta.readByte() == 1); + fieldHelper.loadField(fieldNumber, entry, loadReader); + } + } + @Override public FlatVectorsScorer getFlatVectorScorer() { - return vectorsReader.getFlatVectorScorer(); + // this should not actually be used at all + return new FlatVectorsScorer() { + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues + ) throws IOException { + throw new UnsupportedOperationException("Scorer should not be used"); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues, + float[] target + ) throws IOException { + throw new UnsupportedOperationException("Scorer should not be used"); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues, + byte[] target + ) throws IOException { + throw new UnsupportedOperationException("Scorer should not be used"); + } + }; } @Override public FlatVectorsReader getMergeInstance() throws IOException { - // we know what the reader is, so we can return it directly - return vectorsReader.getMergeInstance(); + return new ES93GenericFlatVectorsReader(fieldInfos, genericReaders.getMergeInstance()); + } + + @Override + public void finishMerge() throws IOException { + for (var reader : genericReaders.allReaders()) { + reader.finishMerge(); + } } @Override public void checkIntegrity() throws IOException { - vectorsReader.checkIntegrity(); + for (var reader : genericReaders.allReaders()) { + reader.checkIntegrity(); + } + } + + private int findField(String field) { + FieldInfo info = fieldInfos.fieldInfo(field); + if (info == null) { + throw new IllegalArgumentException("Could not find field [" + field + "]"); + } + return info.number; } @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { - return vectorsReader.getFloatVectorValues(field); + return genericReaders.getReaderForField(findField(field)).getFloatVectorValues(field); } @Override public ByteVectorValues getByteVectorValues(String field) throws IOException { - return vectorsReader.getByteVectorValues(field); + return genericReaders.getReaderForField(findField(field)).getByteVectorValues(field); } @Override public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { - return vectorsReader.getRandomVectorScorer(field, target); + return genericReaders.getReaderForField(findField(field)).getRandomVectorScorer(field, target); } @Override public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { - return vectorsReader.getRandomVectorScorer(field, target); + return genericReaders.getReaderForField(findField(field)).getRandomVectorScorer(field, target); } @Override public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { - vectorsReader.search(field, target, knnCollector, acceptDocs); + genericReaders.getReaderForField(findField(field)).search(field, target, knnCollector, acceptDocs); } @Override public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { - vectorsReader.search(field, target, knnCollector, acceptDocs); + genericReaders.getReaderForField(findField(field)).search(field, target, knnCollector, acceptDocs); } @Override public long ramBytesUsed() { - return vectorsReader.ramBytesUsed(); + return genericReaders.allReaders().stream().mapToLong(FlatVectorsReader::ramBytesUsed).sum(); } @Override public Map getOffHeapByteSize(FieldInfo fieldInfo) { - return vectorsReader.getOffHeapByteSize(fieldInfo); + return genericReaders.getReaderForField(fieldInfo.number).getOffHeapByteSize(fieldInfo); } @Override public void close() throws IOException { - IOUtils.close(vectorsReader); + IOUtils.close(genericReaders.allReaders()); } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsWriter.java index b8b208f205b11..81c55d12f073a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsWriter.java @@ -22,32 +22,40 @@ import org.elasticsearch.core.IOUtils; import java.io.IOException; - -import static org.elasticsearch.index.codec.vectors.es93.ES93GenericHnswVectorsFormat.META_CODEC_NAME; -import static org.elasticsearch.index.codec.vectors.es93.ES93GenericHnswVectorsFormat.VECTOR_FORMAT_INFO_EXTENSION; -import static org.elasticsearch.index.codec.vectors.es93.ES93GenericHnswVectorsFormat.VERSION_CURRENT; +import java.util.ArrayList; +import java.util.List; class ES93GenericFlatVectorsWriter extends FlatVectorsWriter { - private final IndexOutput metaOut; + private final String rawVectorFormatName; + private final boolean useDirectIOReads; private final FlatVectorsWriter rawVectorWriter; + private final IndexOutput metaOut; + private final List fieldNumbers = new ArrayList<>(); @SuppressWarnings("this-escape") - ES93GenericFlatVectorsWriter(String knnFormatName, boolean useDirectIOReads, SegmentWriteState state, FlatVectorsWriter rawWriter) - throws IOException { + ES93GenericFlatVectorsWriter( + GenericFormatMetaInformation metaInfo, + String rawVectorsFormatName, + boolean useDirectIOReads, + SegmentWriteState state, + FlatVectorsWriter rawWriter + ) throws IOException { super(rawWriter.getFlatVectorScorer()); + this.rawVectorFormatName = rawVectorsFormatName; + this.useDirectIOReads = useDirectIOReads; this.rawVectorWriter = rawWriter; - final String metaFileName = IndexFileNames.segmentFileName( - state.segmentInfo.name, - state.segmentSuffix, - VECTOR_FORMAT_INFO_EXTENSION - ); + + final String metaFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, metaInfo.extension()); try { this.metaOut = state.directory.createOutput(metaFileName, state.context); - CodecUtil.writeIndexHeader(metaOut, META_CODEC_NAME, VERSION_CURRENT, state.segmentInfo.getId(), state.segmentSuffix); - // write the format name used for this segment - metaOut.writeString(knnFormatName); - metaOut.writeByte(useDirectIOReads ? (byte) 1 : 0); + CodecUtil.writeIndexHeader( + metaOut, + metaInfo.codecName(), + metaInfo.versionCurrent(), + state.segmentInfo.getId(), + state.segmentSuffix + ); } catch (Throwable t) { IOUtils.closeWhileHandlingException(this); throw t; @@ -56,22 +64,37 @@ class ES93GenericFlatVectorsWriter extends FlatVectorsWriter { @Override public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { - return rawVectorWriter.addField(fieldInfo); + var writer = rawVectorWriter.addField(fieldInfo); + fieldNumbers.add(fieldInfo.number); + return writer; } @Override public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { rawVectorWriter.mergeOneField(fieldInfo, mergeState); + writeMeta(fieldInfo.number); } @Override public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldInfo, MergeState mergeState) throws IOException { - return rawVectorWriter.mergeOneFieldToIndex(fieldInfo, mergeState); + var supplier = rawVectorWriter.mergeOneFieldToIndex(fieldInfo, mergeState); + writeMeta(fieldInfo.number); + return supplier; } @Override public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { rawVectorWriter.flush(maxDoc, sortMap); + + for (Integer field : fieldNumbers) { + writeMeta(field); + } + } + + private void writeMeta(int field) throws IOException { + metaOut.writeInt(field); + metaOut.writeString(rawVectorFormatName); + metaOut.writeByte(useDirectIOReads ? (byte) 1 : 0); } @Override @@ -79,6 +102,7 @@ public void finish() throws IOException { rawVectorWriter.finish(); if (metaOut != null) { + metaOut.writeInt(-1); // no more fields CodecUtil.writeFooter(metaOut); } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericHnswVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericHnswVectorsFormat.java deleted file mode 100644 index aa4520772ca73..0000000000000 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericHnswVectorsFormat.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * @notice - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * Modifications copyright (C) 2024 Elasticsearch B.V. - */ - -package org.elasticsearch.index.codec.vectors.es93; - -import org.apache.lucene.codecs.KnnVectorsReader; -import org.apache.lucene.codecs.KnnVectorsWriter; -import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; -import org.apache.lucene.index.SegmentReadState; -import org.apache.lucene.index.SegmentWriteState; -import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat; -import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; - -import java.io.IOException; -import java.util.Map; -import java.util.concurrent.ExecutorService; - -public abstract class ES93GenericHnswVectorsFormat extends AbstractHnswVectorsFormat { - - static final String VECTOR_FORMAT_INFO_EXTENSION = "vfi"; - static final String META_CODEC_NAME = "ES93GenericVectorsFormatMeta"; - - public static final int VERSION_START = 0; - public static final int VERSION_GROUPVARINT = 1; - public static final int VERSION_CURRENT = VERSION_GROUPVARINT; - - public ES93GenericHnswVectorsFormat(String name) { - super(name); - } - - public ES93GenericHnswVectorsFormat(String name, int maxConn, int beamWidth) { - super(name, maxConn, beamWidth); - } - - public ES93GenericHnswVectorsFormat(String name, int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { - super(name, maxConn, beamWidth, numMergeWorkers, mergeExec); - } - - @Override - protected final FlatVectorsFormat flatVectorsFormat() { - return writeFlatVectorsFormat(); - } - - protected abstract FlatVectorsFormat writeFlatVectorsFormat(); - - protected abstract boolean useDirectIOReads(); - - protected abstract Map supportedReadFlatVectorsFormats(); - - @Override - public final KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - var flatFormat = writeFlatVectorsFormat(); - boolean directIO = useDirectIOReads(); - return new Lucene99HnswVectorsWriter( - state, - maxConn, - beamWidth, - new ES93GenericFlatVectorsWriter(flatFormat.getName(), directIO, state, flatFormat.fieldsWriter(state)), - numMergeWorkers, - mergeExec - ); - } - - @Override - public final KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { - var readFormats = supportedReadFlatVectorsFormats(); - return new Lucene99HnswVectorsReader(state, new ES93GenericFlatVectorsReader(state, (f, dio) -> { - var format = readFormats.get(f); - if (format == null) return null; - - if (format instanceof DirectIOCapableFlatVectorsFormat diof) { - return diof.fieldsReader(state, dio); - } else { - assert dio == false : format + " is not DirectIO capable"; - return format.fieldsReader(state); - } - })); - } - - @Override - public String toString() { - return getName() - + "(name=" - + getName() - + ", maxConn=" - + maxConn - + ", beamWidth=" - + beamWidth - + ", writeFlatVectorFormat=" - + writeFlatVectorsFormat() - + ", readFlatVectorsFormats=" - + supportedReadFlatVectorsFormats().values() - + ")"; - } -} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java index 66c5406749264..b1ade1524e250 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java @@ -19,36 +19,29 @@ */ package org.elasticsearch.index.codec.vectors.es93; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat; -import java.util.Map; +import java.io.IOException; import java.util.concurrent.ExecutorService; -public class ES93HnswBinaryQuantizedVectorsFormat extends ES93GenericHnswVectorsFormat { +public class ES93HnswBinaryQuantizedVectorsFormat extends AbstractHnswVectorsFormat { public static final String NAME = "ES93HnswBinaryQuantizedVectorsFormat"; - private static final FlatVectorsFormat flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(); - - private static final Map supportedFormats = Map.of(flatVectorsFormat.getName(), flatVectorsFormat); - - private final boolean useDirectIO; + /** The format for storing, reading, merging vectors on disk */ + private final FlatVectorsFormat flatVectorsFormat; /** Constructs a format using default graph construction parameters */ public ES93HnswBinaryQuantizedVectorsFormat() { super(NAME); - useDirectIO = false; - } - - /** - * Constructs a format using the given graph construction parameters. - * - * @param maxConn the maximum number of connections to a node in the HNSW graph - * @param beamWidth the size of the queue maintained during graph construction. - */ - public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) { - super(NAME, maxConn, beamWidth); - useDirectIO = false; + flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(); } /** @@ -56,11 +49,11 @@ public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) { * * @param maxConn the maximum number of connections to a node in the HNSW graph * @param beamWidth the size of the queue maintained during graph construction. - * @param useDirectIO whether direct IO should be used for reads for data written using this format + * @param useDirectIO whether to use direct IO when reading raw vectors */ public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean useDirectIO) { super(NAME, maxConn, beamWidth); - this.useDirectIO = useDirectIO; + flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useDirectIO); } /** @@ -68,28 +61,35 @@ public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean * * @param maxConn the maximum number of connections to a node in the HNSW graph * @param beamWidth the size of the queue maintained during graph construction. + * @param useDirectIO whether to use direct IO when reading raw vectors * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are * generated by this format to do the merge */ - public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + public ES93HnswBinaryQuantizedVectorsFormat( + int maxConn, + int beamWidth, + boolean useDirectIO, + int numMergeWorkers, + ExecutorService mergeExec + ) { super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); - useDirectIO = false; + flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useDirectIO); } @Override - protected FlatVectorsFormat writeFlatVectorsFormat() { + protected FlatVectorsFormat flatVectorsFormat() { return flatVectorsFormat; } @Override - protected boolean useDirectIOReads() { - return useDirectIO; + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); } @Override - protected Map supportedReadFlatVectorsFormats() { - return supportedFormats; + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/GenericFormatMetaInformation.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/GenericFormatMetaInformation.java new file mode 100644 index 0000000000000..b43dd1d481b33 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/GenericFormatMetaInformation.java @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es93; + +record GenericFormatMetaInformation(String extension, String codecName, int versionStart, int versionCurrent) {} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index 1853364738822..e256453ac6d31 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -50,9 +50,9 @@ public class MapperFeatures implements FeatureSpecification { static final NodeFeature SEARCH_LOAD_PER_SHARD = new NodeFeature("mapper.search_load_per_shard"); static final NodeFeature PATTERN_TEXT = new NodeFeature("mapper.patterned_text"); static final NodeFeature IGNORED_SOURCE_FIELDS_PER_ENTRY = new NodeFeature("mapper.ignored_source_fields_per_entry"); - public static final NodeFeature MULTI_FIELD_UNICODE_OPTIMISATION_FIX = new NodeFeature("mapper.multi_field.unicode_optimisation_fix"); + static final NodeFeature MULTI_FIELD_UNICODE_OPTIMISATION_FIX = new NodeFeature("mapper.multi_field.unicode_optimisation_fix"); static final NodeFeature PATTERN_TEXT_RENAME = new NodeFeature("mapper.pattern_text_rename"); - public static final NodeFeature DISKBBQ_ON_DISK_RESCORING = new NodeFeature("mapper.vectors.diskbbq_on_disk_rescoring"); + static final NodeFeature DISKBBQ_ON_DISK_RESCORING = new NodeFeature("mapper.vectors.diskbbq_on_disk_rescoring"); static final NodeFeature PROVIDE_INDEX_SORT_SETTING_DEFAULTS = new NodeFeature("mapper.provide_index_sort_setting_defaults"); @Override diff --git a/server/src/main/java/org/elasticsearch/index/store/LuceneFilesExtensions.java b/server/src/main/java/org/elasticsearch/index/store/LuceneFilesExtensions.java index a129aa1a5b99b..16ea550095b68 100644 --- a/server/src/main/java/org/elasticsearch/index/store/LuceneFilesExtensions.java +++ b/server/src/main/java/org/elasticsearch/index/store/LuceneFilesExtensions.java @@ -85,7 +85,7 @@ public enum LuceneFilesExtensions { VEQ("veq", "Scalar Quantized Vector Data", false, true), VEMB("vemb", "Binarized Vector Metadata", true, false), VEB("veb", "Binarized Vector Data", false, true), - VFI("vfi", "Vector format information", true, false), + VFI("vfi", "Vector Format Information", true, false), // ivf vectors format MIVF("mivf", "IVF Metadata", true, false), CENIVF("cenivf", "IVF Centroid Data", false, true), diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 8e875ac131f34..6c21437d71d28 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -9,4 +9,5 @@ org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat +org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java index 2fc7e59e087ad..f3cd4f92a6a87 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java @@ -68,13 +68,15 @@ public void setUp() throws Exception { if (rarely()) { format = new ES920DiskBBQVectorsFormat( random().nextInt(2 * MIN_VECTORS_PER_CLUSTER, ES920DiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER), - random().nextInt(8, ES920DiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER) + random().nextInt(8, ES920DiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER), + random().nextBoolean() ); } else { // run with low numbers to force many clusters with parents format = new ES920DiskBBQVectorsFormat( random().nextInt(MIN_VECTORS_PER_CLUSTER, 2 * MIN_VECTORS_PER_CLUSTER), - random().nextInt(MIN_CENTROIDS_PER_PARENT_CLUSTER, 8) + random().nextInt(MIN_CENTROIDS_PER_PARENT_CLUSTER, 8), + random().nextBoolean() ); } super.setUp(); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java new file mode 100644 index 0000000000000..96538fd8dfb74 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java @@ -0,0 +1,256 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es93; + +import com.carrotsearch.randomizedtesting.generators.RandomPicks; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.SoftDeletesRetentionMergePolicy; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.FieldExistsQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.CheckJoinIndex; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.apache.lucene.search.join.QueryBitSetProducer; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.store.MockDirectoryWrapper; +import org.apache.lucene.tests.util.TestUtil; +import org.elasticsearch.common.logging.LogConfigurator; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; + +import static java.lang.String.format; +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.hamcrest.Matchers.either; +import static org.hamcrest.Matchers.startsWith; + +public class ES93BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + private KnnVectorsFormat format; + + @Override + public void setUp() throws Exception { + format = new ES93BinaryQuantizedVectorsFormat(random().nextBoolean()); + super.setUp(); + } + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(format); + } + + @Override + protected VectorSimilarityFunction randomSimilarity() { + return RandomPicks.randomFrom( + random(), + List.of( + VectorSimilarityFunction.DOT_PRODUCT, + VectorSimilarityFunction.EUCLIDEAN, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT + ) + ); + } + + static String encodeInts(int[] i) { + return Arrays.toString(i); + } + + static BitSetProducer parentFilter(IndexReader r) throws IOException { + // Create a filter that defines "parent" documents in the index + BitSetProducer parentsFilter = new QueryBitSetProducer(new TermQuery(new Term("docType", "_parent"))); + CheckJoinIndex.check(r, parentsFilter); + return parentsFilter; + } + + Document makeParent(int[] children) { + Document parent = new Document(); + parent.add(newStringField("docType", "_parent", Field.Store.NO)); + parent.add(newStringField("id", encodeInts(children), Field.Store.YES)); + return parent; + } + + public void testEmptyDiversifiedChildSearch() throws Exception { + String fieldName = "field"; + int dims = random().nextInt(4, 65); + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.EUCLIDEAN; + try (Directory d = newDirectory()) { + IndexWriterConfig iwc = newIndexWriterConfig().setCodec(getCodec()); + iwc.setMergePolicy(new SoftDeletesRetentionMergePolicy("soft_delete", MatchAllDocsQuery::new, iwc.getMergePolicy())); + try (IndexWriter w = new IndexWriter(d, iwc)) { + List toAdd = new ArrayList<>(); + for (int j = 1; j <= 5; j++) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField(fieldName, vector, similarityFunction)); + doc.add(newStringField("id", Integer.toString(j), Field.Store.YES)); + toAdd.add(doc); + } + toAdd.add(makeParent(new int[] { 1, 2, 3, 4, 5 })); + w.addDocuments(toAdd); + w.addDocuments(List.of(makeParent(new int[] { 6, 7, 8, 9, 10 }))); + w.deleteDocuments(new FieldExistsQuery(fieldName), new TermQuery(new Term("id", encodeInts(new int[] { 1, 2, 3, 4, 5 })))); + w.flush(); + w.commit(); + w.forceMerge(1); + try (IndexReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + BitSetProducer parentFilter = parentFilter(searcher.getIndexReader()); + Query query = new DiversifyingChildrenFloatKnnVectorQuery(fieldName, vector, null, 1, parentFilter); + assertTrue(searcher.search(query, 1).scoreDocs.length == 0); + } + } + + } + } + + public void testSearch() throws Exception { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + IndexWriterConfig iwc = newIndexWriterConfig(); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + } + w.commit(); + + try (IndexReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + final int k = random().nextInt(5, 50); + float[] queryVector = randomVector(dims); + Query q = new KnnFloatVectorQuery(fieldName, queryVector, k); + TopDocs collectedDocs = searcher.search(q, k); + assertEquals(k, collectedDocs.totalHits.value()); + assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation()); + } + } + } + } + + public void testToString() { + FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ES93BinaryQuantizedVectorsFormat(); + } + }; + String expectedPattern = "ES93BinaryQuantizedVectorsFormat(name=ES93BinaryQuantizedVectorsFormat," + + " writeFlatVectorFormat=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat," + + " flatVectorScorer=%s())"; + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), either(startsWith(defaultScorer)).or(startsWith(memSegScorer))); + } + + @Override + public void testRandomWithUpdatesAndGraph() { + // graph not supported + } + + @Override + public void testSearchWithVisitedLimit() { + // visited limit is not respected, as it is brute force search + } + + public void testSimpleOffHeapSize() throws IOException { + try (Directory dir = newDirectory()) { + testSimpleOffHeapSizeImpl(dir, newIndexWriterConfig(), true); + } + } + + public void testSimpleOffHeapSizeMMapDir() throws IOException { + try (Directory dir = newMMapDirectory()) { + testSimpleOffHeapSizeImpl(dir, newIndexWriterConfig(), true); + } + } + + public void testSimpleOffHeapSizeImpl(Directory dir, IndexWriterConfig config, boolean expectVecOffHeap) throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (IndexWriter w = new IndexWriter(dir, config)) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader("f"); + } + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + assertEquals(expectVecOffHeap ? 2 : 1, offHeap.size()); + assertTrue(offHeap.get("veb") > 0L); + if (expectVecOffHeap) { + assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec")); + } + } + } + } + } + + static Directory newMMapDirectory() throws IOException { + Directory dir = new MMapDirectory(createTempDir("ES93BinaryQuantizedVectorsFormatTests")); + if (random().nextBoolean()) { + dir = new MockDirectoryWrapper(random(), dir); + } + return dir; + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java index 19cb63bad4dbb..809436f139573 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java @@ -52,6 +52,8 @@ import java.util.Locale; import static java.lang.String.format; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.hamcrest.Matchers.either; @@ -64,23 +66,30 @@ public class ES93HnswBinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFor LogConfigurator.configureESLogging(); // native access requires logging to be initialized } - static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES93HnswBinaryQuantizedVectorsFormat()); + private KnnVectorsFormat format; + + @Override + public void setUp() throws Exception { + format = new ES93HnswBinaryQuantizedVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, random().nextBoolean()); + super.setUp(); + } @Override protected Codec getCodec() { - return codec; + return TestUtil.alwaysKnnVectorsFormat(format); } public void testToString() { FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { @Override public KnnVectorsFormat knnVectorsFormat() { - return new ES93HnswBinaryQuantizedVectorsFormat(10, 20, 1, null); + return new ES93HnswBinaryQuantizedVectorsFormat(10, 20, false, 1, null); } }; String expectedPattern = "ES93HnswBinaryQuantizedVectorsFormat(name=ES93HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20," - + " writeFlatVectorFormat=ES93BinaryQuantizedVectorsFormat(name=ES93BinaryQuantizedVectorsFormat," - + " flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=%s"; + + " flatVectorFormat=ES93BinaryQuantizedVectorsFormat(name=ES93BinaryQuantizedVectorsFormat," + + " writeFlatVectorFormat=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat," + + " flatVectorScorer=%s())"; var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); @@ -128,15 +137,15 @@ public void testSingleVectorCase() throws Exception { } public void testLimits() { - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(-1, 20)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(0, 20)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 0)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, -1)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(512 + 1, 20)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 3201)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(-1, 20, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(0, 20, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 0, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, -1, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(512 + 1, 20, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 3201, false)); expectThrows( IllegalArgumentException.class, - () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 100, 1, new SameThreadExecutorService()) + () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 100, false, 1, new SameThreadExecutorService()) ); }