diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 2987b3849e663..dacdf678ddd78 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -465,6 +465,7 @@ org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat, org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat, org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat, org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; provides org.apache.lucene.codecs.Codec diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsWriter.java index 3c143d94fd6b5..86894377bacde 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsWriter.java @@ -23,7 +23,6 @@ import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -250,7 +249,7 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldI final IndexInput finalVectorDataInput = vectorDataInput; final RandomVectorScorerSupplier randomVectorScorerSupplier = vectorsScorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), - new OffHeapFloatVectorValues.DenseOffHeapVectorValues( + new OffHeapBFloat16VectorValues.DenseOffHeapVectorValues( fieldInfo.getVectorDimension(), docsWithField.cardinality(), finalVectorDataInput, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java new file mode 100644 index 0000000000000..ad151147d87ea --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +public class ES93HnswVectorsFormat extends AbstractHnswVectorsFormat { + + static final String NAME = "ES93HnswVectorsFormat"; + + private final FlatVectorsFormat flatVectorsFormat; + + public ES93HnswVectorsFormat() { + super(NAME); + flatVectorsFormat = new ES93GenericFlatVectorsFormat(); + } + + public ES93HnswVectorsFormat(int maxConn, int beamWidth, boolean bfloat16, boolean useDirectIO) { + super(NAME, maxConn, beamWidth); + flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO); + } + + public ES93HnswVectorsFormat( + int maxConn, + int beamWidth, + boolean bfloat16, + boolean useDirectIO, + int numMergeWorkers, + ExecutorService mergeExec + ) { + super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); + flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO); + } + + @Override + protected FlatVectorsFormat flatVectorsFormat() { + return flatVectorsFormat; + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } +} diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 6c21437d71d28..5370d7244df9b 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -10,4 +10,5 @@ org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsForma org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..f6b3a2aedca57 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.index.VectorEncoding; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.hamcrest.Matchers.closeTo; + +public class ES93HnswBFloat16VectorsFormatTests extends ES93HnswVectorsFormatTests { + + @Override + protected boolean useBFloat16() { + return true; + } + + @Override + protected VectorEncoding randomVectorEncoding() { + return VectorEncoding.FLOAT32; + } + + @Override + public void testEmptyByteVectorData() throws Exception { + // no bytes + } + + @Override + public void testMergingWithDifferentByteKnnFields() throws Exception { + // no bytes + } + + @Override + public void testByteVectorScorerIteration() throws Exception { + // no bytes + } + + @Override + public void testSortedIndexBytes() throws Exception { + // no bytes + } + + @Override + public void testMismatchedFields() throws Exception { + // no bytes + } + + @Override + public void testRandomBytes() throws Exception { + // no bytes + } + + @Override + public void testRandom() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testRandom); + assertFloatsWithinBounds(err); + } + + @Override + public void testRandomWithUpdatesAndGraph() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testRandomWithUpdatesAndGraph); + assertFloatsWithinBounds(err); + } + + @Override + public void testSparseVectors() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testSparseVectors); + assertFloatsWithinBounds(err); + } + + @Override + public void testVectorValuesReportCorrectDocs() throws Exception { + AssertionError err = expectThrows(AssertionError.class, super::testVectorValuesReportCorrectDocs); + assertFloatsWithinBounds(err); + } + + private static final Pattern FLOAT_ASSERTION_FAILURE = Pattern.compile(".*expected:<([0-9.-]+)> but was:<([0-9.-]+)>"); + + private static void assertFloatsWithinBounds(AssertionError error) { + Matcher m = FLOAT_ASSERTION_FAILURE.matcher(error.getMessage()); + if (m.matches() == false) { + throw error; // nothing to do with us, just rethrow + } + + // numbers just need to be in the same vicinity + double expected = Double.parseDouble(m.group(1)); + double actual = Double.parseDouble(m.group(2)); + double allowedError = expected * 0.01; // within 1% + assertThat(error.getMessage(), actual, closeTo(expected, allowedError)); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java new file mode 100644 index 0000000000000..bd7e4f7f653bf --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java @@ -0,0 +1,117 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es93; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.SameThreadExecutorService; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BFloat16; + +import java.io.IOException; +import java.util.Locale; + +import static java.lang.String.format; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class ES93HnswVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + private KnnVectorsFormat format; + + protected boolean useBFloat16() { + return false; + } + + @Override + public void setUp() throws Exception { + format = new ES93HnswVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, useBFloat16(), random().nextBoolean()); + super.setUp(); + } + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(format); + } + + public void testToString() { + FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ES93HnswVectorsFormat(10, 20, false, false); + } + }; + String expectedPattern = "ES93HnswVectorsFormat(name=ES93HnswVectorsFormat, maxConn=10, beamWidth=20," + + " flatVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat," + + " format=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=%s())))"; + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + public void testLimits() { + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(-1, 20, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(0, 20, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, 0, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, -1, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(512 + 1, 20, false, false)); + expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, 3201, false, false)); + expectThrows( + IllegalArgumentException.class, + () -> new ES93HnswVectorsFormat(20, 100, false, false, 1, new SameThreadExecutorService()) + ); + } + + public void testSimpleOffHeapSize() throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader("f"); + } + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + int bytes = useBFloat16() ? BFloat16.BYTES : Float.BYTES; + assertEquals(vector.length * bytes, (long) offHeap.get("vec")); + assertEquals(1L, (long) offHeap.get("vex")); + assertEquals(2, offHeap.size()); + } + } + } + } +}