diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java index c42701f1e5d6f..7e653b4b4cbe3 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java @@ -47,8 +47,18 @@ public ES93HnswBinaryQuantizedVectorsFormat() { /** * Constructs a format using the given graph construction parameters. * - * @param maxConn the maximum number of connections to a node in the HNSW graph - * @param beamWidth the size of the queue maintained during graph construction. + * @param useDirectIO whether to use direct IO when reading raw vectors + */ + public ES93HnswBinaryQuantizedVectorsFormat(boolean useBFloat16, boolean useDirectIO) { + super(NAME); + flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useBFloat16, useDirectIO); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. * @param useDirectIO whether to use direct IO when reading raw vectors */ public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean useBFloat16, boolean useDirectIO) { @@ -70,8 +80,8 @@ public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean public ES93HnswBinaryQuantizedVectorsFormat( int maxConn, int beamWidth, - boolean useDirectIO, boolean useBFloat16, + boolean useDirectIO, int numMergeWorkers, ExecutorService mergeExec ) { diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseHnswVectorsFormatTestCase.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseHnswVectorsFormatTestCase.java new file mode 100644 index 0000000000000..097cd70b7344e --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseHnswVectorsFormatTestCase.java @@ -0,0 +1,147 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.SameThreadExecutorService; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.common.logging.LogConfigurator; +import org.hamcrest.Matcher; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ExecutorService; + +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; + +public abstract class BaseHnswVectorsFormatTestCase extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + protected abstract KnnVectorsFormat createFormat(); + + protected abstract KnnVectorsFormat createFormat(int maxConn, int beamWidth); + + protected abstract KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service); + + private KnnVectorsFormat format; + + @Override + public void setUp() throws Exception { + format = createFormat(); + super.setUp(); + } + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(format); + } + + public void testLimits() { + expectThrows(IllegalArgumentException.class, () -> createFormat(-1, 20)); + expectThrows(IllegalArgumentException.class, () -> createFormat(0, 20)); + expectThrows(IllegalArgumentException.class, () -> createFormat(20, 0)); + expectThrows(IllegalArgumentException.class, () -> createFormat(20, -1)); + expectThrows(IllegalArgumentException.class, () -> createFormat(512 + 1, 20)); + expectThrows(IllegalArgumentException.class, () -> createFormat(20, 3201)); + expectThrows(IllegalArgumentException.class, () -> createFormat(20, 100, 1, new SameThreadExecutorService())); + } + + public void testSingleVectorCase() throws Exception { + float[] vector = randomVector(random().nextInt(12, 500)); + for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) { + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + if (similarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(vector); + } + doc.add(new KnnFloatVectorField("f", vector, similarityFunction)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues("f"); + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + assertThat(vectorValues.size(), equalTo(1)); + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f); + } + float[] randomVector = randomVector(vector.length); + if (similarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(randomVector); + } + float trueScore = similarityFunction.compare(vector, randomVector); + TopDocs td = r.searchNearestVectors( + "f", + randomVector, + 1, + AcceptDocs.fromLiveDocs(r.getLiveDocs(), r.maxDoc()), + Integer.MAX_VALUE + ); + assertEquals(1, td.totalHits.value()); + assertThat(td.scoreDocs[0].score, greaterThanOrEqualTo(0f)); + // When it's the only vector in a segment, the score should be very close to the true score + assertEquals(trueScore, td.scoreDocs[0].score, 0.01f); + } + } + } + } + + protected static void testSimpleOffHeapSize( + Directory dir, + IndexWriterConfig config, + float[] vector, + Matcher> matchesMap + ) throws IOException { + try (IndexWriter w = new IndexWriter(dir, config)) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader("f"); + } + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + assertThat(offHeap, matchesMap); + } + } + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java index fdbf4679e6ab5..720d10489a5c2 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java @@ -36,8 +36,12 @@ import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.arrayWithSize; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasEntry; -// @com.carrotsearch.randomizedtesting.annotations.Repeat(iterations = 50) // tests.directory sys property? public class ES814HnswScalarQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { static { @@ -178,7 +182,7 @@ private void testSingleVectorPerSegment(VectorSimilarityFunction sim) throws Exc AcceptDocs.fromLiveDocs(leafReader.getLiveDocs(), leafReader.maxDoc()), 100 ); - assertEquals(hits.scoreDocs.length, 3); + assertThat(hits.scoreDocs, arrayWithSize(3)); assertEquals("B", storedFields.document(hits.scoreDocs[0].doc).get("id")); assertEquals("A", storedFields.document(hits.scoreDocs[1].doc).get("id")); assertEquals("C", storedFields.document(hits.scoreDocs[2].doc).get("id")); @@ -202,10 +206,11 @@ public void testSimpleOffHeapSize() throws IOException { } var fieldInfo = r.getFieldInfos().fieldInfo("f"); var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); - assertEquals(3, offHeap.size()); - assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec")); - assertEquals(1L, (long) offHeap.get("vex")); - assertTrue(offHeap.get("veq") > 0L); + + assertThat(offHeap, aMapWithSize(3)); + assertThat(offHeap, hasEntry("vex", 1L)); + assertThat(offHeap, hasEntry(equalTo("veq"), greaterThan(0L))); + assertThat(offHeap, hasEntry("vec", (long) vector.length * Float.BYTES)); } } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormatTests.java index 03de893958ab0..85391bf9aeba3 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormatTests.java @@ -26,6 +26,11 @@ import java.io.IOException; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasEntry; + public class ES815HnswBitVectorsFormatTests extends BaseKnnBitVectorsFormatTestCase { static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES815HnswBitVectorsFormat()); @@ -56,9 +61,10 @@ public void testSimpleOffHeapSize() throws IOException { } var fieldInfo = r.getFieldInfos().fieldInfo("f"); var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); - assertEquals(2, offHeap.size()); - assertTrue(offHeap.get("vec") > 0L); - assertEquals(1L, (long) offHeap.get("vex")); + + assertThat(offHeap, aMapWithSize(2)); + assertThat(offHeap, hasEntry("vex", 1L)); + assertThat(offHeap, hasEntry(equalTo("vec"), greaterThan(0L))); } } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java index c10fa9428bc13..1771aa23fbfec 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java @@ -19,15 +19,9 @@ */ package org.elasticsearch.index.codec.vectors.es816; -import org.apache.lucene.codecs.Codec; -import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; -import org.apache.lucene.codecs.KnnVectorsReader; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexReader; @@ -38,50 +32,55 @@ import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; -import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; -import org.apache.lucene.tests.util.TestUtil; -import org.apache.lucene.util.SameThreadExecutorService; -import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BaseHnswVectorsFormatTestCase; import java.io.IOException; -import java.util.Arrays; import java.util.Locale; +import java.util.concurrent.ExecutorService; import static java.lang.String.format; -import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; -import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasToString; import static org.hamcrest.Matchers.oneOf; -public class ES816HnswBinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { +public class ES816HnswBinaryQuantizedVectorsFormatTests extends BaseHnswVectorsFormatTestCase { - static { - LogConfigurator.loadLog4jPlugins(); - LogConfigurator.configureESLogging(); // native access requires logging to be initialized + @Override + protected KnnVectorsFormat createFormat() { + return new ES816HnswBinaryQuantizedRWVectorsFormat(); } - static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES816HnswBinaryQuantizedRWVectorsFormat()); + @Override + protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { + return new ES816HnswBinaryQuantizedRWVectorsFormat(maxConn, beamWidth); + } @Override - protected Codec getCodec() { - return codec; + protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) { + return new ES816HnswBinaryQuantizedRWVectorsFormat(maxConn, beamWidth, numMergeWorkers, service); } public void testToString() { - FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { - @Override - public KnnVectorsFormat knnVectorsFormat() { - return new ES816HnswBinaryQuantizedVectorsFormat(10, 20, 1, null); - } - }; - String expectedPattern = - "ES816HnswBinaryQuantizedVectorsFormat(name=ES816HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20," - + " flatVectorFormat=ES816BinaryQuantizedVectorsFormat(name=ES816BinaryQuantizedVectorsFormat," - + " flatVectorScorer=ES816BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())))"; + String expected = "ES816HnswBinaryQuantizedVectorsFormat" + + "(name=ES816HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=%s)"; + expected = format( + Locale.ROOT, + expected, + "ES816BinaryQuantizedVectorsFormat(name=ES816BinaryQuantizedVectorsFormat, flatVectorScorer=%s)" + ); + expected = format(Locale.ROOT, expected, "ES816BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())"); + + String defaultScorer = format(Locale.ROOT, expected, "DefaultFlatVectorScorer"); + String memSegScorer = format(Locale.ROOT, expected, "Lucene99MemorySegmentFlatVectorsScorer"); - var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); - var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); - assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + KnnVectorsFormat format = createFormat(10, 20, 1, null); + assertThat(format, hasToString(oneOf(defaultScorer, memSegScorer))); } public void testSingleVectorCase() throws Exception { @@ -96,7 +95,7 @@ public void testSingleVectorCase() throws Exception { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues("f"); KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); - assert (vectorValues.size() == 1); + assertThat(vectorValues.size(), equalTo(1)); while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f); } @@ -108,55 +107,26 @@ public void testSingleVectorCase() throws Exception { Integer.MAX_VALUE ); assertEquals(1, td.totalHits.value()); - assertTrue(td.scoreDocs[0].score >= 0); + assertThat(td.scoreDocs[0].score, greaterThanOrEqualTo(0f)); } } } } - public void testLimits() { - expectThrows(IllegalArgumentException.class, () -> new ES816HnswBinaryQuantizedVectorsFormat(-1, 20)); - expectThrows(IllegalArgumentException.class, () -> new ES816HnswBinaryQuantizedVectorsFormat(0, 20)); - expectThrows(IllegalArgumentException.class, () -> new ES816HnswBinaryQuantizedVectorsFormat(20, 0)); - expectThrows(IllegalArgumentException.class, () -> new ES816HnswBinaryQuantizedVectorsFormat(20, -1)); - expectThrows(IllegalArgumentException.class, () -> new ES816HnswBinaryQuantizedVectorsFormat(512 + 1, 20)); - expectThrows(IllegalArgumentException.class, () -> new ES816HnswBinaryQuantizedVectorsFormat(20, 3201)); - expectThrows( - IllegalArgumentException.class, - () -> new ES816HnswBinaryQuantizedVectorsFormat(20, 100, 1, new SameThreadExecutorService()) - ); - } - - // Ensures that all expected vector similarity functions are translatable in the format. - public void testVectorSimilarityFuncs() { - // This does not necessarily have to be all similarity functions, but - // differences should be considered carefully. - var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList(); - assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); - } - public void testSimpleOffHeapSize() throws IOException { float[] vector = randomVector(random().nextInt(12, 500)); - try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { - Document doc = new Document(); - doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); - w.addDocument(doc); - w.commit(); - try (IndexReader reader = DirectoryReader.open(w)) { - LeafReader r = getOnlyLeafReader(reader); - if (r instanceof CodecReader codecReader) { - KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); - if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { - knnVectorsReader = fieldsReader.getFieldReader("f"); - } - var fieldInfo = r.getFieldInfos().fieldInfo("f"); - var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); - assertEquals(3, offHeap.size()); - assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec")); - assertEquals(1L, (long) offHeap.get("vex")); - assertTrue(offHeap.get("veb") > 0L); - } - } + try (Directory dir = newDirectory()) { + testSimpleOffHeapSize( + dir, + newIndexWriterConfig(), + vector, + allOf( + aMapWithSize(3), + hasEntry("vex", 1L), + hasEntry(equalTo("veb"), greaterThan(0L)), + hasEntry("vec", (long) vector.length * Float.BYTES) + ) + ); } } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java index 268e359597039..a1765e9af060d 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java @@ -19,15 +19,9 @@ */ package org.elasticsearch.index.codec.vectors.es818; -import org.apache.lucene.codecs.Codec; -import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; -import org.apache.lucene.codecs.KnnVectorsReader; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexReader; @@ -40,52 +34,57 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.store.MMapDirectory; -import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.store.MockDirectoryWrapper; -import org.apache.lucene.tests.util.TestUtil; -import org.apache.lucene.util.SameThreadExecutorService; import org.apache.lucene.util.VectorUtil; -import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BaseHnswVectorsFormatTestCase; import java.io.IOException; -import java.util.Arrays; import java.util.Locale; +import java.util.concurrent.ExecutorService; import static java.lang.String.format; -import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; -import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasToString; import static org.hamcrest.Matchers.oneOf; -public class ES818HnswBinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { +public class ES818HnswBinaryQuantizedVectorsFormatTests extends BaseHnswVectorsFormatTestCase { - static { - LogConfigurator.loadLog4jPlugins(); - LogConfigurator.configureESLogging(); // native access requires logging to be initialized + @Override + protected KnnVectorsFormat createFormat() { + return new ES818HnswBinaryQuantizedVectorsFormat(); } - static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES818HnswBinaryQuantizedVectorsFormat()); + @Override + protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { + return new ES818HnswBinaryQuantizedVectorsFormat(maxConn, beamWidth); + } @Override - protected Codec getCodec() { - return codec; + protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) { + return new ES818HnswBinaryQuantizedVectorsFormat(maxConn, beamWidth, numMergeWorkers, service); } public void testToString() { - FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { - @Override - public KnnVectorsFormat knnVectorsFormat() { - return new ES818HnswBinaryQuantizedVectorsFormat(10, 20, 1, null); - } - }; - String expectedPattern = - "ES818HnswBinaryQuantizedVectorsFormat(name=ES818HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20," - + " flatVectorFormat=ES818BinaryQuantizedVectorsFormat(name=ES818BinaryQuantizedVectorsFormat," - + " flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())))"; - - var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); - var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); - assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + String expected = "ES818HnswBinaryQuantizedVectorsFormat" + + "(name=ES818HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=%s)"; + expected = format( + Locale.ROOT, + expected, + "ES818BinaryQuantizedVectorsFormat(name=ES818BinaryQuantizedVectorsFormat, flatVectorScorer=%s)" + ); + expected = format(Locale.ROOT, expected, "ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())"); + + String defaultScorer = format(Locale.ROOT, expected, "DefaultFlatVectorScorer"); + String memSegScorer = format(Locale.ROOT, expected, "Lucene99MemorySegmentFlatVectorsScorer"); + + KnnVectorsFormat format = createFormat(10, 20, 1, null); + assertThat(format, hasToString(oneOf(defaultScorer, memSegScorer))); } public void testSingleVectorCase() throws Exception { @@ -103,7 +102,7 @@ public void testSingleVectorCase() throws Exception { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues("f"); KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); - assert (vectorValues.size() == 1); + assertThat(vectorValues.size(), equalTo(1)); while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f); } @@ -120,7 +119,7 @@ public void testSingleVectorCase() throws Exception { Integer.MAX_VALUE ); assertEquals(1, td.totalHits.value()); - assertTrue(td.scoreDocs[0].score >= 0); + assertThat(td.scoreDocs[0].score, greaterThanOrEqualTo(0f)); // When it's the only vector in a segment, the score should be very close to the true score assertEquals(trueScore, td.scoreDocs[0].score, 0.01f); } @@ -128,27 +127,6 @@ public void testSingleVectorCase() throws Exception { } } - public void testLimits() { - expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(-1, 20)); - expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(0, 20)); - expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 0)); - expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(20, -1)); - expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(512 + 1, 20)); - expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 3201)); - expectThrows( - IllegalArgumentException.class, - () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 100, 1, new SameThreadExecutorService()) - ); - } - - // Ensures that all expected vector similarity functions are translatable in the format. - public void testVectorSimilarityFuncs() { - // This does not necessarily have to be all similarity functions, but - // differences should be considered carefully. - var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList(); - assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); - } - public void testSimpleOffHeapSize() throws IOException { try (Directory dir = newDirectory()) { testSimpleOffHeapSizeImpl(dir, newIndexWriterConfig(), true); @@ -163,29 +141,16 @@ public void testSimpleOffHeapSizeMMapDir() throws IOException { public void testSimpleOffHeapSizeImpl(Directory dir, IndexWriterConfig config, boolean expectVecOffHeap) throws IOException { float[] vector = randomVector(random().nextInt(12, 500)); - try (IndexWriter w = new IndexWriter(dir, config)) { - Document doc = new Document(); - doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); - w.addDocument(doc); - w.commit(); - try (IndexReader reader = DirectoryReader.open(w)) { - LeafReader r = getOnlyLeafReader(reader); - if (r instanceof CodecReader codecReader) { - KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); - if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { - knnVectorsReader = fieldsReader.getFieldReader("f"); - } - var fieldInfo = r.getFieldInfos().fieldInfo("f"); - var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); - assertEquals(expectVecOffHeap ? 3 : 2, offHeap.size()); - assertEquals(1L, (long) offHeap.get("vex")); - assertTrue(offHeap.get("veb") > 0L); - if (expectVecOffHeap) { - assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec")); - } - } - } - } + var matcher = expectVecOffHeap + ? allOf( + aMapWithSize(3), + hasEntry("vex", 1L), + hasEntry(equalTo("veb"), greaterThan(0L)), + hasEntry("vec", (long) vector.length * Float.BYTES) + ) + : allOf(aMapWithSize(2), hasEntry("vex", 1L), hasEntry(equalTo("veb"), greaterThan(0L))); + + testSimpleOffHeapSize(dir, config, vector, matcher); } static Directory newMMapDirectory() throws IOException { diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java index 45e489662f3bf..4590c1c0fba18 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java @@ -19,147 +19,80 @@ */ package org.elasticsearch.index.codec.vectors.es93; -import org.apache.lucene.codecs.Codec; -import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; -import org.apache.lucene.codecs.KnnVectorsReader; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; -import org.apache.lucene.document.Document; -import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.index.CodecReader; -import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.KnnVectorValues; -import org.apache.lucene.index.LeafReader; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.AcceptDocs; -import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.store.MMapDirectory; -import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.store.MockDirectoryWrapper; -import org.apache.lucene.tests.util.TestUtil; -import org.apache.lucene.util.SameThreadExecutorService; -import org.apache.lucene.util.VectorUtil; -import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.index.codec.vectors.BFloat16; +import org.elasticsearch.index.codec.vectors.BaseHnswVectorsFormatTestCase; import java.io.IOException; -import java.util.Arrays; +import java.util.Locale; +import java.util.concurrent.ExecutorService; import static java.lang.String.format; -import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; -import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; -import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; -import static org.hamcrest.Matchers.oneOf; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.either; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasToString; +import static org.hamcrest.Matchers.startsWith; -public class ES93HnswBinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { - - static { - LogConfigurator.loadLog4jPlugins(); - LogConfigurator.configureESLogging(); // native access requires logging to be initialized - } - - private KnnVectorsFormat format; +public class ES93HnswBinaryQuantizedVectorsFormatTests extends BaseHnswVectorsFormatTestCase { boolean useBFloat16() { return false; } @Override - public void setUp() throws Exception { - format = new ES93HnswBinaryQuantizedVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, useBFloat16(), random().nextBoolean()); - super.setUp(); + protected KnnVectorsFormat createFormat() { + return new ES93HnswBinaryQuantizedVectorsFormat(useBFloat16(), random().nextBoolean()); } @Override - protected Codec getCodec() { - return TestUtil.alwaysKnnVectorsFormat(format); - } - - public void testToString() { - FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { - @Override - public KnnVectorsFormat knnVectorsFormat() { - return new ES93HnswBinaryQuantizedVectorsFormat(10, 20, false, false, 1, null); - } - }; - String expectedPattern = "ES93HnswBinaryQuantizedVectorsFormat(name=ES93HnswBinaryQuantizedVectorsFormat," - + " maxConn=10, beamWidth=20," - + " flatVectorFormat=ES93BinaryQuantizedVectorsFormat(name=ES93BinaryQuantizedVectorsFormat," - + " rawVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat," - + " format=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer={}()))," - + " scorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate={}())))"; - - var defaultScorer = expectedPattern.replaceAll("\\{}", "DefaultFlatVectorScorer"); - var memSegScorer = expectedPattern.replaceAll("\\{}", "Lucene99MemorySegmentFlatVectorsScorer"); - assertThat(customCodec.knnVectorsFormat().toString(), oneOf(defaultScorer, memSegScorer)); + protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { + return new ES93HnswBinaryQuantizedVectorsFormat(maxConn, beamWidth, useBFloat16(), random().nextBoolean()); } - public void testSingleVectorCase() throws Exception { - float[] vector = randomVector(random().nextInt(12, 500)); - for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) { - try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { - Document doc = new Document(); - if (similarityFunction == VectorSimilarityFunction.COSINE) { - VectorUtil.l2normalize(vector); - } - doc.add(new KnnFloatVectorField("f", vector, similarityFunction)); - w.addDocument(doc); - w.commit(); - try (IndexReader reader = DirectoryReader.open(w)) { - LeafReader r = getOnlyLeafReader(reader); - FloatVectorValues vectorValues = r.getFloatVectorValues("f"); - KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); - assert (vectorValues.size() == 1); - while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { - assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f); - } - float[] randomVector = randomVector(vector.length); - if (similarityFunction == VectorSimilarityFunction.COSINE) { - VectorUtil.l2normalize(randomVector); - } - float trueScore = similarityFunction.compare(vector, randomVector); - TopDocs td = r.searchNearestVectors( - "f", - randomVector, - 1, - AcceptDocs.fromLiveDocs(r.getLiveDocs(), r.maxDoc()), - Integer.MAX_VALUE - ); - assertEquals(1, td.totalHits.value()); - assertTrue(td.scoreDocs[0].score >= 0); - // When it's the only vector in a segment, the score should be very close to the true score - assertEquals(trueScore, td.scoreDocs[0].score, 0.01f); - } - } - } + @Override + protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) { + return new ES93HnswBinaryQuantizedVectorsFormat( + maxConn, + beamWidth, + useBFloat16(), + random().nextBoolean(), + numMergeWorkers, + service + ); } - public void testLimits() { - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(-1, 20, false, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(0, 20, false, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 0, false, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, -1, false, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(512 + 1, 20, false, false)); - expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 3201, false, false)); - expectThrows( - IllegalArgumentException.class, - () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 100, false, false, 1, new SameThreadExecutorService()) + public void testToString() { + String expected = "ES93HnswBinaryQuantizedVectorsFormat(" + + "name=ES93HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=%s)"; + expected = format( + Locale.ROOT, + expected, + "ES93BinaryQuantizedVectorsFormat(name=ES93BinaryQuantizedVectorsFormat, rawVectorFormat=%s," + + " scorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate={}()))" ); - } + expected = format(Locale.ROOT, expected, "ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat, format=%s)"); + if (useBFloat16()) { + expected = format( + Locale.ROOT, + expected, + "ES93BFloat16FlatVectorsFormat(name=ES93BFloat16FlatVectorsFormat, flatVectorScorer={}())" + ); + } else { + expected = format(Locale.ROOT, expected, "Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer={}())"); + } + String defaultScorer = expected.replaceAll("\\{}", "DefaultFlatVectorScorer"); + String memSegScorer = expected.replaceAll("\\{}", "Lucene99MemorySegmentFlatVectorsScorer"); - // Ensures that all expected vector similarity functions are translatable in the format. - public void testVectorSimilarityFuncs() { - // This does not necessarily have to be all similarity functions, but - // differences should be considered carefully. - var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList(); - assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); + KnnVectorsFormat format = createFormat(10, 20, 1, null); + assertThat(format, hasToString(either(startsWith(defaultScorer)).or(startsWith(memSegScorer)))); } public void testSimpleOffHeapSize() throws IOException { @@ -176,30 +109,16 @@ public void testSimpleOffHeapSizeMMapDir() throws IOException { public void testSimpleOffHeapSizeImpl(Directory dir, IndexWriterConfig config, boolean expectVecOffHeap) throws IOException { float[] vector = randomVector(random().nextInt(12, 500)); - try (IndexWriter w = new IndexWriter(dir, config)) { - Document doc = new Document(); - doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); - w.addDocument(doc); - w.commit(); - try (IndexReader reader = DirectoryReader.open(w)) { - LeafReader r = getOnlyLeafReader(reader); - if (r instanceof CodecReader codecReader) { - KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); - if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { - knnVectorsReader = fieldsReader.getFieldReader("f"); - } - var fieldInfo = r.getFieldInfos().fieldInfo("f"); - var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); - assertEquals(expectVecOffHeap ? 3 : 2, offHeap.size()); - assertEquals(1L, (long) offHeap.get("vex")); - assertTrue(offHeap.get("veb") > 0L); - if (expectVecOffHeap) { - int bytes = useBFloat16() ? BFloat16.BYTES : Float.BYTES; - assertEquals(vector.length * bytes, (long) offHeap.get("vec")); - } - } - } - } + var matcher = expectVecOffHeap + ? allOf( + aMapWithSize(3), + hasEntry("vex", 1L), + hasEntry(equalTo("veb"), greaterThan(0L)), + hasEntry("vec", (long) vector.length * (useBFloat16() ? BFloat16.BYTES : Float.BYTES)) + ) + : allOf(aMapWithSize(2), hasEntry("vex", 1L), hasEntry(equalTo("veb"), greaterThan(0L))); + + testSimpleOffHeapSize(dir, config, vector, matcher); } static Directory newMMapDirectory() throws IOException {