diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java
index fdb09594a1cda..f56bb8995b34e 100644
--- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java
+++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java
@@ -19,6 +19,7 @@
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;
import org.elasticsearch.common.logging.LogConfigurator;
@@ -76,10 +77,10 @@ public class VectorScorerBenchmark {
float vec2Offset;
float scoreCorrectionConstant;
- RandomVectorScorer luceneDotScorer;
- RandomVectorScorer luceneSqrScorer;
- RandomVectorScorer nativeDotScorer;
- RandomVectorScorer nativeSqrScorer;
+ UpdateableRandomVectorScorer luceneDotScorer;
+ UpdateableRandomVectorScorer luceneSqrScorer;
+ UpdateableRandomVectorScorer nativeDotScorer;
+ UpdateableRandomVectorScorer nativeSqrScorer;
RandomVectorScorer luceneDotScorerQuery;
RandomVectorScorer nativeDotScorerQuery;
@@ -118,12 +119,16 @@ public void setup() throws IOException {
in = dir.openInput("vector.data", IOContext.DEFAULT);
var values = vectorValues(dims, 2, in, VectorSimilarityFunction.DOT_PRODUCT);
scoreCorrectionConstant = values.getScalarQuantizer().getConstantMultiplier();
- luceneDotScorer = luceneScoreSupplier(values, VectorSimilarityFunction.DOT_PRODUCT).scorer(0);
+ luceneDotScorer = luceneScoreSupplier(values, VectorSimilarityFunction.DOT_PRODUCT).scorer();
+ luceneDotScorer.setScoringOrdinal(0);
values = vectorValues(dims, 2, in, VectorSimilarityFunction.EUCLIDEAN);
- luceneSqrScorer = luceneScoreSupplier(values, VectorSimilarityFunction.EUCLIDEAN).scorer(0);
+ luceneSqrScorer = luceneScoreSupplier(values, VectorSimilarityFunction.EUCLIDEAN).scorer();
+ luceneSqrScorer.setScoringOrdinal(0);
- nativeDotScorer = factory.getInt7SQVectorScorerSupplier(DOT_PRODUCT, in, values, scoreCorrectionConstant).get().scorer(0);
- nativeSqrScorer = factory.getInt7SQVectorScorerSupplier(EUCLIDEAN, in, values, scoreCorrectionConstant).get().scorer(0);
+ nativeDotScorer = factory.getInt7SQVectorScorerSupplier(DOT_PRODUCT, in, values, scoreCorrectionConstant).get().scorer();
+ nativeDotScorer.setScoringOrdinal(0);
+ nativeSqrScorer = factory.getInt7SQVectorScorerSupplier(EUCLIDEAN, in, values, scoreCorrectionConstant).get().scorer();
+ nativeSqrScorer.setScoringOrdinal(0);
// setup for getInt7SQVectorScorer / query vector scoring
float[] queryVec = new float[dims];
diff --git a/build-tools-internal/version.properties b/build-tools-internal/version.properties
index b4a8bd8085936..0df89b184badc 100644
--- a/build-tools-internal/version.properties
+++ b/build-tools-internal/version.properties
@@ -1,5 +1,5 @@
elasticsearch = 9.1.0
-lucene = 10.1.0
+lucene = 10.2.0
bundled_jdk_vendor = openjdk
bundled_jdk = 24+36@1f9ff9062db4449d8ca828c504ffae90
@@ -8,7 +8,7 @@ spatial4j = 0.7
jts = 1.15.0
jackson = 2.15.0
snakeyaml = 2.0
-icu4j = 68.2
+icu4j = 77.1
supercsv = 2.4.0
log4j = 2.19.0
slf4j = 2.0.6
diff --git a/docs/Versions.asciidoc b/docs/Versions.asciidoc
index c2e14a399b70e..58195d7313a5a 100644
--- a/docs/Versions.asciidoc
+++ b/docs/Versions.asciidoc
@@ -1,8 +1,8 @@
include::{docs-root}/shared/versions/stack/{source_branch}.asciidoc[]
-:lucene_version: 10.1.0
-:lucene_version_path: 10_1_0
+:lucene_version: 10.2.0
+:lucene_version_path: 10_2_0
:jdk: 11.0.2
:jdk_major: 11
:build_type: tar
diff --git a/docs/changelog/126594.yaml b/docs/changelog/126594.yaml
new file mode 100644
index 0000000000000..59743a606d34a
--- /dev/null
+++ b/docs/changelog/126594.yaml
@@ -0,0 +1,5 @@
+pr: 126594
+summary: Upgrade to Lucene 10.2.0
+area: Search
+type: upgrade
+issues: []
diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml
index 37c3e46f4bfa8..38b174fd5c2a4 100644
--- a/gradle/verification-metadata.xml
+++ b/gradle/verification-metadata.xml
@@ -911,21 +911,16 @@
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
@@ -2966,179 +2961,129 @@
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
-
-
-
-
-
+
+
+
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java
index 198e10406056e..19f33ba1c71f7 100644
--- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java
@@ -10,8 +10,8 @@
package org.elasticsearch.simdvec.internal;
import org.apache.lucene.store.MemorySegmentAccessInput;
-import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
@@ -55,9 +55,6 @@ protected final void checkOrdinal(int ord) {
}
final float scoreFromOrds(int firstOrd, int secondOrd) throws IOException {
- checkOrdinal(firstOrd);
- checkOrdinal(secondOrd);
-
final int length = dims;
long firstByteOffset = (long) firstOrd * (length + Float.BYTES);
long secondByteOffset = (long) secondOrd * (length + Float.BYTES);
@@ -92,13 +89,21 @@ protected final float fallbackScore(long firstByteOffset, long secondByteOffset)
}
@Override
- public RandomVectorScorer scorer(int ord) {
- checkOrdinal(ord);
- return new RandomVectorScorer.AbstractRandomVectorScorer(values) {
+ public UpdateableRandomVectorScorer scorer() {
+ return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) {
+ private int ord = -1;
+
@Override
public float score(int node) throws IOException {
+ checkOrdinal(node);
return scoreFromOrds(ord, node);
}
+
+ @Override
+ public void setScoringOrdinal(int node) throws IOException {
+ checkOrdinal(node);
+ this.ord = node;
+ }
};
}
diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/VectorScorerFactoryTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/VectorScorerFactoryTests.java
index 0f967127f6f2c..ff07e26661ae6 100644
--- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/VectorScorerFactoryTests.java
+++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/VectorScorerFactoryTests.java
@@ -19,8 +19,8 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
-import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;
@@ -50,6 +50,8 @@
// @com.carrotsearch.randomizedtesting.annotations.Repeat(iterations = 100)
public class VectorScorerFactoryTests extends AbstractVectorTestCase {
+ private static final float DELTA = 1e-4f;
+
// bounds of the range of values that can be seen by int7 scalar quantized vectors
static final byte MIN_INT7_VALUE = 0;
static final byte MAX_INT7_VALUE = 127;
@@ -99,10 +101,13 @@ void testSimpleImpl(long maxChunkSize) throws IOException {
float scc = values.getScalarQuantizer().getConstantMultiplier();
float expected = luceneScore(sim, vec1, vec2, scc, vec1Correction, vec2Correction);
- var luceneSupplier = luceneScoreSupplier(values, VectorSimilarityType.of(sim)).scorer(0);
+ var luceneSupplier = luceneScoreSupplier(values, VectorSimilarityType.of(sim)).scorer();
+ luceneSupplier.setScoringOrdinal(0);
assertThat(luceneSupplier.score(1), equalTo(expected));
var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, scc).get();
- assertThat(supplier.scorer(0).score(1), equalTo(expected));
+ var scorer = supplier.scorer();
+ scorer.setScoringOrdinal(0);
+ assertThat(scorer.score(1), equalTo(expected));
if (Runtime.version().feature() >= 22) {
var qScorer = factory.getInt7SQVectorScorer(VectorSimilarityType.of(sim), values, query1).get();
@@ -134,24 +139,32 @@ public void testNonNegativeDotProduct() throws IOException {
float expected = 0f;
assertThat(luceneScore(DOT_PRODUCT, vec1, vec2, 1, -5, -5), equalTo(expected));
var supplier = factory.getInt7SQVectorScorerSupplier(DOT_PRODUCT, in, values, 1).get();
- assertThat(supplier.scorer(0).score(1), equalTo(expected));
- assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f));
+ var scorer = supplier.scorer();
+ scorer.setScoringOrdinal(0);
+ assertThat(scorer.score(1), equalTo(expected));
+ assertThat(scorer.score(1), greaterThanOrEqualTo(0f));
// max inner product
expected = luceneScore(MAXIMUM_INNER_PRODUCT, vec1, vec2, 1, -5, -5);
supplier = factory.getInt7SQVectorScorerSupplier(MAXIMUM_INNER_PRODUCT, in, values, 1).get();
- assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f));
- assertThat(supplier.scorer(0).score(1), equalTo(expected));
+ scorer = supplier.scorer();
+ scorer.setScoringOrdinal(0);
+ assertThat(scorer.score(1), greaterThanOrEqualTo(0f));
+ assertThat(scorer.score(1), equalTo(expected));
// cosine
expected = 0f;
assertThat(luceneScore(COSINE, vec1, vec2, 1, -5, -5), equalTo(expected));
supplier = factory.getInt7SQVectorScorerSupplier(COSINE, in, values, 1).get();
- assertThat(supplier.scorer(0).score(1), equalTo(expected));
- assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f));
+ scorer = supplier.scorer();
+ scorer.setScoringOrdinal(0);
+ assertThat(scorer.score(1), equalTo(expected));
+ assertThat(scorer.score(1), greaterThanOrEqualTo(0f));
// euclidean
expected = luceneScore(EUCLIDEAN, vec1, vec2, 1, -5, -5);
supplier = factory.getInt7SQVectorScorerSupplier(EUCLIDEAN, in, values, 1).get();
- assertThat(supplier.scorer(0).score(1), equalTo(expected));
- assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f));
+ scorer = supplier.scorer();
+ scorer.setScoringOrdinal(0);
+ assertThat(scorer.score(1), equalTo(expected));
+ assertThat(scorer.score(1), greaterThanOrEqualTo(0f));
}
}
}
@@ -208,7 +221,9 @@ void testRandomSupplier(long maxChunkSize, Function byteArraySu
var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim));
float expected = luceneScore(sim, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]);
var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, correction).get();
- assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected));
+ var scorer = supplier.scorer();
+ scorer.setScoringOrdinal(idx0);
+ assertThat(scorer.score(idx1), equalTo(expected));
}
}
}
@@ -265,7 +280,7 @@ void testRandomScorerImpl(long maxChunkSize, Function floatArr
var expected = luceneScore(sim, qVectors[idx0], qVectors[idx1], correction, corrections[idx0], corrections[idx1]);
var scorer = factory.getInt7SQVectorScorer(VectorSimilarityType.of(sim), values, vectors[idx0]).get();
- assertThat(scorer.score(idx1), equalTo(expected));
+ assertEquals(scorer.score(idx1), expected, DELTA);
}
}
}
@@ -313,7 +328,9 @@ void testRandomSliceImpl(int dims, long maxChunkSize, int initialPadding, Functi
var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim));
float expected = luceneScore(sim, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]);
var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, correction).get();
- assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected));
+ var scorer = supplier.scorer();
+ scorer.setScoringOrdinal(idx0);
+ assertThat(scorer.score(idx1), equalTo(expected));
}
}
}
@@ -352,7 +369,9 @@ public void testLarge() throws IOException {
var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim));
float expected = luceneScore(sim, vector(idx0, dims), vector(idx1, dims), correction, off0, off1);
var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, correction).get();
- assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected));
+ var scorer = supplier.scorer();
+ scorer.setScoringOrdinal(idx0);
+ assertThat(scorer.score(idx1), equalTo(expected));
}
}
}
@@ -391,8 +410,8 @@ void testRaceImpl(VectorSimilarityType sim) throws Exception {
var values = vectorValues(dims, 4, in, VectorSimilarityType.of(sim));
var scoreSupplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, 1f).get();
var tasks = List.>>of(
- new ScoreCallable(scoreSupplier.copy().scorer(0), 1, expectedScore1),
- new ScoreCallable(scoreSupplier.copy().scorer(2), 3, expectedScore2)
+ new ScoreCallable(scoreSupplier.copy().scorer(), 0, 1, expectedScore1),
+ new ScoreCallable(scoreSupplier.copy().scorer(), 2, 3, expectedScore2)
);
var executor = Executors.newFixedThreadPool(2);
var results = executor.invokeAll(tasks);
@@ -408,14 +427,19 @@ void testRaceImpl(VectorSimilarityType sim) throws Exception {
static class ScoreCallable implements Callable> {
- final RandomVectorScorer scorer;
+ final UpdateableRandomVectorScorer scorer;
final int ord;
final float expectedScore;
- ScoreCallable(RandomVectorScorer scorer, int ord, float expectedScore) {
- this.scorer = scorer;
- this.ord = ord;
- this.expectedScore = expectedScore;
+ ScoreCallable(UpdateableRandomVectorScorer scorer, int queryOrd, int ord, float expectedScore) {
+ try {
+ this.scorer = scorer;
+ this.scorer.setScoringOrdinal(queryOrd);
+ this.ord = ord;
+ this.expectedScore = expectedScore;
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
}
@Override
diff --git a/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/AbstractCompoundWordTokenFilterFactory.java b/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/AbstractCompoundWordTokenFilterFactory.java
index c0a8c1374c146..a144df94a7750 100644
--- a/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/AbstractCompoundWordTokenFilterFactory.java
+++ b/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/AbstractCompoundWordTokenFilterFactory.java
@@ -28,6 +28,8 @@ public abstract class AbstractCompoundWordTokenFilterFactory extends AbstractTok
protected final int maxSubwordSize;
protected final boolean onlyLongestMatch;
protected final CharArraySet wordList;
+ // TODO expose this parameter?
+ protected final boolean reuseChars;
protected AbstractCompoundWordTokenFilterFactory(IndexSettings indexSettings, Environment env, String name, Settings settings) {
super(name);
@@ -36,6 +38,8 @@ protected AbstractCompoundWordTokenFilterFactory(IndexSettings indexSettings, En
minSubwordSize = settings.getAsInt("min_subword_size", CompoundWordTokenFilterBase.DEFAULT_MIN_SUBWORD_SIZE);
maxSubwordSize = settings.getAsInt("max_subword_size", CompoundWordTokenFilterBase.DEFAULT_MAX_SUBWORD_SIZE);
onlyLongestMatch = settings.getAsBoolean("only_longest_match", false);
+ // TODO is the default of true correct? see: https://github.com/apache/lucene/pull/14278
+ reuseChars = true;
wordList = Analysis.getWordSet(env, settings, "word_list");
if (wordList == null) {
throw new IllegalArgumentException("word_list must be provided for [" + name + "], either as a path to a file, or directly");
diff --git a/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/DictionaryCompoundWordTokenFilterFactory.java b/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/DictionaryCompoundWordTokenFilterFactory.java
index 7c2bb6ba1c116..4ec328070a49a 100644
--- a/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/DictionaryCompoundWordTokenFilterFactory.java
+++ b/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/DictionaryCompoundWordTokenFilterFactory.java
@@ -28,6 +28,14 @@ public class DictionaryCompoundWordTokenFilterFactory extends AbstractCompoundWo
@Override
public TokenStream create(TokenStream tokenStream) {
- return new DictionaryCompoundWordTokenFilter(tokenStream, wordList, minWordSize, minSubwordSize, maxSubwordSize, onlyLongestMatch);
+ return new DictionaryCompoundWordTokenFilter(
+ tokenStream,
+ wordList,
+ minWordSize,
+ minSubwordSize,
+ maxSubwordSize,
+ onlyLongestMatch,
+ reuseChars
+ );
}
}
diff --git a/server/src/main/java/org/elasticsearch/index/IndexVersions.java b/server/src/main/java/org/elasticsearch/index/IndexVersions.java
index 531e9455d1a2e..4c07be7d9200b 100644
--- a/server/src/main/java/org/elasticsearch/index/IndexVersions.java
+++ b/server/src/main/java/org/elasticsearch/index/IndexVersions.java
@@ -159,6 +159,7 @@ private static Version parseUnchecked(String version) {
public static final IndexVersion SYNTHETIC_SOURCE_STORE_ARRAYS_NATIVELY_UNSIGNED_LONG = def(9_019_0_00, Version.LUCENE_10_1_0);
public static final IndexVersion SYNTHETIC_SOURCE_STORE_ARRAYS_NATIVELY_SCALED_FLOAT = def(9_020_0_00, Version.LUCENE_10_1_0);
public static final IndexVersion USE_LUCENE101_POSTINGS_FORMAT = def(9_021_0_00, Version.LUCENE_10_1_0);
+ public static final IndexVersion UPGRADE_TO_LUCENE_10_2_0 = def(9_022_00_0, Version.LUCENE_10_2_0);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java
index 18668f4f304b0..e3242ee411e7d 100644
--- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java
@@ -22,6 +22,7 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import java.io.IOException;
@@ -130,18 +131,33 @@ public float score(int i) throws IOException {
}
static class HammingScorerSupplier implements RandomVectorScorerSupplier {
- private final ByteVectorValues byteValues, byteValues1, byteValues2;
+ private final ByteVectorValues byteValues, targetValues;
HammingScorerSupplier(ByteVectorValues byteValues) throws IOException {
this.byteValues = byteValues;
- this.byteValues1 = byteValues.copy();
- this.byteValues2 = byteValues.copy();
+ this.targetValues = byteValues.copy();
}
@Override
- public RandomVectorScorer scorer(int i) throws IOException {
- byte[] query = byteValues1.vectorValue(i);
- return new HammingVectorScorer(byteValues2, query);
+ public UpdateableRandomVectorScorer scorer() throws IOException {
+ return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(targetValues) {
+ private final byte[] query = new byte[targetValues.dimension()];
+ private int currentOrd = -1;
+
+ @Override
+ public void setScoringOrdinal(int i) throws IOException {
+ if (currentOrd == i) {
+ return;
+ }
+ System.arraycopy(targetValues.vectorValue(i), 0, query, 0, query.length);
+ this.currentOrd = i;
+ }
+
+ @Override
+ public float score(int i) throws IOException {
+ return hammingScore(targetValues.vectorValue(i), query);
+ }
+ };
}
@Override
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java
index 7c7e470909eb3..87b744b4e4eec 100644
--- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java
@@ -26,6 +26,7 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
import org.elasticsearch.simdvec.ESVectorUtil;
@@ -76,8 +77,20 @@ public RandomVectorScorer getRandomVectorScorer(
byte[] quantized = new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8];
OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(target, initial, (byte) 4, centroid);
BQSpaceUtils.transposeHalfByte(initial, quantized);
- BinaryQueryVector queryVector = new BinaryQueryVector(quantized, queryCorrections);
- return new BinarizedRandomVectorScorer(queryVector, binarizedVectors, similarityFunction);
+ return new RandomVectorScorer.AbstractRandomVectorScorer(vectorValues) {
+ @Override
+ public float score(int i) throws IOException {
+ return quantizedScore(
+ binarizedVectors.dimension(),
+ similarityFunction,
+ binarizedVectors.getCentroidDP(),
+ quantized,
+ queryCorrections,
+ binarizedVectors.vectorValue(i),
+ binarizedVectors.getCorrectiveTerms(i)
+ );
+ }
+ };
}
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
}
@@ -121,68 +134,95 @@ static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSu
}
@Override
- public RandomVectorScorer scorer(int ord) throws IOException {
- byte[] vector = queryVectors.vectorValue(ord);
- OptimizedScalarQuantizer.QuantizationResult correctiveTerms = queryVectors.getCorrectiveTerms(ord);
- BinaryQueryVector binaryQueryVector = new BinaryQueryVector(vector, correctiveTerms);
- return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction);
+ public BinarizedRandomVectorScorer scorer() throws IOException {
+ return new BinarizedRandomVectorScorer(queryVectors.copy(), targetVectors.copy(), similarityFunction);
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
- return new BinarizedRandomVectorScorerSupplier(queryVectors.copy(), targetVectors.copy(), similarityFunction);
+ return new BinarizedRandomVectorScorerSupplier(queryVectors, targetVectors, similarityFunction);
}
}
- /** A binarized query representing its quantized form along with factors */
- public record BinaryQueryVector(byte[] vector, OptimizedScalarQuantizer.QuantizationResult quantizationResult) {}
-
/** Vector scorer over binarized vector values */
- public static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
- private final BinaryQueryVector queryVector;
+ public static class BinarizedRandomVectorScorer extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer {
+ private final ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors;
private final BinarizedByteVectorValues targetVectors;
private final VectorSimilarityFunction similarityFunction;
+ private final byte[] quantizedQuery;
+ private OptimizedScalarQuantizer.QuantizationResult queryCorrections = null;
+ private int currentOrdinal = -1;
- public BinarizedRandomVectorScorer(
- BinaryQueryVector queryVectors,
+ BinarizedRandomVectorScorer(
+ ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors,
BinarizedByteVectorValues targetVectors,
VectorSimilarityFunction similarityFunction
) {
super(targetVectors);
- this.queryVector = queryVectors;
+ this.queryVectors = queryVectors;
+ this.quantizedQuery = new byte[queryVectors.quantizedDimension()];
this.targetVectors = targetVectors;
this.similarityFunction = similarityFunction;
}
@Override
public float score(int targetOrd) throws IOException {
- byte[] quantizedQuery = queryVector.vector();
- byte[] binaryCode = targetVectors.vectorValue(targetOrd);
- float qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode);
- OptimizedScalarQuantizer.QuantizationResult queryCorrections = queryVector.quantizationResult();
- OptimizedScalarQuantizer.QuantizationResult indexCorrections = targetVectors.getCorrectiveTerms(targetOrd);
- float x1 = indexCorrections.quantizedComponentSum();
- float ax = indexCorrections.lowerInterval();
- // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
- float lx = indexCorrections.upperInterval() - ax;
- float ay = queryCorrections.lowerInterval();
- float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
- float y1 = queryCorrections.quantizedComponentSum();
- float score = ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist;
- // For euclidean, we need to invert the score and apply the additional correction, which is
- // assumed to be the squared l2norm of the centroid centered vectors.
- if (similarityFunction == EUCLIDEAN) {
- score = queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - 2 * score;
- return Math.max(1 / (1f + score), 0);
- } else {
- // For cosine and max inner product, we need to apply the additional correction, which is
- // assumed to be the non-centered dot-product between the vector and the centroid
- score += queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - targetVectors.getCentroidDP();
- if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
- return VectorUtil.scaleMaxInnerProductScore(score);
- }
- return Math.max((1f + score) / 2f, 0);
+ if (queryCorrections == null) {
+ throw new IllegalStateException("score() called before setScoringOrdinal()");
+ }
+ return quantizedScore(
+ targetVectors.dimension(),
+ similarityFunction,
+ targetVectors.getCentroidDP(),
+ quantizedQuery,
+ queryCorrections,
+ targetVectors.vectorValue(targetOrd),
+ targetVectors.getCorrectiveTerms(targetOrd)
+ );
+ }
+
+ @Override
+ public void setScoringOrdinal(int i) throws IOException {
+ if (i == currentOrdinal) {
+ return;
+ }
+ System.arraycopy(queryVectors.vectorValue(i), 0, quantizedQuery, 0, quantizedQuery.length);
+ queryCorrections = queryVectors.getCorrectiveTerms(i);
+ currentOrdinal = i;
+ }
+ }
+
+ private static float quantizedScore(
+ int dims,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ byte[] q,
+ OptimizedScalarQuantizer.QuantizationResult queryCorrections,
+ byte[] d,
+ OptimizedScalarQuantizer.QuantizationResult indexCorrections
+ ) {
+ float qcDist = ESVectorUtil.ipByteBinByte(q, d);
+ float x1 = indexCorrections.quantizedComponentSum();
+ float ax = indexCorrections.lowerInterval();
+ // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
+ float lx = indexCorrections.upperInterval() - ax;
+ float ay = queryCorrections.lowerInterval();
+ float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
+ float y1 = queryCorrections.quantizedComponentSum();
+ float score = ax * ay * dims + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist;
+ // For euclidean, we need to invert the score and apply the additional correction, which is
+ // assumed to be the squared l2norm of the centroid centered vectors.
+ if (similarityFunction == EUCLIDEAN) {
+ score = queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - 2 * score;
+ return Math.max(1 / (1f + score), 0);
+ } else {
+ // For cosine and max inner product, we need to apply the additional correction, which is
+ // assumed to be the non-centered dot-product between the vector and the centroid
+ score += queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - centroidDp;
+ if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+ return VectorUtil.scaleMaxInnerProductScore(score);
}
+ return Math.max((1f + score) / 2f, 0);
}
}
}
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java
index 932925ea423ba..8dba2dbee9f5b 100644
--- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java
@@ -44,8 +44,8 @@
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
-import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
@@ -763,6 +763,10 @@ public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int target
);
}
+ int quantizedDimension() {
+ return byteBuffer.array().length;
+ }
+
public int size() {
return size;
}
@@ -887,8 +891,8 @@ static class BinarizedCloseableRandomVectorScorerSupplier implements CloseableRa
}
@Override
- public RandomVectorScorer scorer(int ord) throws IOException {
- return supplier.scorer(ord);
+ public UpdateableRandomVectorScorer scorer() throws IOException {
+ return supplier.scorer();
}
@Override
diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/PointsSortedDocsProducer.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/PointsSortedDocsProducer.java
index e88c9724edba1..860205ebb23bc 100644
--- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/PointsSortedDocsProducer.java
+++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/PointsSortedDocsProducer.java
@@ -58,7 +58,7 @@ DocIdSet processLeaf(CompositeValuesCollectorQueue queue, LeafReaderContext cont
}
upperBucket = (Long) upperValue;
}
- DocIdSetBuilder builder = fillDocIdSet ? new DocIdSetBuilder(context.reader().maxDoc(), values, field) : null;
+ DocIdSetBuilder builder = fillDocIdSet ? new DocIdSetBuilder(context.reader().maxDoc(), values) : null;
Visitor visitor = new Visitor(context, queue, builder, values.getBytesPerDimension(), lowerBucket, upperBucket);
try {
values.intersect(visitor);
diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/filter/FiltersAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/filter/FiltersAggregator.java
index bf6fb39d43c4b..ff3f93995dbae 100644
--- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/filter/FiltersAggregator.java
+++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/filter/FiltersAggregator.java
@@ -328,7 +328,11 @@ protected LeafBucketCollector getLeafCollector(AggregationExecutionContext aggCt
hasOtherBucket
);
}
- return new MultiFilterLeafCollector(sub, filterWrappers, numFilters, totalNumKeys, usesCompetitiveIterator, hasOtherBucket);
+ if (usesCompetitiveIterator) {
+ return new MultiFilterCompetitiveLeafCollector(sub, filterWrappers, numFilters, totalNumKeys, hasOtherBucket);
+ } else {
+ return new MultiFilterLeafCollector(sub, filterWrappers, numFilters, totalNumKeys, hasOtherBucket);
+ }
}
}
@@ -400,21 +404,20 @@ public DocIdSetIterator competitiveIterator() throws IOException {
}
}
- private class MultiFilterLeafCollector extends AbstractLeafCollector {
+ private final class MultiFilterLeafCollector extends AbstractLeafCollector {
// A DocIdSetIterator heap with one entry for each filter, ordered by doc ID
- final DisiPriorityQueue filterIterators;
+ DisiPriorityQueue filterIterators;
MultiFilterLeafCollector(
LeafBucketCollector sub,
List filterWrappers,
int numFilters,
int totalNumKeys,
- boolean usesCompetitiveIterator,
boolean hasOtherBucket
) {
- super(sub, numFilters, totalNumKeys, usesCompetitiveIterator, hasOtherBucket);
- filterIterators = filterWrappers.isEmpty() ? null : new DisiPriorityQueue(filterWrappers.size());
+ super(sub, numFilters, totalNumKeys, false, hasOtherBucket);
+ filterIterators = filterWrappers.isEmpty() ? null : DisiPriorityQueue.ofMaxSize(filterWrappers.size());
for (FilterMatchingDisiWrapper wrapper : filterWrappers) {
filterIterators.add(wrapper);
}
@@ -423,7 +426,7 @@ private class MultiFilterLeafCollector extends AbstractLeafCollector {
public void collect(int doc, long bucket) throws IOException {
boolean matched = false;
if (filterIterators != null) {
- // Advance filters if necessary. Filters will already be advanced if used as a competitive iterator.
+ // Advance filters if necessary.
DisiWrapper top = filterIterators.top();
while (top.doc < doc) {
top.doc = top.approximation.advance(doc);
@@ -448,16 +451,51 @@ public void collect(int doc, long bucket) throws IOException {
}
@Override
- public DocIdSetIterator competitiveIterator() throws IOException {
- if (usesCompetitiveIterator) {
- // A DocIdSetIterator view of the filterIterators heap
- assert filterIterators != null;
- return new DisjunctionDISIApproximation(filterIterators);
- }
+ public DocIdSetIterator competitiveIterator() {
return null;
}
}
+ private final class MultiFilterCompetitiveLeafCollector extends AbstractLeafCollector {
+
+ private final DisjunctionDISIApproximation disjunctionDisi;
+
+ MultiFilterCompetitiveLeafCollector(
+ LeafBucketCollector sub,
+ List filterWrappers,
+ int numFilters,
+ int totalNumKeys,
+ boolean hasOtherBucket
+ ) {
+ super(sub, numFilters, totalNumKeys, true, hasOtherBucket);
+ assert filterWrappers.isEmpty() == false;
+ disjunctionDisi = DisjunctionDISIApproximation.of(filterWrappers, Long.MAX_VALUE);
+ }
+
+ public void collect(int doc, long bucket) throws IOException {
+ boolean matched = false;
+ int target = disjunctionDisi.advance(doc);
+ if (target == doc) {
+ for (DisiWrapper w = disjunctionDisi.topList(); w != null; w = w.next) {
+ FilterMatchingDisiWrapper topMatch = (FilterMatchingDisiWrapper) w;
+ if (topMatch.checkDocForMatch(doc)) {
+ collectBucket(sub, doc, bucketOrd(bucket, topMatch.filterOrd));
+ matched = true;
+ }
+ }
+ }
+
+ if (hasOtherBucket && false == matched) {
+ collectBucket(sub, doc, bucketOrd(bucket, numFilters));
+ }
+ }
+
+ @Override
+ public DocIdSetIterator competitiveIterator() {
+ return disjunctionDisi;
+ }
+ }
+
private static class FilterMatchingDisiWrapper extends DisiWrapper {
final int filterOrd;
diff --git a/server/src/test/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzerTests.java b/server/src/test/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzerTests.java
index 285cb8a9564be..fdcfe5e3720f8 100644
--- a/server/src/test/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzerTests.java
+++ b/server/src/test/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzerTests.java
@@ -56,7 +56,6 @@
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.suggest.document.Completion101PostingsFormat;
-import org.apache.lucene.search.suggest.document.CompletionPostingsFormat;
import org.apache.lucene.search.suggest.document.SuggestField;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FilterDirectory;
@@ -331,7 +330,7 @@ public void testCompletionField() throws Exception {
@Override
public PostingsFormat getPostingsFormatForField(String field) {
if (field.startsWith("suggest_")) {
- return new Completion101PostingsFormat(randomFrom(CompletionPostingsFormat.FSTLoadMode.values()));
+ return new Completion101PostingsFormat();
} else {
return super.postingsFormat();
}
diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java
index 0bebe16f468ce..b897a004a9781 100644
--- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java
+++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java
@@ -26,6 +26,7 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
import org.elasticsearch.simdvec.ESVectorUtil;
@@ -120,28 +121,48 @@ static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSu
}
@Override
- public RandomVectorScorer scorer(int ord) throws IOException {
- byte[] vector = queryVectors.vectorValue(ord);
- int quantizedSum = queryVectors.sumQuantizedValues(ord);
- float distanceToCentroid = queryVectors.getCentroidDistance(ord);
- float lower = queryVectors.getLower(ord);
- float width = queryVectors.getWidth(ord);
- float normVmC = 0f;
- float vDotC = 0f;
- if (similarityFunction != EUCLIDEAN) {
- normVmC = queryVectors.getNormVmC(ord);
- vDotC = queryVectors.getVDotC(ord);
- }
- BinaryQueryVector binaryQueryVector = new BinaryQueryVector(
- vector,
- new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, lower, width, normVmC, vDotC)
- );
- return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction);
+ public UpdateableRandomVectorScorer scorer() throws IOException {
+ byte[] queryVector = new byte[queryVectors.quantizedDimension()];
+ return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(targetVectors) {
+ private int ord = -1;
+ private BinaryQuantizer.QueryFactors factors = null;
+ private final ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors =
+ BinarizedRandomVectorScorerSupplier.this.queryVectors.copy();
+ private final BinarizedByteVectorValues targetVectors = BinarizedRandomVectorScorerSupplier.this.targetVectors.copy();
+
+ @Override
+ public void setScoringOrdinal(int i) throws IOException {
+ if (i == ord) {
+ return;
+ }
+ ord = i;
+ System.arraycopy(queryVectors.vectorValue(i), 0, queryVector, 0, queryVector.length);
+ int quantizedSum = queryVectors.sumQuantizedValues(ord);
+ float distanceToCentroid = queryVectors.getCentroidDistance(ord);
+ float lower = queryVectors.getLower(ord);
+ float width = queryVectors.getWidth(ord);
+ float normVmC = 0f;
+ float vDotC = 0f;
+ if (similarityFunction != EUCLIDEAN) {
+ normVmC = queryVectors.getNormVmC(ord);
+ vDotC = queryVectors.getVDotC(ord);
+ }
+ factors = new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, lower, width, normVmC, vDotC);
+ }
+
+ @Override
+ public float score(int i) throws IOException {
+ if (factors == null) {
+ throw new IllegalStateException("setScoringOrdinal must be called before score");
+ }
+ return quantizedScore(this.targetVectors, i, queryVector, factors, similarityFunction);
+ }
+ };
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
- return new BinarizedRandomVectorScorerSupplier(queryVectors.copy(), targetVectors.copy(), similarityFunction);
+ return new BinarizedRandomVectorScorerSupplier(queryVectors, targetVectors, similarityFunction);
}
}
@@ -154,9 +175,6 @@ static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRand
private final BinarizedByteVectorValues targetVectors;
private final VectorSimilarityFunction similarityFunction;
- private final float sqrtDimensions;
- private final float maxX1;
-
BinarizedRandomVectorScorer(
BinaryQueryVector queryVectors,
BinarizedByteVectorValues targetVectors,
@@ -166,91 +184,97 @@ static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRand
this.queryVector = queryVectors;
this.targetVectors = targetVectors;
this.similarityFunction = similarityFunction;
- // FIXME: precompute this once?
- this.sqrtDimensions = targetVectors.sqrtDimensions();
- this.maxX1 = targetVectors.maxX1();
}
@Override
public float score(int targetOrd) throws IOException {
- byte[] quantizedQuery = queryVector.vector();
- int quantizedSum = queryVector.factors().quantizedSum();
- float lower = queryVector.factors().lower();
- float width = queryVector.factors().width();
- float distanceToCentroid = queryVector.factors().distToC();
- if (similarityFunction == EUCLIDEAN) {
- return euclideanScore(targetOrd, sqrtDimensions, quantizedQuery, distanceToCentroid, lower, quantizedSum, width);
- }
+ return quantizedScore(targetVectors, targetOrd, queryVector.vector(), queryVector.factors(), similarityFunction);
+ }
+ }
- float vmC = queryVector.factors().normVmC();
- float vDotC = queryVector.factors().vDotC();
- float cDotC = targetVectors.getCentroidDP();
- byte[] binaryCode = targetVectors.vectorValue(targetOrd);
- float ooq = targetVectors.getOOQ(targetOrd);
- float normOC = targetVectors.getNormOC(targetOrd);
- float oDotC = targetVectors.getODotC(targetOrd);
+ private static float quantizedScore(
+ BinarizedByteVectorValues targetVectors,
+ int targetOrd,
+ byte[] quantizedQuery,
+ BinaryQuantizer.QueryFactors queryFactors,
+ VectorSimilarityFunction similarityFunction
+ ) throws IOException {
+ if (similarityFunction == EUCLIDEAN) {
+ return euclideanQuantizedScore(targetVectors, targetOrd, queryFactors, quantizedQuery);
+ }
+ return dotProductQuantizedScore(targetVectors, targetOrd, quantizedQuery, queryFactors, similarityFunction);
+ }
- float qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode);
+ private static float dotProductQuantizedScore(
+ BinarizedByteVectorValues targetVectors,
+ int targetOrd,
+ byte[] quantizedQuery,
+ BinaryQuantizer.QueryFactors queryFactors,
+ VectorSimilarityFunction similarityFunction
+ ) throws IOException {
+ float vmC = queryFactors.normVmC();
+ float vDotC = queryFactors.vDotC();
+ float cDotC = targetVectors.getCentroidDP();
+ byte[] binaryCode = targetVectors.vectorValue(targetOrd);
+ float ooq = targetVectors.getOOQ(targetOrd);
+ float normOC = targetVectors.getNormOC(targetOrd);
+ float oDotC = targetVectors.getODotC(targetOrd);
- // FIXME: pre-compute these only once for each target vector
- // ... pull this out or use a similar cache mechanism as do in score
- float xbSum = (float) BQVectorUtils.popcount(binaryCode);
- final float dist;
- // If ||o-c|| == 0, so, it's ok to throw the rest of the equation away
- // and simply use `oDotC + vDotC - cDotC` as centroid == doc vector
- if (normOC == 0 || ooq == 0) {
- dist = oDotC + vDotC - cDotC;
- } else {
- // If ||o-c|| != 0, we should assume that `ooq` is finite
- assert Float.isFinite(ooq);
- float estimatedDot = (2 * width / sqrtDimensions * qcDist + 2 * lower / sqrtDimensions * xbSum - width / sqrtDimensions
- * quantizedSum - sqrtDimensions * lower) / ooq;
- dist = vmC * normOC * estimatedDot + oDotC + vDotC - cDotC;
- }
- assert Float.isFinite(dist);
+ float qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode);
- float ooqSqr = (float) Math.pow(ooq, 2);
- float errorBound = (float) (vmC * normOC * (maxX1 * Math.sqrt((1 - ooqSqr) / ooqSqr)));
- float score = Float.isFinite(errorBound) ? dist - errorBound : dist;
- if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
- return VectorUtil.scaleMaxInnerProductScore(score);
- }
- return Math.max((1f + score) / 2f, 0);
+ float xbSum = (float) BQVectorUtils.popcount(binaryCode);
+ final float dist;
+ // If ||o-c|| == 0, so, it's ok to throw the rest of the equation away
+ // and simply use `oDotC + vDotC - cDotC` as centroid == doc vector
+ if (normOC == 0 || ooq == 0) {
+ dist = oDotC + vDotC - cDotC;
+ } else {
+ // If ||o-c|| != 0, we should assume that `ooq` is finite
+ assert Float.isFinite(ooq);
+ float estimatedDot = (2 * queryFactors.width() / targetVectors.sqrtDimensions() * qcDist + 2 * queryFactors.lower()
+ / targetVectors.sqrtDimensions() * xbSum - queryFactors.width() / targetVectors.sqrtDimensions() * queryFactors
+ .quantizedSum() - targetVectors.sqrtDimensions() * queryFactors.lower()) / ooq;
+ dist = vmC * normOC * estimatedDot + oDotC + vDotC - cDotC;
+ }
+ assert Float.isFinite(dist);
+
+ float ooqSqr = (float) Math.pow(ooq, 2);
+ float errorBound = (float) (vmC * normOC * (targetVectors.maxX1() * Math.sqrt((1 - ooqSqr) / ooqSqr)));
+ float score = Float.isFinite(errorBound) ? dist - errorBound : dist;
+ if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+ return VectorUtil.scaleMaxInnerProductScore(score);
}
+ return Math.max((1f + score) / 2f, 0);
+ }
- private float euclideanScore(
- int targetOrd,
- float sqrtDimensions,
- byte[] quantizedQuery,
- float distanceToCentroid,
- float lower,
- int quantizedSum,
- float width
- ) throws IOException {
- byte[] binaryCode = targetVectors.vectorValue(targetOrd);
+ private static float euclideanQuantizedScore(
+ BinarizedByteVectorValues targetVectors,
+ int targetOrd,
+ BinaryQuantizer.QueryFactors factors,
+ byte[] quantizedQuery
+ ) throws IOException {
+ byte[] binaryCode = targetVectors.vectorValue(targetOrd);
- // FIXME: pre-compute these only once for each target vector
- // .. not sure how to enumerate the target ordinals but that's what we did in PoC
- float targetDistToC = targetVectors.getCentroidDistance(targetOrd);
- float x0 = targetVectors.getVectorMagnitude(targetOrd);
- float sqrX = targetDistToC * targetDistToC;
- double xX0 = targetDistToC / x0;
+ float targetDistToC = targetVectors.getCentroidDistance(targetOrd);
+ float x0 = targetVectors.getVectorMagnitude(targetOrd);
+ float sqrX = targetDistToC * targetDistToC;
+ double xX0 = targetDistToC / x0;
- // TODO maybe store?
- float xbSum = (float) BQVectorUtils.popcount(binaryCode);
- float factorPPC = (float) (-2.0 / sqrtDimensions * xX0 * (xbSum * 2.0 - targetVectors.dimension()));
- float factorIP = (float) (-2.0 / sqrtDimensions * xX0);
+ float xbSum = (float) BQVectorUtils.popcount(binaryCode);
+ float factorPPC = (float) (-2.0 / targetVectors.sqrtDimensions() * xX0 * (xbSum * 2.0 - targetVectors.dimension()));
+ float factorIP = (float) (-2.0 / targetVectors.sqrtDimensions() * xX0);
- long qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode);
- float score = sqrX + distanceToCentroid + factorPPC * lower + (qcDist * 2 - quantizedSum) * factorIP * width;
- float projectionDist = (float) Math.sqrt(xX0 * xX0 - targetDistToC * targetDistToC);
- float error = 2.0f * maxX1 * projectionDist;
- float y = (float) Math.sqrt(distanceToCentroid);
- float errorBound = y * error;
- if (Float.isFinite(errorBound)) {
- score = score + errorBound;
- }
- return Math.max(1 / (1f + score), 0);
+ long qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode);
+ float score = sqrX + factors.distToC() + factorPPC * factors.lower() + (qcDist * 2 - factors.quantizedSum()) * factorIP * factors
+ .width();
+ float projectionDist = (float) Math.sqrt(xX0 * xX0 - targetDistToC * targetDistToC);
+ float error = 2.0f * targetVectors.maxX1() * projectionDist;
+ float y = (float) Math.sqrt(factors.distToC());
+ float errorBound = y * error;
+ if (Float.isFinite(errorBound)) {
+ score = score + errorBound;
}
+ return Math.max(1 / (1f + score), 0);
}
+
}
diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java
index 61bd5323b5b43..ae7eea79dd29e 100644
--- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java
+++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java
@@ -45,8 +45,8 @@
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
-import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
@@ -822,6 +822,10 @@ public int dimension() {
return dimension;
}
+ public int quantizedDimension() {
+ return byteBuffer.array().length;
+ }
+
public OffHeapBinarizedQueryVectorValues copy() throws IOException {
return new OffHeapBinarizedQueryVectorValues(slice.clone(), dimension, size, vectorSimilarityFunction);
}
@@ -959,8 +963,8 @@ static class BinarizedCloseableRandomVectorScorerSupplier implements CloseableRa
}
@Override
- public RandomVectorScorer scorer(int ord) throws IOException {
- return supplier.scorer(ord);
+ public UpdateableRandomVectorScorer scorer() throws IOException {
+ return supplier.scorer();
}
@Override
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractor.java
index f3de4fa124c44..e5947df2b417b 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractor.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractor.java
@@ -15,6 +15,8 @@
import org.apache.lucene.search.Weight;
import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
import java.util.List;
import java.util.Map;
@@ -36,11 +38,12 @@ public QueryFeatureExtractor(List featureNames, List weights) {
}
this.featureNames = featureNames;
this.weights = weights;
- this.subScorers = new DisiPriorityQueue(weights.size());
+ this.subScorers = DisiPriorityQueue.ofMaxSize(weights.size());
}
@Override
public void setNextReader(LeafReaderContext segmentContext) throws IOException {
+ Collection disiWrappers = new ArrayList<>();
subScorers.clear();
for (int i = 0; i < weights.size(); i++) {
var weight = weights.get(i);
@@ -51,11 +54,13 @@ public void setNextReader(LeafReaderContext segmentContext) throws IOException {
if (scorerSupplier != null) {
var scorer = scorerSupplier.get(0L);
if (scorer != null) {
- subScorers.add(new FeatureDisiWrapper(scorer, featureNames.get(i)));
+ FeatureDisiWrapper featureDisiWrapper = new FeatureDisiWrapper(scorer, featureNames.get(i));
+ subScorers.add(featureDisiWrapper);
+ disiWrappers.add(featureDisiWrapper);
}
}
}
- approximation = subScorers.size() > 0 ? new DisjunctionDISIApproximation(subScorers) : null;
+ approximation = subScorers.size() > 0 ? new DisjunctionDISIApproximation(disiWrappers, Long.MAX_VALUE) : null;
}
@Override
@@ -69,7 +74,7 @@ public void addFeatures(Map featureMap, int docId) throws IOExce
if (approximation.docID() != docId) {
return;
}
- var w = (FeatureDisiWrapper) subScorers.topList();
+ var w = (FeatureDisiWrapper) approximation.topList();
for (; w != null; w = (FeatureDisiWrapper) w.next) {
if (w.twoPhaseView == null || w.twoPhaseView.matches()) {
featureMap.put(w.featureName, w.scorable.score());
diff --git a/x-pack/plugin/wildcard/src/main/java/org/elasticsearch/xpack/wildcard/mapper/WildcardFieldMapper.java b/x-pack/plugin/wildcard/src/main/java/org/elasticsearch/xpack/wildcard/mapper/WildcardFieldMapper.java
index 77e275555deb4..5f9d5029aaf7c 100644
--- a/x-pack/plugin/wildcard/src/main/java/org/elasticsearch/xpack/wildcard/mapper/WildcardFieldMapper.java
+++ b/x-pack/plugin/wildcard/src/main/java/org/elasticsearch/xpack/wildcard/mapper/WildcardFieldMapper.java
@@ -407,6 +407,9 @@ public Query regexpQuery(
public static Query toApproximationQuery(RegExp r) throws IllegalArgumentException {
Query result = null;
switch (r.kind) {
+ case REGEXP_CHAR_CLASS:
+ result = createCharacterClassQuery(r);
+ break;
case REGEXP_UNION:
result = createUnionQuery(r);
break;
@@ -426,7 +429,6 @@ public static Query toApproximationQuery(RegExp r) throws IllegalArgumentExcepti
// Repeat is zero or more times so zero matches = match all
result = new MatchAllDocsQuery();
break;
-
case REGEXP_REPEAT_MIN:
case REGEXP_REPEAT_MINMAX:
if (r.min > 0) {
@@ -458,7 +460,7 @@ public static Query toApproximationQuery(RegExp r) throws IllegalArgumentExcepti
case REGEXP_INTERVAL:
case REGEXP_EMPTY:
case REGEXP_AUTOMATON:
- case REGEXP_PRE_CLASS:
+ // case REGEXP_PRE_CLASS:
result = new MatchAllDocsQuery();
break;
}
@@ -496,11 +498,35 @@ private static Query createConcatenationQuery(RegExp r) {
}
+ private static Query createCharacterClassQuery(RegExp r) {
+ // TODO: consider expanding this to allow for character ranges as well (need additional tests and performance eval)
+ List queries = new ArrayList<>();
+ if (r.from.length > MAX_CLAUSES_IN_APPROXIMATION_QUERY) {
+ return new MatchAllDocsQuery();
+ }
+ for (int i = 0; i < r.from.length; i++) {
+ // only handle character classes for now not ranges
+ if (r.from[i] == r.to[i]) {
+ String cs = Character.toString(r.from[i]);
+ String normalizedChar = toLowerCase(cs);
+ queries.add(new TermQuery(new Term("", normalizedChar)));
+ } else {
+ // immediately exit because we can't currently optimize a combination of range and classes
+ return new MatchAllDocsQuery();
+ }
+ }
+ return formQuery(queries);
+ }
+
private static Query createUnionQuery(RegExp r) {
// Create an OR of clauses
- ArrayList queries = new ArrayList<>();
+ List queries = new ArrayList<>();
findLeaves(r.exp1, org.apache.lucene.util.automaton.RegExp.Kind.REGEXP_UNION, queries);
findLeaves(r.exp2, org.apache.lucene.util.automaton.RegExp.Kind.REGEXP_UNION, queries);
+ return formQuery(queries);
+ }
+
+ private static Query formQuery(List queries) {
BooleanQuery.Builder bOr = new BooleanQuery.Builder();
HashSet uniqueClauses = new HashSet<>();
for (Query query : queries) {