diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 470ca69fb0d68..ae8584d716c80 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -458,11 +458,13 @@ org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat, org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat, + org.elasticsearch.index.codec.vectors.ES815BitFlatVectorsFormat, org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat, - ES920DiskBBQVectorsFormat; + ES920DiskBBQVectorsFormat, + org.elasticsearch.index.codec.vectors.ES920HnswComposableKnnVectorsFormat; provides org.apache.lucene.codecs.Codec with diff --git a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch814Codec.java b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch814Codec.java index ae372ea8194bc..cfc67dbee1366 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch814Codec.java +++ b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch814Codec.java @@ -18,8 +18,8 @@ import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat; +import org.elasticsearch.index.codec.vectors.ComposablePerFieldKnnVectorsFormat; import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; /** @@ -47,7 +47,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) { }; private static final KnnVectorsFormat defaultKnnVectorsFormat = new Lucene99HnswVectorsFormat(); - private final KnnVectorsFormat knnVectorsFormat = new PerFieldKnnVectorsFormat() { + private final KnnVectorsFormat knnVectorsFormat = new ComposablePerFieldKnnVectorsFormat() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { return Elasticsearch814Codec.this.getKnnVectorsFormatForField(field); diff --git a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch816Codec.java b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch816Codec.java index d58c4e2cdc34a..a5eba3ceac0cd 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch816Codec.java +++ b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch816Codec.java @@ -18,8 +18,8 @@ import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat; +import org.elasticsearch.index.codec.vectors.ComposablePerFieldKnnVectorsFormat; import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; /** @@ -49,7 +49,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) { } }; - private final KnnVectorsFormat knnVectorsFormat = new PerFieldKnnVectorsFormat() { + private final KnnVectorsFormat knnVectorsFormat = new ComposablePerFieldKnnVectorsFormat() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { return Elasticsearch816Codec.this.getKnnVectorsFormatForField(field); diff --git a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch900Codec.java b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch900Codec.java index 04428d5b37fba..f1ab339d59879 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch900Codec.java +++ b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch900Codec.java @@ -18,8 +18,8 @@ import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat; +import org.elasticsearch.index.codec.vectors.ComposablePerFieldKnnVectorsFormat; import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; /** @@ -47,7 +47,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) { }; private final KnnVectorsFormat defaultKnnVectorsFormat; - private final KnnVectorsFormat knnVectorsFormat = new PerFieldKnnVectorsFormat() { + private final KnnVectorsFormat knnVectorsFormat = new ComposablePerFieldKnnVectorsFormat() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { return Elasticsearch900Codec.this.getKnnVectorsFormatForField(field); diff --git a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch900Lucene101Codec.java b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch900Lucene101Codec.java index 3edd55d8f8de7..9145b93bfe821 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch900Lucene101Codec.java +++ b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch900Lucene101Codec.java @@ -17,9 +17,9 @@ import org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat; import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat; import org.elasticsearch.index.codec.perfield.XPerFieldDocValuesFormat; +import org.elasticsearch.index.codec.vectors.ComposablePerFieldKnnVectorsFormat; import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; /** @@ -49,7 +49,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) { }; private final KnnVectorsFormat defaultKnnVectorsFormat; - private final KnnVectorsFormat knnVectorsFormat = new PerFieldKnnVectorsFormat() { + private final KnnVectorsFormat knnVectorsFormat = new ComposablePerFieldKnnVectorsFormat() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { return Elasticsearch900Lucene101Codec.this.getKnnVectorsFormatForField(field); diff --git a/server/src/main/java/org/elasticsearch/index/codec/PerFieldFormatSupplier.java b/server/src/main/java/org/elasticsearch/index/codec/PerFieldFormatSupplier.java index 11362c6e68cd7..58975517bb052 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/PerFieldFormatSupplier.java +++ b/server/src/main/java/org/elasticsearch/index/codec/PerFieldFormatSupplier.java @@ -124,7 +124,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { if (mapperService != null) { Mapper mapper = mapperService.mappingLookup().getMapper(field); if (mapper instanceof DenseVectorFieldMapper vectorMapper) { - return vectorMapper.getKnnVectorsFormatForField(knnVectorsFormat); + return vectorMapper.getKnnVectorsFormatForField(); } } return knnVectorsFormat; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ComposablePerFieldKnnVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ComposablePerFieldKnnVectorsFormat.java new file mode 100644 index 0000000000000..81000c29a3f75 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ComposablePerFieldKnnVectorsFormat.java @@ -0,0 +1,465 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.internal.hppc.ObjectCursor; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOUtils; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; + +import java.io.Closeable; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +// copied and modified from Lucene. +public abstract class ComposablePerFieldKnnVectorsFormat extends KnnVectorsFormat { + + private static final Logger logger = LogManager.getLogger(ComposablePerFieldKnnVectorsFormat.class); + + /** Name of this {@link KnnVectorsFormat}. */ + public static final String PER_FIELD_NAME = "ComposablePerFieldVectorsES920"; + + /** {@link FieldInfo} attribute name used to store the format name for each field. */ + public static final String PER_FIELD_FORMAT_KEY = ComposablePerFieldKnnVectorsFormat.class.getSimpleName() + ".format"; + + public static final String PER_FIELD_COMPOSED_FORMAT_KEY = ComposablePerFieldKnnVectorsFormat.class.getSimpleName() + + ".composed_format"; + + /** {@link FieldInfo} attribute name used to store the segment suffix name for each field. */ + public static final String PER_FIELD_SUFFIX_KEY = ComposablePerFieldKnnVectorsFormat.class.getSimpleName() + ".suffix"; + + private record WriterAndSuffix(KnnVectorsWriter writer, int suffix) implements Closeable { + + @Override + public void close() throws IOException { + writer.close(); + } + } + + static String getSuffix(String formatName, String suffix) { + return formatName + "_" + suffix; + } + + static String getFullSegmentSuffix(String outerSegmentSuffix, String segmentSuffix) { + if (outerSegmentSuffix.length() == 0) { + return segmentSuffix; + } else { + return outerSegmentSuffix + "_" + segmentSuffix; + } + } + + public ComposablePerFieldKnnVectorsFormat() { + super(PER_FIELD_NAME); + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new FieldsWriter(state); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new FieldsReader(state); + } + + @Override + public int getMaxDimensions(String fieldName) { + return DenseVectorFieldMapper.MAX_DIMS_COUNT; + } + + /** + * Returns the numeric vector format that should be used for writing new segments of field + * . + * + *

The field to format mapping is written to the index, so this method is only invoked when + * writing, not when reading. + */ + public abstract KnnVectorsFormat getKnnVectorsFormatForField(String field); + + private class FieldsWriter extends KnnVectorsWriter { + private final Map formats; + private final Map suffixes = new HashMap<>(); + private final SegmentWriteState segmentWriteState; + + FieldsWriter(SegmentWriteState segmentWriteState) { + this.segmentWriteState = segmentWriteState; + formats = new HashMap<>(); + } + + @Override + public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + KnnVectorsWriter writer = getInstance(fieldInfo); + return writer.addField(fieldInfo); + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + for (WriterAndSuffix was : formats.values()) { + was.writer.flush(maxDoc, sortMap); + } + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + getInstance(fieldInfo).mergeOneField(fieldInfo, mergeState); + } + + @Override + public void finish() throws IOException { + for (WriterAndSuffix was : formats.values()) { + was.writer.finish(); + } + } + + @Override + public void close() throws IOException { + IOUtils.close(formats.values()); + } + + private KnnVectorsWriter getInstance(FieldInfo field) throws IOException { + KnnVectorsFormat format = getKnnVectorsFormatForField(field.name); + if (format == null) { + throw new IllegalStateException("invalid null KnnVectorsFormat for field=\"" + field.name + "\""); + } + final String formatName; + KnnVectorsFormat possibleComposable = format; + if (possibleComposable instanceof MaxDimOverridingKnnVectorsFormat wrapped) { + possibleComposable = wrapped.getDelegate(); + } + // We expect the provided format to be the fully composed writer, with all values set + if (possibleComposable instanceof ComposableKnnVectorsFormat composableFormat) { + logger.info( + "composable format for field [{}] [{}] [{}]", + field.name, + composableFormat.getComposedFormatName(), + composableFormat.getClass().getCanonicalName() + ); + formatName = composableFormat.getComposedFormatName(); + field.putAttribute(PER_FIELD_COMPOSED_FORMAT_KEY, formatName); + } else { + logger.info("regular format for field [{}] [{}] [{}]", field.name, format.getName(), format.getClass().getCanonicalName()); + formatName = format.getName(); + } + // NOTE This is the composed name, maybe we instead put the Lucene acceptable name here and the composed name in another place? + field.putAttribute(PER_FIELD_FORMAT_KEY, formatName); + Integer suffix; + + WriterAndSuffix writerAndSuffix = formats.get(format); + if (writerAndSuffix == null) { + // First time we are seeing this format; create a new instance + suffix = suffixes.get(formatName); + if (suffix == null) { + suffix = 0; + } else { + suffix = suffix + 1; + } + suffixes.put(formatName, suffix); + + String segmentSuffix = getFullSegmentSuffix( + segmentWriteState.segmentSuffix, + getSuffix(formatName, Integer.toString(suffix)) + ); + writerAndSuffix = new WriterAndSuffix(format.fieldsWriter(new SegmentWriteState(segmentWriteState, segmentSuffix)), suffix); + formats.put(format, writerAndSuffix); + } else { + // we've already seen this format, so just grab its suffix + assert suffixes.containsKey(formatName); + suffix = writerAndSuffix.suffix; + } + field.putAttribute(PER_FIELD_SUFFIX_KEY, Integer.toString(suffix)); + return writerAndSuffix.writer; + } + + @Override + public long ramBytesUsed() { + long total = 0; + for (WriterAndSuffix was : formats.values()) { + total += was.writer.ramBytesUsed(); + } + return total; + } + } + + /** VectorReader that can wrap multiple delegate readers, selected by field. */ + public static class FieldsReader extends KnnVectorsReader { + + private final IntObjectHashMap fields = new IntObjectHashMap<>(); + private final FieldInfos fieldInfos; + + /** + * Create a FieldsReader over a segment, opening VectorReaders for each KnnVectorsFormat + * specified by the indexed numeric vector fields. + * + * @param readState defines the fields + * @throws IOException if one of the delegate readers throws + */ + public FieldsReader(final SegmentReadState readState) throws IOException { + this.fieldInfos = readState.fieldInfos; + // Init each unique format: + boolean success = false; + Map formats = new HashMap<>(); + try { + // Read field name -> format name + for (FieldInfo fi : readState.fieldInfos) { + if (fi.hasVectorValues()) { + final String fieldName = fi.name; + String formatName = fi.getAttribute(PER_FIELD_FORMAT_KEY); + if (formatName != null) { + // null formatName means the field is in fieldInfos, but has no vectors! + final String suffix = fi.getAttribute(PER_FIELD_SUFFIX_KEY); + if (suffix == null) { + throw new IllegalStateException("missing attribute: " + PER_FIELD_SUFFIX_KEY + " for field: " + fieldName); + } + final KnnVectorsFormat format; + // check if its a "composable" name + String composedFormatName = fi.getAttribute(PER_FIELD_COMPOSED_FORMAT_KEY); + if (composedFormatName != null) { + logger.info("got composed format format for field: {} is: {}", fi.name, composedFormatName); + // extract composable elements + String[] parts = formatName.split("\\+"); + String baseFormatName = parts[0]; + KnnVectorsFormat baseFormat = KnnVectorsFormat.forName(baseFormatName); + if (baseFormat instanceof ComposableKnnVectorsFormat composable) { + // we know either it's a directory modifier, or some inner format + // first check for directory modifier + ComposableKnnVectorsFormat outterFormat = composable; + for (int i = 1; i < parts.length; i++) { + // check if it's a directory modifier + DirectoryModifier dm = DirectoryModifier.fromString(parts[i]); + if (dm != null) { + outterFormat.setDirectoryModifier(dm); + } else { + // assume its another knn format, could be composable or not + KnnVectorsFormat innerFormat = KnnVectorsFormat.forName(parts[i]); + outterFormat.setInnerVectorsFormat(innerFormat); + if (innerFormat instanceof ComposableKnnVectorsFormat innerComposable) { + outterFormat = innerComposable; + } else { + // we are done, verify that there are no more parts + if (i != parts.length - 1) { + throw new IllegalStateException( + "found non-composable format: " + + innerFormat.getName() + + " in the middle of the composed format name: " + + formatName + + " for field: " + + fieldName + ); + } + } + } + + } + } else { + throw new IllegalStateException( + "expected ComposableKnnVectorsFormat for field: " + + fieldName + + " got: " + + baseFormat.getClass().getSimpleName() + + " for format name: " + + baseFormatName + + " full format name: " + + formatName + + " composed format name: " + + composedFormatName + ); + } + format = baseFormat; + } else { + // standard format + format = KnnVectorsFormat.forName(formatName); + } + // TODO does this need to be the composed name? + String segmentSuffix = getFullSegmentSuffix(readState.segmentSuffix, getSuffix(formatName, suffix)); + if (formats.containsKey(segmentSuffix) == false) { + formats.put(segmentSuffix, format.fieldsReader(new SegmentReadState(readState, segmentSuffix))); + } + fields.put(fi.number, formats.get(segmentSuffix)); + } + } + } + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(formats.values()); + } + } + } + + private FieldsReader(final FieldsReader fieldsReader) { + this.fieldInfos = fieldsReader.fieldInfos; + for (FieldInfo fi : this.fieldInfos) { + if (fi.hasVectorValues() && fieldsReader.fields.containsKey(fi.number)) { + this.fields.put(fi.number, fieldsReader.fields.get(fi.number).getMergeInstance()); + } + } + } + + @Override + public KnnVectorsReader getMergeInstance() { + return new FieldsReader(this); + } + + @Override + public void finishMerge() throws IOException { + for (ObjectCursor knnVectorReader : fields.values()) { + knnVectorReader.value.finishMerge(); + } + } + + /** + * Return the underlying VectorReader for the given field + * + * @param field the name of a numeric vector field + */ + public KnnVectorsReader getFieldReader(String field) { + final FieldInfo info = fieldInfos.fieldInfo(field); + if (info == null) { + return null; + } + return fields.get(info.number); + } + + @Override + public void checkIntegrity() throws IOException { + for (ObjectCursor cursor : fields.values()) { + cursor.value.checkIntegrity(); + } + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + final FieldInfo info = fieldInfos.fieldInfo(field); + final KnnVectorsReader reader; + if (info == null || (reader = fields.get(info.number)) == null) { + return null; + } + return reader.getFloatVectorValues(field); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + final FieldInfo info = fieldInfos.fieldInfo(field); + final KnnVectorsReader reader; + if (info == null || (reader = fields.get(info.number)) == null) { + return null; + } + return reader.getByteVectorValues(field); + } + + @Override + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + final FieldInfo info = fieldInfos.fieldInfo(field); + final KnnVectorsReader reader; + if (info == null || (reader = fields.get(info.number)) == null) { + return; + } + reader.search(field, target, knnCollector, acceptDocs); + } + + @Override + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + final FieldInfo info = fieldInfos.fieldInfo(field); + final KnnVectorsReader reader; + if (info == null || (reader = fields.get(info.number)) == null) { + return; + } + reader.search(field, target, knnCollector, acceptDocs); + } + + @Override + public void close() throws IOException { + List readers = new ArrayList<>(fields.size()); + for (ObjectCursor cursor : fields.values()) { + readers.add(cursor.value); + } + IOUtils.close(readers); + } + } + + public enum DirectoryModifier { + DirectIO, + None; + + static DirectoryModifier fromString(String name) { + for (DirectoryModifier dm : values()) { + if (dm.name().equalsIgnoreCase(name)) { + return dm; + } + } + return null; + } + } + + public abstract static class ComposableKnnVectorsFormat extends KnnVectorsFormat { + /** + * Sole constructor + * + * @param name + */ + protected ComposableKnnVectorsFormat(String name) { + super(name); + } + + abstract DirectoryModifier getDirectoryModifier(); + + final String getComposedFormatName() { + StringBuilder sb = new StringBuilder(); + sb.append(getName()); + if (getDirectoryModifier() != null && getDirectoryModifier() != DirectoryModifier.None) { + sb.append("+").append(getDirectoryModifier().name()); + } + KnnVectorsFormat inner = getInnerVectorsFormat(); + if (inner != null) { + if (inner instanceof ComposableKnnVectorsFormat composable) { + sb.append("+").append(composable.getComposedFormatName()); + } else { + sb.append("+").append(inner.getName()); + } + } + return sb.toString(); + } + + abstract KnnVectorsFormat getInnerVectorsFormat(); + + abstract void setInnerVectorsFormat(KnnVectorsFormat format); + + abstract void setDirectoryModifier(DirectoryModifier modifier); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java index e3242ee411e7d..4ad0e86e82f4f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java @@ -29,11 +29,11 @@ import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; -class ES815BitFlatVectorsFormat extends FlatVectorsFormat { +public class ES815BitFlatVectorsFormat extends FlatVectorsFormat { private static final FlatVectorsFormat delegate = new Lucene99FlatVectorsFormat(FlatBitVectorScorer.INSTANCE); - protected ES815BitFlatVectorsFormat() { + public ES815BitFlatVectorsFormat() { super("ES815BitFlatVectorsFormat"); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES920HnswComposableKnnVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES920HnswComposableKnnVectorsFormat.java new file mode 100644 index 0000000000000..e63866da92d9e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES920HnswComposableKnnVectorsFormat.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.util.SetOnce; + +import java.io.IOException; +import java.util.Objects; + +public class ES920HnswComposableKnnVectorsFormat extends ComposablePerFieldKnnVectorsFormat.ComposableKnnVectorsFormat { + + private final SetOnce flatVectorsFormat; + private final int maxConn; + private final int beamWidth; + + public ES920HnswComposableKnnVectorsFormat() { + this(null, 16, 100); + } + + public ES920HnswComposableKnnVectorsFormat(FlatVectorsFormat flatVectorsFormat, int maxConn, int beamWidth) { + super("ES920HnswComposableKnnVectorsFormat"); + this.flatVectorsFormat = new SetOnce<>(); + if (flatVectorsFormat != null) { + this.flatVectorsFormat.set(flatVectorsFormat); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + } + + @Override + ComposablePerFieldKnnVectorsFormat.DirectoryModifier getDirectoryModifier() { + return null; + } + + @Override + KnnVectorsFormat getInnerVectorsFormat() { + return flatVectorsFormat.get(); + } + + @Override + void setInnerVectorsFormat(KnnVectorsFormat format) { + if (flatVectorsFormat.get() != null) { + // allow it to be "set" again if it's the exact same format + if (Objects.equals(flatVectorsFormat.get(), format)) { + return; + } + throw new IllegalStateException("Inner format already set"); + } + if (format instanceof FlatVectorsFormat fvf) { + this.flatVectorsFormat.set(fvf); + return; + } + throw new IllegalArgumentException("Inner format must be a FlatVectorsFormat, received: " + format.getClass()); + } + + @Override + void setDirectoryModifier(ComposablePerFieldKnnVectorsFormat.DirectoryModifier modifier) { + throw new UnsupportedOperationException("DirectoryModifier is not supported on HNSW format"); + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + if (flatVectorsFormat.get() == null) { + throw new IllegalStateException("flatVectorsFormat must be set"); + } + return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.get().fieldsWriter(state), 1, null); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + if (flatVectorsFormat.get() == null) { + throw new IllegalStateException("flatVectorsFormat must be set"); + } + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.get().fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 4096; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/MaxDimOverridingKnnVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/MaxDimOverridingKnnVectorsFormat.java new file mode 100644 index 0000000000000..eea5d16740840 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/MaxDimOverridingKnnVectorsFormat.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +import java.io.IOException; + +public final class MaxDimOverridingKnnVectorsFormat extends KnnVectorsFormat { + private final KnnVectorsFormat delegate; + private final int maxDimension; + + public MaxDimOverridingKnnVectorsFormat(KnnVectorsFormat delegate, int maxDimension) { + super(delegate.getName()); + this.delegate = delegate; + this.maxDimension = maxDimension; + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return delegate.fieldsWriter(state); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return delegate.fieldsReader(state); + } + + @Override + public int getMaxDimensions(String fieldName) { + return maxDimension; + } + + @Override + public String toString() { + return delegate.toString(); + } + + public KnnVectorsFormat getDelegate() { + return delegate; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java index 22520567f2954..e5bee0014f87f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java @@ -49,6 +49,7 @@ import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.index.codec.vectors.BQSpaceUtils; import org.elasticsearch.index.codec.vectors.BQVectorUtils; +import org.elasticsearch.index.codec.vectors.ComposablePerFieldKnnVectorsFormat; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import java.io.Closeable; @@ -555,6 +556,9 @@ static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) { if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { vectorsReader = candidateReader.getFieldReader(fieldName); } + if (vectorsReader instanceof ComposablePerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } if (vectorsReader instanceof ES818BinaryQuantizedVectorsReader reader) { return reader.getCentroid(fieldName); } diff --git a/server/src/main/java/org/elasticsearch/index/engine/Engine.java b/server/src/main/java/org/elasticsearch/index/engine/Engine.java index 8271a29dfd995..2499861cb4c65 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/Engine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/Engine.java @@ -68,6 +68,7 @@ import org.elasticsearch.index.VersionType; import org.elasticsearch.index.codec.FieldInfosWithUsages; import org.elasticsearch.index.codec.TrackingPostingsInMemoryBytesCodec; +import org.elasticsearch.index.codec.vectors.ComposablePerFieldKnnVectorsFormat; import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils; import org.elasticsearch.index.mapper.DocumentParser; import org.elasticsearch.index.mapper.LuceneDocument; @@ -403,6 +404,8 @@ private DenseVectorStats getDenseVectorStats(final LeafReader atomicReader, List var vectorsReader = reader.getVectorReader(); if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { vectorsReader = fieldsReader.getFieldReader(info.name); + } else if (vectorsReader instanceof ComposablePerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + vectorsReader = fieldsReader.getFieldReader(info.name); } Map offHeap = OffHeapByteSizeUtils.getOffHeapByteSize(vectorsReader, info); offHeapStats.put(info.name, offHeap); @@ -1381,13 +1384,10 @@ private void fillSegmentInfo( knnFormats = new HashMap<>(); } String key = fieldInfo.getAttribute(PerFieldKnnVectorsFormat.PER_FIELD_FORMAT_KEY); - knnFormats.compute(key, (s, a) -> { - if (a == null) { - a = new ArrayList<>(); - } - a.add(name); - return a; - }); + if (key == null) { + key = fieldInfo.getAttribute(ComposablePerFieldKnnVectorsFormat.PER_FIELD_FORMAT_KEY); + } + knnFormats.computeIfAbsent(key, s -> new ArrayList<>()).add(name); } } } diff --git a/server/src/main/java/org/elasticsearch/index/engine/MergeMemoryEstimator.java b/server/src/main/java/org/elasticsearch/index/engine/MergeMemoryEstimator.java index 7567ee3e38284..9a9f7c8a99b0b 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/MergeMemoryEstimator.java +++ b/server/src/main/java/org/elasticsearch/index/engine/MergeMemoryEstimator.java @@ -20,6 +20,7 @@ import org.apache.lucene.index.SegmentReader; import org.apache.lucene.index.VectorEncoding; import org.elasticsearch.common.lucene.Lucene; +import org.elasticsearch.index.codec.vectors.ComposablePerFieldKnnVectorsFormat; import java.util.List; import java.util.Map; @@ -97,6 +98,9 @@ private static long estimateVectorFieldMemory(FieldInfo fieldInfo, SegmentCommit if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldKnnVectorsFormat) { vectorsReader = perFieldKnnVectorsFormat.getFieldReader(fieldInfo.getName()); } + if (vectorsReader instanceof ComposablePerFieldKnnVectorsFormat.FieldsReader perFieldKnnVectorsFormat) { + vectorsReader = perFieldKnnVectorsFormat.getFieldReader(fieldInfo.getName()); + } return getVectorFieldEstimation(fieldInfo, segmentCommitInfo, vectorsReader); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index ba7e68cef62e7..d1cee00c2c6db 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -10,8 +10,6 @@ package org.elasticsearch.index.mapper.vectors; import org.apache.lucene.codecs.KnnVectorsFormat; -import org.apache.lucene.codecs.KnnVectorsReader; -import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.Field; @@ -26,8 +24,6 @@ import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.NumericDocValues; -import org.apache.lucene.index.SegmentReadState; -import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.BooleanClause; @@ -52,9 +48,11 @@ import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat; -import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.ES814ScalarQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat; -import org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat; +import org.elasticsearch.index.codec.vectors.ES815BitFlatVectorsFormat; +import org.elasticsearch.index.codec.vectors.ES920HnswComposableKnnVectorsFormat; +import org.elasticsearch.index.codec.vectors.MaxDimOverridingKnnVectorsFormat; import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; @@ -1716,7 +1714,11 @@ public Int4HnswIndexOptions(int m, int efConstruction, Float confidenceInterval, @Override public KnnVectorsFormat getVectorsFormat(ElementType elementType) { assert elementType == ElementType.FLOAT; - return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, confidenceInterval, 4, true); + return new ES920HnswComposableKnnVectorsFormat( + new ES814ScalarQuantizedVectorsFormat(confidenceInterval, 4, true), + m, + efConstruction + ); } @Override @@ -1864,7 +1866,11 @@ public Int8HnswIndexOptions(int m, int efConstruction, Float confidenceInterval, @Override public KnnVectorsFormat getVectorsFormat(ElementType elementType) { assert elementType == ElementType.FLOAT; - return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, confidenceInterval, 7, false); + return new ES920HnswComposableKnnVectorsFormat( + new ES814ScalarQuantizedVectorsFormat(confidenceInterval, 7, false), + m, + efConstruction + ); } @Override @@ -1952,7 +1958,7 @@ static class HnswIndexOptions extends DenseVectorIndexOptions { @Override public KnnVectorsFormat getVectorsFormat(ElementType elementType) { if (elementType == ElementType.BIT) { - return new ES815HnswBitVectorsFormat(m, efConstruction); + return new ES920HnswComposableKnnVectorsFormat(new ES815BitFlatVectorsFormat(), m, efConstruction); } return new Lucene99HnswVectorsFormat(m, efConstruction, 1, null); } @@ -2845,35 +2851,15 @@ private static DenseVectorIndexOptions parseIndexOptions(String fieldName, Objec * @return the custom kNN vectors format that is configured for this field or * {@code null} if the default format should be used. */ - public KnnVectorsFormat getKnnVectorsFormatForField(KnnVectorsFormat defaultFormat) { + public KnnVectorsFormat getKnnVectorsFormatForField() { final KnnVectorsFormat format; if (indexOptions == null) { - format = fieldType().element.elementType() == ElementType.BIT ? new ES815HnswBitVectorsFormat() : defaultFormat; + format = new ES920HnswComposableKnnVectorsFormat(new ES815BitFlatVectorsFormat(), 16, 100); } else { format = indexOptions.getVectorsFormat(fieldType().element.elementType()); } // It's legal to reuse the same format name as this is the same on-disk format. - return new KnnVectorsFormat(format.getName()) { - @Override - public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return format.fieldsWriter(state); - } - - @Override - public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { - return format.fieldsReader(state); - } - - @Override - public int getMaxDimensions(String fieldName) { - return MAX_DIMS_COUNT; - } - - @Override - public String toString() { - return format.toString(); - } - }; + return new MaxDimOverridingKnnVectorsFormat(format, MAX_DIMS_COUNT); } @Override diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index ff02849e96c46..a314626e62111 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -3,8 +3,10 @@ org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat +org.elasticsearch.index.codec.vectors.ES815BitFlatVectorsFormat org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat +org.elasticsearch.index.codec.vectors.ES920HnswComposableKnnVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES920ComposedBitFlatVectorFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES920ComposedBitFlatVectorFormatTests.java new file mode 100644 index 0000000000000..10fb56af12e93 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES920ComposedBitFlatVectorFormatTests.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.elasticsearch.index.codec.Elasticsearch900Lucene101Codec; +import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils; +import org.junit.Before; + +import java.io.IOException; + +public class ES920ComposedBitFlatVectorFormatTests extends BaseKnnBitVectorsFormatTestCase { + + @Override + protected Codec getCodec() { + return new Elasticsearch900Lucene101Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new ES920HnswComposableKnnVectorsFormat(new ES815BitFlatVectorsFormat(), 16, 100); + } + }; + } + + @Before + public void init() { + similarityFunction = VectorSimilarityFunction.EUCLIDEAN; + } + + public void testSimpleOffHeapSize() throws IOException { + byte[] vector = randomVector(random().nextInt(12, 500)); + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnByteVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader("f"); + } + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = OffHeapByteSizeUtils.getOffHeapByteSize(knnVectorsReader, fieldInfo); + assertEquals(1, offHeap.size()); + assertTrue(offHeap.get("vec") > 0L); + } + } + } + } +}