diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java index fdb09594a1cda..f56bb8995b34e 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java @@ -19,6 +19,7 @@ import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizer; import org.elasticsearch.common.logging.LogConfigurator; @@ -76,10 +77,10 @@ public class VectorScorerBenchmark { float vec2Offset; float scoreCorrectionConstant; - RandomVectorScorer luceneDotScorer; - RandomVectorScorer luceneSqrScorer; - RandomVectorScorer nativeDotScorer; - RandomVectorScorer nativeSqrScorer; + UpdateableRandomVectorScorer luceneDotScorer; + UpdateableRandomVectorScorer luceneSqrScorer; + UpdateableRandomVectorScorer nativeDotScorer; + UpdateableRandomVectorScorer nativeSqrScorer; RandomVectorScorer luceneDotScorerQuery; RandomVectorScorer nativeDotScorerQuery; @@ -118,12 +119,16 @@ public void setup() throws IOException { in = dir.openInput("vector.data", IOContext.DEFAULT); var values = vectorValues(dims, 2, in, VectorSimilarityFunction.DOT_PRODUCT); scoreCorrectionConstant = values.getScalarQuantizer().getConstantMultiplier(); - luceneDotScorer = luceneScoreSupplier(values, VectorSimilarityFunction.DOT_PRODUCT).scorer(0); + luceneDotScorer = luceneScoreSupplier(values, VectorSimilarityFunction.DOT_PRODUCT).scorer(); + luceneDotScorer.setScoringOrdinal(0); values = vectorValues(dims, 2, in, VectorSimilarityFunction.EUCLIDEAN); - luceneSqrScorer = luceneScoreSupplier(values, VectorSimilarityFunction.EUCLIDEAN).scorer(0); + luceneSqrScorer = luceneScoreSupplier(values, VectorSimilarityFunction.EUCLIDEAN).scorer(); + luceneSqrScorer.setScoringOrdinal(0); - nativeDotScorer = factory.getInt7SQVectorScorerSupplier(DOT_PRODUCT, in, values, scoreCorrectionConstant).get().scorer(0); - nativeSqrScorer = factory.getInt7SQVectorScorerSupplier(EUCLIDEAN, in, values, scoreCorrectionConstant).get().scorer(0); + nativeDotScorer = factory.getInt7SQVectorScorerSupplier(DOT_PRODUCT, in, values, scoreCorrectionConstant).get().scorer(); + nativeDotScorer.setScoringOrdinal(0); + nativeSqrScorer = factory.getInt7SQVectorScorerSupplier(EUCLIDEAN, in, values, scoreCorrectionConstant).get().scorer(); + nativeSqrScorer.setScoringOrdinal(0); // setup for getInt7SQVectorScorer / query vector scoring float[] queryVec = new float[dims]; diff --git a/build-tools-internal/version.properties b/build-tools-internal/version.properties index b4a8bd8085936..0df89b184badc 100644 --- a/build-tools-internal/version.properties +++ b/build-tools-internal/version.properties @@ -1,5 +1,5 @@ elasticsearch = 9.1.0 -lucene = 10.1.0 +lucene = 10.2.0 bundled_jdk_vendor = openjdk bundled_jdk = 24+36@1f9ff9062db4449d8ca828c504ffae90 @@ -8,7 +8,7 @@ spatial4j = 0.7 jts = 1.15.0 jackson = 2.15.0 snakeyaml = 2.0 -icu4j = 68.2 +icu4j = 77.1 supercsv = 2.4.0 log4j = 2.19.0 slf4j = 2.0.6 diff --git a/docs/Versions.asciidoc b/docs/Versions.asciidoc index c2e14a399b70e..58195d7313a5a 100644 --- a/docs/Versions.asciidoc +++ b/docs/Versions.asciidoc @@ -1,8 +1,8 @@ include::{docs-root}/shared/versions/stack/{source_branch}.asciidoc[] -:lucene_version: 10.1.0 -:lucene_version_path: 10_1_0 +:lucene_version: 10.2.0 +:lucene_version_path: 10_2_0 :jdk: 11.0.2 :jdk_major: 11 :build_type: tar diff --git a/docs/changelog/126594.yaml b/docs/changelog/126594.yaml new file mode 100644 index 0000000000000..59743a606d34a --- /dev/null +++ b/docs/changelog/126594.yaml @@ -0,0 +1,5 @@ +pr: 126594 +summary: Upgrade to Lucene 10.2.0 +area: Search +type: upgrade +issues: [] diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index 37c3e46f4bfa8..38b174fd5c2a4 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -911,21 +911,16 @@ - - - - - - - - - - + + + + + @@ -2966,179 +2961,129 @@ - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + - - - - - + + + diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java index 198e10406056e..19f33ba1c71f7 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java @@ -10,8 +10,8 @@ package org.elasticsearch.simdvec.internal; import org.apache.lucene.store.MemorySegmentAccessInput; -import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity; @@ -55,9 +55,6 @@ protected final void checkOrdinal(int ord) { } final float scoreFromOrds(int firstOrd, int secondOrd) throws IOException { - checkOrdinal(firstOrd); - checkOrdinal(secondOrd); - final int length = dims; long firstByteOffset = (long) firstOrd * (length + Float.BYTES); long secondByteOffset = (long) secondOrd * (length + Float.BYTES); @@ -92,13 +89,21 @@ protected final float fallbackScore(long firstByteOffset, long secondByteOffset) } @Override - public RandomVectorScorer scorer(int ord) { - checkOrdinal(ord); - return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + public UpdateableRandomVectorScorer scorer() { + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { + private int ord = -1; + @Override public float score(int node) throws IOException { + checkOrdinal(node); return scoreFromOrds(ord, node); } + + @Override + public void setScoringOrdinal(int node) throws IOException { + checkOrdinal(node); + this.ord = node; + } }; } diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/VectorScorerFactoryTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/VectorScorerFactoryTests.java index 0f967127f6f2c..ff07e26661ae6 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/VectorScorerFactoryTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/VectorScorerFactoryTests.java @@ -19,8 +19,8 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.MMapDirectory; -import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizer; @@ -50,6 +50,8 @@ // @com.carrotsearch.randomizedtesting.annotations.Repeat(iterations = 100) public class VectorScorerFactoryTests extends AbstractVectorTestCase { + private static final float DELTA = 1e-4f; + // bounds of the range of values that can be seen by int7 scalar quantized vectors static final byte MIN_INT7_VALUE = 0; static final byte MAX_INT7_VALUE = 127; @@ -99,10 +101,13 @@ void testSimpleImpl(long maxChunkSize) throws IOException { float scc = values.getScalarQuantizer().getConstantMultiplier(); float expected = luceneScore(sim, vec1, vec2, scc, vec1Correction, vec2Correction); - var luceneSupplier = luceneScoreSupplier(values, VectorSimilarityType.of(sim)).scorer(0); + var luceneSupplier = luceneScoreSupplier(values, VectorSimilarityType.of(sim)).scorer(); + luceneSupplier.setScoringOrdinal(0); assertThat(luceneSupplier.score(1), equalTo(expected)); var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, scc).get(); - assertThat(supplier.scorer(0).score(1), equalTo(expected)); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(0); + assertThat(scorer.score(1), equalTo(expected)); if (Runtime.version().feature() >= 22) { var qScorer = factory.getInt7SQVectorScorer(VectorSimilarityType.of(sim), values, query1).get(); @@ -134,24 +139,32 @@ public void testNonNegativeDotProduct() throws IOException { float expected = 0f; assertThat(luceneScore(DOT_PRODUCT, vec1, vec2, 1, -5, -5), equalTo(expected)); var supplier = factory.getInt7SQVectorScorerSupplier(DOT_PRODUCT, in, values, 1).get(); - assertThat(supplier.scorer(0).score(1), equalTo(expected)); - assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f)); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(0); + assertThat(scorer.score(1), equalTo(expected)); + assertThat(scorer.score(1), greaterThanOrEqualTo(0f)); // max inner product expected = luceneScore(MAXIMUM_INNER_PRODUCT, vec1, vec2, 1, -5, -5); supplier = factory.getInt7SQVectorScorerSupplier(MAXIMUM_INNER_PRODUCT, in, values, 1).get(); - assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f)); - assertThat(supplier.scorer(0).score(1), equalTo(expected)); + scorer = supplier.scorer(); + scorer.setScoringOrdinal(0); + assertThat(scorer.score(1), greaterThanOrEqualTo(0f)); + assertThat(scorer.score(1), equalTo(expected)); // cosine expected = 0f; assertThat(luceneScore(COSINE, vec1, vec2, 1, -5, -5), equalTo(expected)); supplier = factory.getInt7SQVectorScorerSupplier(COSINE, in, values, 1).get(); - assertThat(supplier.scorer(0).score(1), equalTo(expected)); - assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f)); + scorer = supplier.scorer(); + scorer.setScoringOrdinal(0); + assertThat(scorer.score(1), equalTo(expected)); + assertThat(scorer.score(1), greaterThanOrEqualTo(0f)); // euclidean expected = luceneScore(EUCLIDEAN, vec1, vec2, 1, -5, -5); supplier = factory.getInt7SQVectorScorerSupplier(EUCLIDEAN, in, values, 1).get(); - assertThat(supplier.scorer(0).score(1), equalTo(expected)); - assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f)); + scorer = supplier.scorer(); + scorer.setScoringOrdinal(0); + assertThat(scorer.score(1), equalTo(expected)); + assertThat(scorer.score(1), greaterThanOrEqualTo(0f)); } } } @@ -208,7 +221,9 @@ void testRandomSupplier(long maxChunkSize, Function byteArraySu var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim)); float expected = luceneScore(sim, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]); var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, correction).get(); - assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected)); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx0); + assertThat(scorer.score(idx1), equalTo(expected)); } } } @@ -265,7 +280,7 @@ void testRandomScorerImpl(long maxChunkSize, Function floatArr var expected = luceneScore(sim, qVectors[idx0], qVectors[idx1], correction, corrections[idx0], corrections[idx1]); var scorer = factory.getInt7SQVectorScorer(VectorSimilarityType.of(sim), values, vectors[idx0]).get(); - assertThat(scorer.score(idx1), equalTo(expected)); + assertEquals(scorer.score(idx1), expected, DELTA); } } } @@ -313,7 +328,9 @@ void testRandomSliceImpl(int dims, long maxChunkSize, int initialPadding, Functi var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim)); float expected = luceneScore(sim, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]); var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, correction).get(); - assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected)); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx0); + assertThat(scorer.score(idx1), equalTo(expected)); } } } @@ -352,7 +369,9 @@ public void testLarge() throws IOException { var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim)); float expected = luceneScore(sim, vector(idx0, dims), vector(idx1, dims), correction, off0, off1); var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, correction).get(); - assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected)); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx0); + assertThat(scorer.score(idx1), equalTo(expected)); } } } @@ -391,8 +410,8 @@ void testRaceImpl(VectorSimilarityType sim) throws Exception { var values = vectorValues(dims, 4, in, VectorSimilarityType.of(sim)); var scoreSupplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, 1f).get(); var tasks = List.>>of( - new ScoreCallable(scoreSupplier.copy().scorer(0), 1, expectedScore1), - new ScoreCallable(scoreSupplier.copy().scorer(2), 3, expectedScore2) + new ScoreCallable(scoreSupplier.copy().scorer(), 0, 1, expectedScore1), + new ScoreCallable(scoreSupplier.copy().scorer(), 2, 3, expectedScore2) ); var executor = Executors.newFixedThreadPool(2); var results = executor.invokeAll(tasks); @@ -408,14 +427,19 @@ void testRaceImpl(VectorSimilarityType sim) throws Exception { static class ScoreCallable implements Callable> { - final RandomVectorScorer scorer; + final UpdateableRandomVectorScorer scorer; final int ord; final float expectedScore; - ScoreCallable(RandomVectorScorer scorer, int ord, float expectedScore) { - this.scorer = scorer; - this.ord = ord; - this.expectedScore = expectedScore; + ScoreCallable(UpdateableRandomVectorScorer scorer, int queryOrd, int ord, float expectedScore) { + try { + this.scorer = scorer; + this.scorer.setScoringOrdinal(queryOrd); + this.ord = ord; + this.expectedScore = expectedScore; + } catch (IOException e) { + throw new RuntimeException(e); + } } @Override diff --git a/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/AbstractCompoundWordTokenFilterFactory.java b/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/AbstractCompoundWordTokenFilterFactory.java index c0a8c1374c146..a144df94a7750 100644 --- a/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/AbstractCompoundWordTokenFilterFactory.java +++ b/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/AbstractCompoundWordTokenFilterFactory.java @@ -28,6 +28,8 @@ public abstract class AbstractCompoundWordTokenFilterFactory extends AbstractTok protected final int maxSubwordSize; protected final boolean onlyLongestMatch; protected final CharArraySet wordList; + // TODO expose this parameter? + protected final boolean reuseChars; protected AbstractCompoundWordTokenFilterFactory(IndexSettings indexSettings, Environment env, String name, Settings settings) { super(name); @@ -36,6 +38,8 @@ protected AbstractCompoundWordTokenFilterFactory(IndexSettings indexSettings, En minSubwordSize = settings.getAsInt("min_subword_size", CompoundWordTokenFilterBase.DEFAULT_MIN_SUBWORD_SIZE); maxSubwordSize = settings.getAsInt("max_subword_size", CompoundWordTokenFilterBase.DEFAULT_MAX_SUBWORD_SIZE); onlyLongestMatch = settings.getAsBoolean("only_longest_match", false); + // TODO is the default of true correct? see: https://github.com/apache/lucene/pull/14278 + reuseChars = true; wordList = Analysis.getWordSet(env, settings, "word_list"); if (wordList == null) { throw new IllegalArgumentException("word_list must be provided for [" + name + "], either as a path to a file, or directly"); diff --git a/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/DictionaryCompoundWordTokenFilterFactory.java b/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/DictionaryCompoundWordTokenFilterFactory.java index 7c2bb6ba1c116..4ec328070a49a 100644 --- a/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/DictionaryCompoundWordTokenFilterFactory.java +++ b/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/DictionaryCompoundWordTokenFilterFactory.java @@ -28,6 +28,14 @@ public class DictionaryCompoundWordTokenFilterFactory extends AbstractCompoundWo @Override public TokenStream create(TokenStream tokenStream) { - return new DictionaryCompoundWordTokenFilter(tokenStream, wordList, minWordSize, minSubwordSize, maxSubwordSize, onlyLongestMatch); + return new DictionaryCompoundWordTokenFilter( + tokenStream, + wordList, + minWordSize, + minSubwordSize, + maxSubwordSize, + onlyLongestMatch, + reuseChars + ); } } diff --git a/server/src/main/java/org/elasticsearch/index/IndexVersions.java b/server/src/main/java/org/elasticsearch/index/IndexVersions.java index 531e9455d1a2e..4c07be7d9200b 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexVersions.java +++ b/server/src/main/java/org/elasticsearch/index/IndexVersions.java @@ -159,6 +159,7 @@ private static Version parseUnchecked(String version) { public static final IndexVersion SYNTHETIC_SOURCE_STORE_ARRAYS_NATIVELY_UNSIGNED_LONG = def(9_019_0_00, Version.LUCENE_10_1_0); public static final IndexVersion SYNTHETIC_SOURCE_STORE_ARRAYS_NATIVELY_SCALED_FLOAT = def(9_020_0_00, Version.LUCENE_10_1_0); public static final IndexVersion USE_LUCENE101_POSTINGS_FORMAT = def(9_021_0_00, Version.LUCENE_10_1_0); + public static final IndexVersion UPGRADE_TO_LUCENE_10_2_0 = def(9_022_00_0, Version.LUCENE_10_2_0); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java index 18668f4f304b0..e3242ee411e7d 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java @@ -22,6 +22,7 @@ import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import java.io.IOException; @@ -130,18 +131,33 @@ public float score(int i) throws IOException { } static class HammingScorerSupplier implements RandomVectorScorerSupplier { - private final ByteVectorValues byteValues, byteValues1, byteValues2; + private final ByteVectorValues byteValues, targetValues; HammingScorerSupplier(ByteVectorValues byteValues) throws IOException { this.byteValues = byteValues; - this.byteValues1 = byteValues.copy(); - this.byteValues2 = byteValues.copy(); + this.targetValues = byteValues.copy(); } @Override - public RandomVectorScorer scorer(int i) throws IOException { - byte[] query = byteValues1.vectorValue(i); - return new HammingVectorScorer(byteValues2, query); + public UpdateableRandomVectorScorer scorer() throws IOException { + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(targetValues) { + private final byte[] query = new byte[targetValues.dimension()]; + private int currentOrd = -1; + + @Override + public void setScoringOrdinal(int i) throws IOException { + if (currentOrd == i) { + return; + } + System.arraycopy(targetValues.vectorValue(i), 0, query, 0, query.length); + this.currentOrd = i; + } + + @Override + public float score(int i) throws IOException { + return hammingScore(targetValues.vectorValue(i), query); + } + }; } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java index 7c7e470909eb3..87b744b4e4eec 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java @@ -26,6 +26,7 @@ import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.elasticsearch.index.codec.vectors.BQSpaceUtils; import org.elasticsearch.simdvec.ESVectorUtil; @@ -76,8 +77,20 @@ public RandomVectorScorer getRandomVectorScorer( byte[] quantized = new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8]; OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(target, initial, (byte) 4, centroid); BQSpaceUtils.transposeHalfByte(initial, quantized); - BinaryQueryVector queryVector = new BinaryQueryVector(quantized, queryCorrections); - return new BinarizedRandomVectorScorer(queryVector, binarizedVectors, similarityFunction); + return new RandomVectorScorer.AbstractRandomVectorScorer(vectorValues) { + @Override + public float score(int i) throws IOException { + return quantizedScore( + binarizedVectors.dimension(), + similarityFunction, + binarizedVectors.getCentroidDP(), + quantized, + queryCorrections, + binarizedVectors.vectorValue(i), + binarizedVectors.getCorrectiveTerms(i) + ); + } + }; } return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); } @@ -121,68 +134,95 @@ static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSu } @Override - public RandomVectorScorer scorer(int ord) throws IOException { - byte[] vector = queryVectors.vectorValue(ord); - OptimizedScalarQuantizer.QuantizationResult correctiveTerms = queryVectors.getCorrectiveTerms(ord); - BinaryQueryVector binaryQueryVector = new BinaryQueryVector(vector, correctiveTerms); - return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction); + public BinarizedRandomVectorScorer scorer() throws IOException { + return new BinarizedRandomVectorScorer(queryVectors.copy(), targetVectors.copy(), similarityFunction); } @Override public RandomVectorScorerSupplier copy() throws IOException { - return new BinarizedRandomVectorScorerSupplier(queryVectors.copy(), targetVectors.copy(), similarityFunction); + return new BinarizedRandomVectorScorerSupplier(queryVectors, targetVectors, similarityFunction); } } - /** A binarized query representing its quantized form along with factors */ - public record BinaryQueryVector(byte[] vector, OptimizedScalarQuantizer.QuantizationResult quantizationResult) {} - /** Vector scorer over binarized vector values */ - public static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { - private final BinaryQueryVector queryVector; + public static class BinarizedRandomVectorScorer extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer { + private final ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors; private final BinarizedByteVectorValues targetVectors; private final VectorSimilarityFunction similarityFunction; + private final byte[] quantizedQuery; + private OptimizedScalarQuantizer.QuantizationResult queryCorrections = null; + private int currentOrdinal = -1; - public BinarizedRandomVectorScorer( - BinaryQueryVector queryVectors, + BinarizedRandomVectorScorer( + ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors, BinarizedByteVectorValues targetVectors, VectorSimilarityFunction similarityFunction ) { super(targetVectors); - this.queryVector = queryVectors; + this.queryVectors = queryVectors; + this.quantizedQuery = new byte[queryVectors.quantizedDimension()]; this.targetVectors = targetVectors; this.similarityFunction = similarityFunction; } @Override public float score(int targetOrd) throws IOException { - byte[] quantizedQuery = queryVector.vector(); - byte[] binaryCode = targetVectors.vectorValue(targetOrd); - float qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode); - OptimizedScalarQuantizer.QuantizationResult queryCorrections = queryVector.quantizationResult(); - OptimizedScalarQuantizer.QuantizationResult indexCorrections = targetVectors.getCorrectiveTerms(targetOrd); - float x1 = indexCorrections.quantizedComponentSum(); - float ax = indexCorrections.lowerInterval(); - // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary - float lx = indexCorrections.upperInterval() - ax; - float ay = queryCorrections.lowerInterval(); - float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; - float y1 = queryCorrections.quantizedComponentSum(); - float score = ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist; - // For euclidean, we need to invert the score and apply the additional correction, which is - // assumed to be the squared l2norm of the centroid centered vectors. - if (similarityFunction == EUCLIDEAN) { - score = queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - 2 * score; - return Math.max(1 / (1f + score), 0); - } else { - // For cosine and max inner product, we need to apply the additional correction, which is - // assumed to be the non-centered dot-product between the vector and the centroid - score += queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - targetVectors.getCentroidDP(); - if (similarityFunction == MAXIMUM_INNER_PRODUCT) { - return VectorUtil.scaleMaxInnerProductScore(score); - } - return Math.max((1f + score) / 2f, 0); + if (queryCorrections == null) { + throw new IllegalStateException("score() called before setScoringOrdinal()"); + } + return quantizedScore( + targetVectors.dimension(), + similarityFunction, + targetVectors.getCentroidDP(), + quantizedQuery, + queryCorrections, + targetVectors.vectorValue(targetOrd), + targetVectors.getCorrectiveTerms(targetOrd) + ); + } + + @Override + public void setScoringOrdinal(int i) throws IOException { + if (i == currentOrdinal) { + return; + } + System.arraycopy(queryVectors.vectorValue(i), 0, quantizedQuery, 0, quantizedQuery.length); + queryCorrections = queryVectors.getCorrectiveTerms(i); + currentOrdinal = i; + } + } + + private static float quantizedScore( + int dims, + VectorSimilarityFunction similarityFunction, + float centroidDp, + byte[] q, + OptimizedScalarQuantizer.QuantizationResult queryCorrections, + byte[] d, + OptimizedScalarQuantizer.QuantizationResult indexCorrections + ) { + float qcDist = ESVectorUtil.ipByteBinByte(q, d); + float x1 = indexCorrections.quantizedComponentSum(); + float ax = indexCorrections.lowerInterval(); + // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary + float lx = indexCorrections.upperInterval() - ax; + float ay = queryCorrections.lowerInterval(); + float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; + float y1 = queryCorrections.quantizedComponentSum(); + float score = ax * ay * dims + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist; + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + score = queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - 2 * score; + return Math.max(1 / (1f + score), 0); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + score += queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - centroidDp; + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); } + return Math.max((1f + score) / 2f, 0); } } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java index 932925ea423ba..8dba2dbee9f5b 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java @@ -44,8 +44,8 @@ import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; -import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.index.codec.vectors.BQSpaceUtils; import org.elasticsearch.index.codec.vectors.BQVectorUtils; @@ -763,6 +763,10 @@ public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int target ); } + int quantizedDimension() { + return byteBuffer.array().length; + } + public int size() { return size; } @@ -887,8 +891,8 @@ static class BinarizedCloseableRandomVectorScorerSupplier implements CloseableRa } @Override - public RandomVectorScorer scorer(int ord) throws IOException { - return supplier.scorer(ord); + public UpdateableRandomVectorScorer scorer() throws IOException { + return supplier.scorer(); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/PointsSortedDocsProducer.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/PointsSortedDocsProducer.java index e88c9724edba1..860205ebb23bc 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/PointsSortedDocsProducer.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/PointsSortedDocsProducer.java @@ -58,7 +58,7 @@ DocIdSet processLeaf(CompositeValuesCollectorQueue queue, LeafReaderContext cont } upperBucket = (Long) upperValue; } - DocIdSetBuilder builder = fillDocIdSet ? new DocIdSetBuilder(context.reader().maxDoc(), values, field) : null; + DocIdSetBuilder builder = fillDocIdSet ? new DocIdSetBuilder(context.reader().maxDoc(), values) : null; Visitor visitor = new Visitor(context, queue, builder, values.getBytesPerDimension(), lowerBucket, upperBucket); try { values.intersect(visitor); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/filter/FiltersAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/filter/FiltersAggregator.java index bf6fb39d43c4b..ff3f93995dbae 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/filter/FiltersAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/filter/FiltersAggregator.java @@ -328,7 +328,11 @@ protected LeafBucketCollector getLeafCollector(AggregationExecutionContext aggCt hasOtherBucket ); } - return new MultiFilterLeafCollector(sub, filterWrappers, numFilters, totalNumKeys, usesCompetitiveIterator, hasOtherBucket); + if (usesCompetitiveIterator) { + return new MultiFilterCompetitiveLeafCollector(sub, filterWrappers, numFilters, totalNumKeys, hasOtherBucket); + } else { + return new MultiFilterLeafCollector(sub, filterWrappers, numFilters, totalNumKeys, hasOtherBucket); + } } } @@ -400,21 +404,20 @@ public DocIdSetIterator competitiveIterator() throws IOException { } } - private class MultiFilterLeafCollector extends AbstractLeafCollector { + private final class MultiFilterLeafCollector extends AbstractLeafCollector { // A DocIdSetIterator heap with one entry for each filter, ordered by doc ID - final DisiPriorityQueue filterIterators; + DisiPriorityQueue filterIterators; MultiFilterLeafCollector( LeafBucketCollector sub, List filterWrappers, int numFilters, int totalNumKeys, - boolean usesCompetitiveIterator, boolean hasOtherBucket ) { - super(sub, numFilters, totalNumKeys, usesCompetitiveIterator, hasOtherBucket); - filterIterators = filterWrappers.isEmpty() ? null : new DisiPriorityQueue(filterWrappers.size()); + super(sub, numFilters, totalNumKeys, false, hasOtherBucket); + filterIterators = filterWrappers.isEmpty() ? null : DisiPriorityQueue.ofMaxSize(filterWrappers.size()); for (FilterMatchingDisiWrapper wrapper : filterWrappers) { filterIterators.add(wrapper); } @@ -423,7 +426,7 @@ private class MultiFilterLeafCollector extends AbstractLeafCollector { public void collect(int doc, long bucket) throws IOException { boolean matched = false; if (filterIterators != null) { - // Advance filters if necessary. Filters will already be advanced if used as a competitive iterator. + // Advance filters if necessary. DisiWrapper top = filterIterators.top(); while (top.doc < doc) { top.doc = top.approximation.advance(doc); @@ -448,16 +451,51 @@ public void collect(int doc, long bucket) throws IOException { } @Override - public DocIdSetIterator competitiveIterator() throws IOException { - if (usesCompetitiveIterator) { - // A DocIdSetIterator view of the filterIterators heap - assert filterIterators != null; - return new DisjunctionDISIApproximation(filterIterators); - } + public DocIdSetIterator competitiveIterator() { return null; } } + private final class MultiFilterCompetitiveLeafCollector extends AbstractLeafCollector { + + private final DisjunctionDISIApproximation disjunctionDisi; + + MultiFilterCompetitiveLeafCollector( + LeafBucketCollector sub, + List filterWrappers, + int numFilters, + int totalNumKeys, + boolean hasOtherBucket + ) { + super(sub, numFilters, totalNumKeys, true, hasOtherBucket); + assert filterWrappers.isEmpty() == false; + disjunctionDisi = DisjunctionDISIApproximation.of(filterWrappers, Long.MAX_VALUE); + } + + public void collect(int doc, long bucket) throws IOException { + boolean matched = false; + int target = disjunctionDisi.advance(doc); + if (target == doc) { + for (DisiWrapper w = disjunctionDisi.topList(); w != null; w = w.next) { + FilterMatchingDisiWrapper topMatch = (FilterMatchingDisiWrapper) w; + if (topMatch.checkDocForMatch(doc)) { + collectBucket(sub, doc, bucketOrd(bucket, topMatch.filterOrd)); + matched = true; + } + } + } + + if (hasOtherBucket && false == matched) { + collectBucket(sub, doc, bucketOrd(bucket, numFilters)); + } + } + + @Override + public DocIdSetIterator competitiveIterator() { + return disjunctionDisi; + } + } + private static class FilterMatchingDisiWrapper extends DisiWrapper { final int filterOrd; diff --git a/server/src/test/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzerTests.java b/server/src/test/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzerTests.java index 285cb8a9564be..fdcfe5e3720f8 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzerTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzerTests.java @@ -56,7 +56,6 @@ import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; import org.apache.lucene.search.suggest.document.Completion101PostingsFormat; -import org.apache.lucene.search.suggest.document.CompletionPostingsFormat; import org.apache.lucene.search.suggest.document.SuggestField; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FilterDirectory; @@ -331,7 +330,7 @@ public void testCompletionField() throws Exception { @Override public PostingsFormat getPostingsFormatForField(String field) { if (field.startsWith("suggest_")) { - return new Completion101PostingsFormat(randomFrom(CompletionPostingsFormat.FSTLoadMode.values())); + return new Completion101PostingsFormat(); } else { return super.postingsFormat(); } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java index 0bebe16f468ce..b897a004a9781 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java @@ -26,6 +26,7 @@ import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.elasticsearch.index.codec.vectors.BQSpaceUtils; import org.elasticsearch.index.codec.vectors.BQVectorUtils; import org.elasticsearch.simdvec.ESVectorUtil; @@ -120,28 +121,48 @@ static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSu } @Override - public RandomVectorScorer scorer(int ord) throws IOException { - byte[] vector = queryVectors.vectorValue(ord); - int quantizedSum = queryVectors.sumQuantizedValues(ord); - float distanceToCentroid = queryVectors.getCentroidDistance(ord); - float lower = queryVectors.getLower(ord); - float width = queryVectors.getWidth(ord); - float normVmC = 0f; - float vDotC = 0f; - if (similarityFunction != EUCLIDEAN) { - normVmC = queryVectors.getNormVmC(ord); - vDotC = queryVectors.getVDotC(ord); - } - BinaryQueryVector binaryQueryVector = new BinaryQueryVector( - vector, - new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, lower, width, normVmC, vDotC) - ); - return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction); + public UpdateableRandomVectorScorer scorer() throws IOException { + byte[] queryVector = new byte[queryVectors.quantizedDimension()]; + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(targetVectors) { + private int ord = -1; + private BinaryQuantizer.QueryFactors factors = null; + private final ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors = + BinarizedRandomVectorScorerSupplier.this.queryVectors.copy(); + private final BinarizedByteVectorValues targetVectors = BinarizedRandomVectorScorerSupplier.this.targetVectors.copy(); + + @Override + public void setScoringOrdinal(int i) throws IOException { + if (i == ord) { + return; + } + ord = i; + System.arraycopy(queryVectors.vectorValue(i), 0, queryVector, 0, queryVector.length); + int quantizedSum = queryVectors.sumQuantizedValues(ord); + float distanceToCentroid = queryVectors.getCentroidDistance(ord); + float lower = queryVectors.getLower(ord); + float width = queryVectors.getWidth(ord); + float normVmC = 0f; + float vDotC = 0f; + if (similarityFunction != EUCLIDEAN) { + normVmC = queryVectors.getNormVmC(ord); + vDotC = queryVectors.getVDotC(ord); + } + factors = new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, lower, width, normVmC, vDotC); + } + + @Override + public float score(int i) throws IOException { + if (factors == null) { + throw new IllegalStateException("setScoringOrdinal must be called before score"); + } + return quantizedScore(this.targetVectors, i, queryVector, factors, similarityFunction); + } + }; } @Override public RandomVectorScorerSupplier copy() throws IOException { - return new BinarizedRandomVectorScorerSupplier(queryVectors.copy(), targetVectors.copy(), similarityFunction); + return new BinarizedRandomVectorScorerSupplier(queryVectors, targetVectors, similarityFunction); } } @@ -154,9 +175,6 @@ static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRand private final BinarizedByteVectorValues targetVectors; private final VectorSimilarityFunction similarityFunction; - private final float sqrtDimensions; - private final float maxX1; - BinarizedRandomVectorScorer( BinaryQueryVector queryVectors, BinarizedByteVectorValues targetVectors, @@ -166,91 +184,97 @@ static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRand this.queryVector = queryVectors; this.targetVectors = targetVectors; this.similarityFunction = similarityFunction; - // FIXME: precompute this once? - this.sqrtDimensions = targetVectors.sqrtDimensions(); - this.maxX1 = targetVectors.maxX1(); } @Override public float score(int targetOrd) throws IOException { - byte[] quantizedQuery = queryVector.vector(); - int quantizedSum = queryVector.factors().quantizedSum(); - float lower = queryVector.factors().lower(); - float width = queryVector.factors().width(); - float distanceToCentroid = queryVector.factors().distToC(); - if (similarityFunction == EUCLIDEAN) { - return euclideanScore(targetOrd, sqrtDimensions, quantizedQuery, distanceToCentroid, lower, quantizedSum, width); - } + return quantizedScore(targetVectors, targetOrd, queryVector.vector(), queryVector.factors(), similarityFunction); + } + } - float vmC = queryVector.factors().normVmC(); - float vDotC = queryVector.factors().vDotC(); - float cDotC = targetVectors.getCentroidDP(); - byte[] binaryCode = targetVectors.vectorValue(targetOrd); - float ooq = targetVectors.getOOQ(targetOrd); - float normOC = targetVectors.getNormOC(targetOrd); - float oDotC = targetVectors.getODotC(targetOrd); + private static float quantizedScore( + BinarizedByteVectorValues targetVectors, + int targetOrd, + byte[] quantizedQuery, + BinaryQuantizer.QueryFactors queryFactors, + VectorSimilarityFunction similarityFunction + ) throws IOException { + if (similarityFunction == EUCLIDEAN) { + return euclideanQuantizedScore(targetVectors, targetOrd, queryFactors, quantizedQuery); + } + return dotProductQuantizedScore(targetVectors, targetOrd, quantizedQuery, queryFactors, similarityFunction); + } - float qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode); + private static float dotProductQuantizedScore( + BinarizedByteVectorValues targetVectors, + int targetOrd, + byte[] quantizedQuery, + BinaryQuantizer.QueryFactors queryFactors, + VectorSimilarityFunction similarityFunction + ) throws IOException { + float vmC = queryFactors.normVmC(); + float vDotC = queryFactors.vDotC(); + float cDotC = targetVectors.getCentroidDP(); + byte[] binaryCode = targetVectors.vectorValue(targetOrd); + float ooq = targetVectors.getOOQ(targetOrd); + float normOC = targetVectors.getNormOC(targetOrd); + float oDotC = targetVectors.getODotC(targetOrd); - // FIXME: pre-compute these only once for each target vector - // ... pull this out or use a similar cache mechanism as do in score - float xbSum = (float) BQVectorUtils.popcount(binaryCode); - final float dist; - // If ||o-c|| == 0, so, it's ok to throw the rest of the equation away - // and simply use `oDotC + vDotC - cDotC` as centroid == doc vector - if (normOC == 0 || ooq == 0) { - dist = oDotC + vDotC - cDotC; - } else { - // If ||o-c|| != 0, we should assume that `ooq` is finite - assert Float.isFinite(ooq); - float estimatedDot = (2 * width / sqrtDimensions * qcDist + 2 * lower / sqrtDimensions * xbSum - width / sqrtDimensions - * quantizedSum - sqrtDimensions * lower) / ooq; - dist = vmC * normOC * estimatedDot + oDotC + vDotC - cDotC; - } - assert Float.isFinite(dist); + float qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode); - float ooqSqr = (float) Math.pow(ooq, 2); - float errorBound = (float) (vmC * normOC * (maxX1 * Math.sqrt((1 - ooqSqr) / ooqSqr))); - float score = Float.isFinite(errorBound) ? dist - errorBound : dist; - if (similarityFunction == MAXIMUM_INNER_PRODUCT) { - return VectorUtil.scaleMaxInnerProductScore(score); - } - return Math.max((1f + score) / 2f, 0); + float xbSum = (float) BQVectorUtils.popcount(binaryCode); + final float dist; + // If ||o-c|| == 0, so, it's ok to throw the rest of the equation away + // and simply use `oDotC + vDotC - cDotC` as centroid == doc vector + if (normOC == 0 || ooq == 0) { + dist = oDotC + vDotC - cDotC; + } else { + // If ||o-c|| != 0, we should assume that `ooq` is finite + assert Float.isFinite(ooq); + float estimatedDot = (2 * queryFactors.width() / targetVectors.sqrtDimensions() * qcDist + 2 * queryFactors.lower() + / targetVectors.sqrtDimensions() * xbSum - queryFactors.width() / targetVectors.sqrtDimensions() * queryFactors + .quantizedSum() - targetVectors.sqrtDimensions() * queryFactors.lower()) / ooq; + dist = vmC * normOC * estimatedDot + oDotC + vDotC - cDotC; + } + assert Float.isFinite(dist); + + float ooqSqr = (float) Math.pow(ooq, 2); + float errorBound = (float) (vmC * normOC * (targetVectors.maxX1() * Math.sqrt((1 - ooqSqr) / ooqSqr))); + float score = Float.isFinite(errorBound) ? dist - errorBound : dist; + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); } + return Math.max((1f + score) / 2f, 0); + } - private float euclideanScore( - int targetOrd, - float sqrtDimensions, - byte[] quantizedQuery, - float distanceToCentroid, - float lower, - int quantizedSum, - float width - ) throws IOException { - byte[] binaryCode = targetVectors.vectorValue(targetOrd); + private static float euclideanQuantizedScore( + BinarizedByteVectorValues targetVectors, + int targetOrd, + BinaryQuantizer.QueryFactors factors, + byte[] quantizedQuery + ) throws IOException { + byte[] binaryCode = targetVectors.vectorValue(targetOrd); - // FIXME: pre-compute these only once for each target vector - // .. not sure how to enumerate the target ordinals but that's what we did in PoC - float targetDistToC = targetVectors.getCentroidDistance(targetOrd); - float x0 = targetVectors.getVectorMagnitude(targetOrd); - float sqrX = targetDistToC * targetDistToC; - double xX0 = targetDistToC / x0; + float targetDistToC = targetVectors.getCentroidDistance(targetOrd); + float x0 = targetVectors.getVectorMagnitude(targetOrd); + float sqrX = targetDistToC * targetDistToC; + double xX0 = targetDistToC / x0; - // TODO maybe store? - float xbSum = (float) BQVectorUtils.popcount(binaryCode); - float factorPPC = (float) (-2.0 / sqrtDimensions * xX0 * (xbSum * 2.0 - targetVectors.dimension())); - float factorIP = (float) (-2.0 / sqrtDimensions * xX0); + float xbSum = (float) BQVectorUtils.popcount(binaryCode); + float factorPPC = (float) (-2.0 / targetVectors.sqrtDimensions() * xX0 * (xbSum * 2.0 - targetVectors.dimension())); + float factorIP = (float) (-2.0 / targetVectors.sqrtDimensions() * xX0); - long qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode); - float score = sqrX + distanceToCentroid + factorPPC * lower + (qcDist * 2 - quantizedSum) * factorIP * width; - float projectionDist = (float) Math.sqrt(xX0 * xX0 - targetDistToC * targetDistToC); - float error = 2.0f * maxX1 * projectionDist; - float y = (float) Math.sqrt(distanceToCentroid); - float errorBound = y * error; - if (Float.isFinite(errorBound)) { - score = score + errorBound; - } - return Math.max(1 / (1f + score), 0); + long qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode); + float score = sqrX + factors.distToC() + factorPPC * factors.lower() + (qcDist * 2 - factors.quantizedSum()) * factorIP * factors + .width(); + float projectionDist = (float) Math.sqrt(xX0 * xX0 - targetDistToC * targetDistToC); + float error = 2.0f * targetVectors.maxX1() * projectionDist; + float y = (float) Math.sqrt(factors.distToC()); + float errorBound = y * error; + if (Float.isFinite(errorBound)) { + score = score + errorBound; } + return Math.max(1 / (1f + score), 0); } + } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java index 61bd5323b5b43..ae7eea79dd29e 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java @@ -45,8 +45,8 @@ import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; -import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.index.codec.vectors.BQSpaceUtils; import org.elasticsearch.index.codec.vectors.BQVectorUtils; @@ -822,6 +822,10 @@ public int dimension() { return dimension; } + public int quantizedDimension() { + return byteBuffer.array().length; + } + public OffHeapBinarizedQueryVectorValues copy() throws IOException { return new OffHeapBinarizedQueryVectorValues(slice.clone(), dimension, size, vectorSimilarityFunction); } @@ -959,8 +963,8 @@ static class BinarizedCloseableRandomVectorScorerSupplier implements CloseableRa } @Override - public RandomVectorScorer scorer(int ord) throws IOException { - return supplier.scorer(ord); + public UpdateableRandomVectorScorer scorer() throws IOException { + return supplier.scorer(); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractor.java index f3de4fa124c44..e5947df2b417b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractor.java @@ -15,6 +15,8 @@ import org.apache.lucene.search.Weight; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Map; @@ -36,11 +38,12 @@ public QueryFeatureExtractor(List featureNames, List weights) { } this.featureNames = featureNames; this.weights = weights; - this.subScorers = new DisiPriorityQueue(weights.size()); + this.subScorers = DisiPriorityQueue.ofMaxSize(weights.size()); } @Override public void setNextReader(LeafReaderContext segmentContext) throws IOException { + Collection disiWrappers = new ArrayList<>(); subScorers.clear(); for (int i = 0; i < weights.size(); i++) { var weight = weights.get(i); @@ -51,11 +54,13 @@ public void setNextReader(LeafReaderContext segmentContext) throws IOException { if (scorerSupplier != null) { var scorer = scorerSupplier.get(0L); if (scorer != null) { - subScorers.add(new FeatureDisiWrapper(scorer, featureNames.get(i))); + FeatureDisiWrapper featureDisiWrapper = new FeatureDisiWrapper(scorer, featureNames.get(i)); + subScorers.add(featureDisiWrapper); + disiWrappers.add(featureDisiWrapper); } } } - approximation = subScorers.size() > 0 ? new DisjunctionDISIApproximation(subScorers) : null; + approximation = subScorers.size() > 0 ? new DisjunctionDISIApproximation(disiWrappers, Long.MAX_VALUE) : null; } @Override @@ -69,7 +74,7 @@ public void addFeatures(Map featureMap, int docId) throws IOExce if (approximation.docID() != docId) { return; } - var w = (FeatureDisiWrapper) subScorers.topList(); + var w = (FeatureDisiWrapper) approximation.topList(); for (; w != null; w = (FeatureDisiWrapper) w.next) { if (w.twoPhaseView == null || w.twoPhaseView.matches()) { featureMap.put(w.featureName, w.scorable.score()); diff --git a/x-pack/plugin/wildcard/src/main/java/org/elasticsearch/xpack/wildcard/mapper/WildcardFieldMapper.java b/x-pack/plugin/wildcard/src/main/java/org/elasticsearch/xpack/wildcard/mapper/WildcardFieldMapper.java index 77e275555deb4..5f9d5029aaf7c 100644 --- a/x-pack/plugin/wildcard/src/main/java/org/elasticsearch/xpack/wildcard/mapper/WildcardFieldMapper.java +++ b/x-pack/plugin/wildcard/src/main/java/org/elasticsearch/xpack/wildcard/mapper/WildcardFieldMapper.java @@ -407,6 +407,9 @@ public Query regexpQuery( public static Query toApproximationQuery(RegExp r) throws IllegalArgumentException { Query result = null; switch (r.kind) { + case REGEXP_CHAR_CLASS: + result = createCharacterClassQuery(r); + break; case REGEXP_UNION: result = createUnionQuery(r); break; @@ -426,7 +429,6 @@ public static Query toApproximationQuery(RegExp r) throws IllegalArgumentExcepti // Repeat is zero or more times so zero matches = match all result = new MatchAllDocsQuery(); break; - case REGEXP_REPEAT_MIN: case REGEXP_REPEAT_MINMAX: if (r.min > 0) { @@ -458,7 +460,7 @@ public static Query toApproximationQuery(RegExp r) throws IllegalArgumentExcepti case REGEXP_INTERVAL: case REGEXP_EMPTY: case REGEXP_AUTOMATON: - case REGEXP_PRE_CLASS: + // case REGEXP_PRE_CLASS: result = new MatchAllDocsQuery(); break; } @@ -496,11 +498,35 @@ private static Query createConcatenationQuery(RegExp r) { } + private static Query createCharacterClassQuery(RegExp r) { + // TODO: consider expanding this to allow for character ranges as well (need additional tests and performance eval) + List queries = new ArrayList<>(); + if (r.from.length > MAX_CLAUSES_IN_APPROXIMATION_QUERY) { + return new MatchAllDocsQuery(); + } + for (int i = 0; i < r.from.length; i++) { + // only handle character classes for now not ranges + if (r.from[i] == r.to[i]) { + String cs = Character.toString(r.from[i]); + String normalizedChar = toLowerCase(cs); + queries.add(new TermQuery(new Term("", normalizedChar))); + } else { + // immediately exit because we can't currently optimize a combination of range and classes + return new MatchAllDocsQuery(); + } + } + return formQuery(queries); + } + private static Query createUnionQuery(RegExp r) { // Create an OR of clauses - ArrayList queries = new ArrayList<>(); + List queries = new ArrayList<>(); findLeaves(r.exp1, org.apache.lucene.util.automaton.RegExp.Kind.REGEXP_UNION, queries); findLeaves(r.exp2, org.apache.lucene.util.automaton.RegExp.Kind.REGEXP_UNION, queries); + return formQuery(queries); + } + + private static Query formQuery(List queries) { BooleanQuery.Builder bOr = new BooleanQuery.Builder(); HashSet uniqueClauses = new HashSet<>(); for (Query query : queries) {