From a0129cf87c0ef065549cb2a3d6f7f967ac852345 Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Sun, 25 May 2025 16:06:41 -0700 Subject: [PATCH 01/18] initial late interaction field --- .../lucene/document/LateInteractionField.java | 118 ++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java diff --git a/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java b/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java new file mode 100644 index 000000000000..56d697ac0b46 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.document; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import org.apache.lucene.util.BytesRef; + +/** + * A field for storing multi-vector values for late interaction models. + * + *

The value is stored as a binary payload, and can be retrieved using {@link + * LateInteractionField#decode(BytesRef)}. Multi-vectors are expected to have the same dimension for + * each composing token vector. This is stored along with the token vectors in the first 4 bytes of + * the payload. + * + *

Note: This field does not ensure consistency in token vector dimensions for values across + * documents. + */ +public class LateInteractionField extends BinaryDocValuesField { + + /** + * Creates a new {@link LateInteractionField} from the provided multi-vector matrix. + * + * @param name field name + * @param value multi-vector value + */ + public LateInteractionField(String name, float[][] value) { + super(name, encode(value)); + } + + /** + * Encodes provided multi-vector matrix into a {@link BytesRef} that can be stored in the {@link + * LateInteractionField}. + * + *

Composing token vectors for the multi-vector are expected to have the same dimension, which + * is stored along with the token vectors in the first 4 bytes of the payload. Use {@link + * LateInteractionField#decode(BytesRef)} to retrieve the multi-vector. + * + * @param value Multi-Vector to encode + * @return BytesRef representation for provided multi-vector + */ + public static BytesRef encode(float[][] value) { + if (value.length == 0) { + throw new IllegalArgumentException("Provided value is empty"); + } + final int tokenVectorDimension = value[0].length; + final ByteBuffer buffer = + ByteBuffer.allocate(Integer.BYTES + value.length * tokenVectorDimension * Float.BYTES) + .order(ByteOrder.LITTLE_ENDIAN); + // TODO: Should we store dimension in FieldType to ensure consistency across all documents? + buffer.putInt(tokenVectorDimension); + for (int i = 0; i < value.length; i++) { + if (value[i].length != tokenVectorDimension) { + throw new IllegalArgumentException( + "Composing token vectors should have the same dimension. " + + "Mismatching dimensions detected between token[0] and token[" + + i + + "], " + + value[0].length + + " != " + + value[i].length); + } + buffer.asFloatBuffer().put(value[i]); + } + return new BytesRef(buffer.array()); + } + + /** + * Decodes provided {@link BytesRef} into a multi-vector matrix. + * + *

The token vectors are expected to have the same dimension, which is stored along with the + * token vectors in the first 4 bytes of the payload. Meant to be used as a counterpart to {@link + * LateInteractionField#encode(float[][])} + * + * @param payload to decode into multi-vector value + * @return + */ + public static float[][] decode(BytesRef payload) { + final ByteBuffer buffer = ByteBuffer.wrap(payload.bytes, payload.offset, payload.length); + buffer.order(ByteOrder.LITTLE_ENDIAN); + final int tokenVectorDimension = buffer.getInt(); + int numVectors = (payload.length - 4) / tokenVectorDimension; + if (numVectors * tokenVectorDimension + 4 != payload.length) { + throw new IllegalArgumentException( + "Provided payload does not appear to have been encoded via LateInteractionField.encode. " + + "Payload length should be equal to 4 + numVectors * tokenVectorDimension, " + + "got " + + payload.length + + " != 4 + " + + numVectors + + " * " + + tokenVectorDimension); + } + var floatBuffer = buffer.asFloatBuffer(); + float[][] value = new float[numVectors][]; + for (int i = 0; i < numVectors; i++) { + value[i] = new float[tokenVectorDimension]; + floatBuffer.get(value[i]); + } + return value; + } +} From ee66d0dadc7918ad158caf196434239b6e53bf05 Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Mon, 26 May 2025 17:48:07 -0700 Subject: [PATCH 02/18] impl for LateI values source, query and field --- .../lucene/document/LateInteractionField.java | 19 +- .../search/LateInteractionValuesSource.java | 171 ++++++++++++++++++ .../queries/function/FunctionScoreQuery.java | 26 +++ 3 files changed, 214 insertions(+), 2 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/search/LateInteractionValuesSource.java diff --git a/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java b/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java index 56d697ac0b46..224f9d8993ff 100644 --- a/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java +++ b/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java @@ -44,6 +44,21 @@ public LateInteractionField(String name, float[][] value) { super(name, encode(value)); } + /** + * Set multi-vector value for the field. + * + *

Value should not be null or empty. Composing token vectors for provided multi-vector value + * should have the same dimension. + */ + public void setValue(float[][] value) { + this.fieldsData = encode(value); + } + + /** Returns the multi-vector value stored in this field */ + public float[][] getValue() { + return decode((BytesRef) fieldsData); + } + /** * Encodes provided multi-vector matrix into a {@link BytesRef} that can be stored in the {@link * LateInteractionField}. @@ -56,8 +71,8 @@ public LateInteractionField(String name, float[][] value) { * @return BytesRef representation for provided multi-vector */ public static BytesRef encode(float[][] value) { - if (value.length == 0) { - throw new IllegalArgumentException("Provided value is empty"); + if (value == null || value.length == 0) { + throw new IllegalArgumentException("Value should not be null or empty"); } final int tokenVectorDimension = value[0].length; final ByteBuffer buffer = diff --git a/lucene/core/src/java/org/apache/lucene/search/LateInteractionValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/LateInteractionValuesSource.java new file mode 100644 index 000000000000..33893b5135da --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/LateInteractionValuesSource.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; +import org.apache.lucene.document.LateInteractionField; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; + +public class LateInteractionValuesSource extends DoubleValuesSource { + + private final String fieldName; + private final float[][] queryVector; + private final VectorSimilarityFunction vectorSimilarityFunction; + private final ScoreFunction scoreFunction; + + public LateInteractionValuesSource( + String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) { + this(fieldName, queryVector, vectorSimilarityFunction, ScoreFunction.SUM_MAX_SIM); + } + + public LateInteractionValuesSource( + String fieldName, + float[][] queryVector, + VectorSimilarityFunction vectorSimilarityFunction, + ScoreFunction scoreFunction) { + if (fieldName == null) { + throw new IllegalArgumentException("fieldName must not be null"); + } + if (queryVector == null || queryVector.length == 0) { + throw new IllegalArgumentException("queryVector must not be null or empty"); + } + this.fieldName = fieldName; + this.queryVector = queryVector; + this.vectorSimilarityFunction = vectorSimilarityFunction; + this.scoreFunction = scoreFunction; + } + + @Override + public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { + BinaryDocValues values = ctx.reader().getBinaryDocValues(fieldName); + if (values == null) { + return DoubleValues.EMPTY; + } + + return new DoubleValues() { + @Override + public double doubleValue() throws IOException { + return scoreFunction.compare( + queryVector, + LateInteractionField.decode(values.binaryValue()), + vectorSimilarityFunction); + } + + @Override + public boolean advanceExact(int doc) throws IOException { + return values.advanceExact(doc); + } + }; + } + + @Override + public boolean needsScores() { + return false; + } + + @Override + public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException { + return this; + } + + @Override + public int hashCode() { + return Objects.hash( + fieldName, Arrays.deepHashCode(queryVector), vectorSimilarityFunction, scoreFunction); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + LateInteractionValuesSource other = (LateInteractionValuesSource) obj; + return Objects.equals(fieldName, other.fieldName) + && vectorSimilarityFunction == other.vectorSimilarityFunction + && scoreFunction == other.scoreFunction + && Arrays.deepEquals(queryVector, other.queryVector); + } + + @Override + public String toString() { + return "LateInteractionValuesSource(fieldName=" + + fieldName + + " similarityFunction=" + + vectorSimilarityFunction + + " scoreFunction=" + + scoreFunction.name() + + " queryVector=" + + Arrays.deepToString(queryVector) + + ")"; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } + + /** Defines the function to compute similarity score between query and document multi-vectors */ + public enum ScoreFunction { + + /** Computes the sum of max similarity between query and document vectors */ + SUM_MAX_SIM { + @Override + public float compare( + float[][] queryVector, + float[][] docVector, + VectorSimilarityFunction vectorSimilarityFunction) { + if (docVector.length == 0) { + return Float.MIN_VALUE; + } + float result = 0f; + for (float[] q : queryVector) { + float maxSim = Float.MIN_VALUE; + for (float[] d : docVector) { + if (q.length != d.length) { + throw new IllegalArgumentException( + "Provided multi-vectors are incompatible. " + + "Their composing token vectors should have the same dimension, got " + + q.length + + " != " + + d.length); + } + maxSim = Float.max(maxSim, vectorSimilarityFunction.compare(q, d)); + } + result += maxSim; + } + return result; + } + }; + + /** + * Computes similarity between two multi-vectors using provided {@link VectorSimilarityFunction} + * + *

Provided multi-vectors can have varying number of composing token vectors, but their token + * vectors should have the same dimension. + * + * @param outer a multi-vector + * @param inner another multi-vector + * @return similarity score between two multi-vectors + */ + public abstract float compare( + float[][] outer, float[][] inner, VectorSimilarityFunction vectorSimilarityFunction); + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java index d16749f73868..75280c3ff02d 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java @@ -20,12 +20,14 @@ import java.io.IOException; import java.util.Objects; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.DoubleValues; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.FilterScorer; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.LateInteractionValuesSource; import org.apache.lucene.search.Matches; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -71,6 +73,30 @@ public DoubleValuesSource getSource() { return source; } + /** + * Returns a FunctionScoreQuery that re-scores hits from the wrapped query using late-interaction + * scores between provided query and indexed document multi-vectors. + * + *

Document multi-vectors are indexed using {@link + * org.apache.lucene.document.LateInteractionField}. Documents are scored using {@link + * LateInteractionValuesSource.ScoreFunction#SUM_MAX_SIM} computation on provided vector + * similarity function. + * + * @param in the query to re-score + * @param fieldName field containing document multi-vectors for re-scoring + * @param queryVector query multi-vector + * @param vectorSimilarityFunction vector similarity function used for computing scores + */ + public static FunctionScoreQuery lateInteractionRerankQuery( + Query in, + String fieldName, + float[][] queryVector, + VectorSimilarityFunction vectorSimilarityFunction) { + LateInteractionValuesSource scoreSource = + new LateInteractionValuesSource(fieldName, queryVector, vectorSimilarityFunction); + return new FunctionScoreQuery(in, scoreSource); + } + /** * Returns a FunctionScoreQuery where the scores of a wrapped query are multiplied by the value of * a DoubleValuesSource. From 1132e67b1a0427f502e4090baaf2f44df1a84990 Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Mon, 26 May 2025 23:26:26 -0700 Subject: [PATCH 03/18] tests for late interaction field --- .../lucene/document/LateInteractionField.java | 12 ++- .../document/TestLateInteractionField.java | 77 +++++++++++++++++++ 2 files changed, 86 insertions(+), 3 deletions(-) create mode 100644 lucene/core/src/test/org/apache/lucene/document/TestLateInteractionField.java diff --git a/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java b/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java index 224f9d8993ff..dd22ecd6d539 100644 --- a/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java +++ b/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java @@ -19,6 +19,8 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.FloatBuffer; + import org.apache.lucene.util.BytesRef; /** @@ -74,12 +76,16 @@ public static BytesRef encode(float[][] value) { if (value == null || value.length == 0) { throw new IllegalArgumentException("Value should not be null or empty"); } + if (value[0] == null || value[0].length== 0) { + throw new IllegalArgumentException("Composing token vectors should not be null or empty"); + } final int tokenVectorDimension = value[0].length; final ByteBuffer buffer = ByteBuffer.allocate(Integer.BYTES + value.length * tokenVectorDimension * Float.BYTES) .order(ByteOrder.LITTLE_ENDIAN); // TODO: Should we store dimension in FieldType to ensure consistency across all documents? buffer.putInt(tokenVectorDimension); + FloatBuffer floatBuffer = buffer.asFloatBuffer(); for (int i = 0; i < value.length; i++) { if (value[i].length != tokenVectorDimension) { throw new IllegalArgumentException( @@ -91,7 +97,7 @@ public static BytesRef encode(float[][] value) { + " != " + value[i].length); } - buffer.asFloatBuffer().put(value[i]); + floatBuffer.put(value[i]); } return new BytesRef(buffer.array()); } @@ -110,8 +116,8 @@ public static float[][] decode(BytesRef payload) { final ByteBuffer buffer = ByteBuffer.wrap(payload.bytes, payload.offset, payload.length); buffer.order(ByteOrder.LITTLE_ENDIAN); final int tokenVectorDimension = buffer.getInt(); - int numVectors = (payload.length - 4) / tokenVectorDimension; - if (numVectors * tokenVectorDimension + 4 != payload.length) { + int numVectors = (payload.length - Integer.BYTES) / (tokenVectorDimension * Float.BYTES); + if (numVectors * tokenVectorDimension * Float.BYTES + Integer.BYTES != payload.length) { throw new IllegalArgumentException( "Provided payload does not appear to have been encoded via LateInteractionField.encode. " + "Payload length should be equal to 4 + numVectors * tokenVectorDimension, " diff --git a/lucene/core/src/test/org/apache/lucene/document/TestLateInteractionField.java b/lucene/core/src/test/org/apache/lucene/document/TestLateInteractionField.java new file mode 100644 index 000000000000..758188ebc765 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/document/TestLateInteractionField.java @@ -0,0 +1,77 @@ +package org.apache.lucene.document; + +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.TestVectorUtil; + +public class TestLateInteractionField extends LuceneTestCase { + + public void testEncodeDecode() { + float[][] value = new float[random().nextInt(3, 12)][]; + final int dim = 128; + for (int i = 0; i < value.length; i++) { + value[i] = TestVectorUtil.randomVector(dim); + } + final LateInteractionField field = new LateInteractionField("test", value); + BytesRef encoded = LateInteractionField.encode(value); + float[][] decoded = LateInteractionField.decode(encoded); + assertEqualArrays(value, decoded); + assertEqualArrays(value, field.getValue()); + } + + public void testSetterGetter() { + final int dim = 128; + float[][] value = new float[random().nextInt(3, 12)][]; + for (int i = 0; i < value.length; i++) { + value[i] = TestVectorUtil.randomVector(dim); + } + final LateInteractionField field = new LateInteractionField("test", value); + + float[][] value2 = new float[random().nextInt(3, 12)][]; + for (int i = 0; i < value2.length; i++) { + value2[i] = TestVectorUtil.randomVector(dim); + } + assertEqualArrays(field.getValue(), value); + field.setValue(value2); + assertEqualArrays(field.getValue(), value2); + } + + public void testInputValidation() { + expectThrows(IllegalArgumentException.class, + () -> LateInteractionField.encode(null)); + expectThrows(IllegalArgumentException.class, + () -> new LateInteractionField("test", null)); + expectThrows(IllegalArgumentException.class, + () -> LateInteractionField.encode(new float[0][])); + expectThrows(IllegalArgumentException.class, + () -> LateInteractionField.encode(new float[3][])); + + float[][] emptyTokens = new float[1][]; + emptyTokens[0] = new float[0]; + expectThrows(IllegalArgumentException.class, + () -> LateInteractionField.encode(emptyTokens)); + + final int dim = 128; + float[][] value = new float[random().nextInt(3, 12)][]; + for (int i = 0; i < value.length; i++) { + if (random().nextBoolean()) { + value[i] = TestVectorUtil.randomVector(dim); + } else { + value[i] = TestVectorUtil.randomVector(dim + 1); + } + } + expectThrows(IllegalArgumentException.class, + () -> LateInteractionField.encode(value)); + } + + private void assertEqualArrays(float[][] a, float[][] b) { + assertEquals(a.length, b.length); + for (int i = 0; i < a.length; i++) { + assertEquals(a[i].length, b[i].length); + for (int j = 0; j < a[i].length; j++) { + assertEquals(a[i][j], b[i][j], 1e-5f); + } + } + } + +} From d8c3ecce54d61a3d1caabb17ba6dc0caf711015b Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Tue, 27 May 2025 16:35:50 -0700 Subject: [PATCH 04/18] LateI values source test and tidy --- .../lucene/document/LateInteractionField.java | 3 +- .../search/LateInteractionValuesSource.java | 28 ++- .../document/TestLateInteractionField.java | 36 ++-- .../TestLateInteractionValuesSource.java | 159 ++++++++++++++++++ 4 files changed, 209 insertions(+), 17 deletions(-) create mode 100644 lucene/core/src/test/org/apache/lucene/search/TestLateInteractionValuesSource.java diff --git a/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java b/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java index dd22ecd6d539..e1446eb15f9d 100644 --- a/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java +++ b/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java @@ -20,7 +20,6 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; - import org.apache.lucene.util.BytesRef; /** @@ -76,7 +75,7 @@ public static BytesRef encode(float[][] value) { if (value == null || value.length == 0) { throw new IllegalArgumentException("Value should not be null or empty"); } - if (value[0] == null || value[0].length== 0) { + if (value[0] == null || value[0].length == 0) { throw new IllegalArgumentException("Composing token vectors should not be null or empty"); } final int tokenVectorDimension = value[0].length; diff --git a/lucene/core/src/java/org/apache/lucene/search/LateInteractionValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/LateInteractionValuesSource.java index 33893b5135da..173e3cf23e42 100644 --- a/lucene/core/src/java/org/apache/lucene/search/LateInteractionValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/LateInteractionValuesSource.java @@ -32,6 +32,10 @@ public class LateInteractionValuesSource extends DoubleValuesSource { private final VectorSimilarityFunction vectorSimilarityFunction; private final ScoreFunction scoreFunction; + public LateInteractionValuesSource(String fieldName, float[][] queryVector) { + this(fieldName, queryVector, VectorSimilarityFunction.COSINE, ScoreFunction.SUM_MAX_SIM); + } + public LateInteractionValuesSource( String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) { this(fieldName, queryVector, vectorSimilarityFunction, ScoreFunction.SUM_MAX_SIM); @@ -45,8 +49,12 @@ public LateInteractionValuesSource( if (fieldName == null) { throw new IllegalArgumentException("fieldName must not be null"); } - if (queryVector == null || queryVector.length == 0) { - throw new IllegalArgumentException("queryVector must not be null or empty"); + validateQueryVector(queryVector); + if (vectorSimilarityFunction == null) { + throw new IllegalArgumentException("vectorSimilarityFunction must not be null"); + } + if (scoreFunction == null) { + throw new IllegalArgumentException("scoreFunction must not be null"); } this.fieldName = fieldName; this.queryVector = queryVector; @@ -54,6 +62,22 @@ public LateInteractionValuesSource( this.scoreFunction = scoreFunction; } + private void validateQueryVector(float[][] queryVector) { + if (queryVector == null || queryVector.length == 0) { + throw new IllegalArgumentException("queryVector must not be null or empty"); + } + if (queryVector[0] == null || queryVector[0].length == 0) { + throw new IllegalArgumentException( + "composing token vectors in provided query vector should not be null or empty"); + } + for (int i = 1; i < queryVector.length; i++) { + if (queryVector[i] == null || queryVector[i].length != queryVector[0].length) { + throw new IllegalArgumentException( + "all composing token vectors in provided query vector should have the same length"); + } + } + } + @Override public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { BinaryDocValues values = ctx.reader().getBinaryDocValues(fieldName); diff --git a/lucene/core/src/test/org/apache/lucene/document/TestLateInteractionField.java b/lucene/core/src/test/org/apache/lucene/document/TestLateInteractionField.java index 758188ebc765..5adf59daf899 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestLateInteractionField.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestLateInteractionField.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.lucene.document; import org.apache.lucene.tests.util.LuceneTestCase; @@ -37,19 +54,14 @@ public void testSetterGetter() { } public void testInputValidation() { - expectThrows(IllegalArgumentException.class, - () -> LateInteractionField.encode(null)); - expectThrows(IllegalArgumentException.class, - () -> new LateInteractionField("test", null)); - expectThrows(IllegalArgumentException.class, - () -> LateInteractionField.encode(new float[0][])); - expectThrows(IllegalArgumentException.class, - () -> LateInteractionField.encode(new float[3][])); + expectThrows(IllegalArgumentException.class, () -> LateInteractionField.encode(null)); + expectThrows(IllegalArgumentException.class, () -> new LateInteractionField("test", null)); + expectThrows(IllegalArgumentException.class, () -> LateInteractionField.encode(new float[0][])); + expectThrows(IllegalArgumentException.class, () -> LateInteractionField.encode(new float[3][])); float[][] emptyTokens = new float[1][]; emptyTokens[0] = new float[0]; - expectThrows(IllegalArgumentException.class, - () -> LateInteractionField.encode(emptyTokens)); + expectThrows(IllegalArgumentException.class, () -> LateInteractionField.encode(emptyTokens)); final int dim = 128; float[][] value = new float[random().nextInt(3, 12)][]; @@ -60,8 +72,7 @@ public void testInputValidation() { value[i] = TestVectorUtil.randomVector(dim + 1); } } - expectThrows(IllegalArgumentException.class, - () -> LateInteractionField.encode(value)); + expectThrows(IllegalArgumentException.class, () -> LateInteractionField.encode(value)); } private void assertEqualArrays(float[][] a, float[][] b) { @@ -73,5 +84,4 @@ private void assertEqualArrays(float[][] a, float[][] b) { } } } - } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionValuesSource.java b/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionValuesSource.java new file mode 100644 index 000000000000..42291fe5b9b9 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionValuesSource.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntField; +import org.apache.lucene.document.LateInteractionField; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.StoredFields; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.TestVectorUtil; + +public class TestLateInteractionValuesSource extends LuceneTestCase { + + private static final int dimension = 128; + private static final String LATE_I_FIELD = "lateIF"; + + public void testValidations() { + expectThrows( + IllegalArgumentException.class, + () -> new LateInteractionValuesSource(null, new float[0][], null)); + expectThrows( + IllegalArgumentException.class, + () -> new LateInteractionValuesSource("fieldName", null, null)); + expectThrows( + IllegalArgumentException.class, + () -> new LateInteractionValuesSource("fieldName", new float[0][], null)); + + float[][] emptyTokens = new float[1][]; + emptyTokens[0] = new float[0]; + expectThrows( + IllegalArgumentException.class, + () -> new LateInteractionValuesSource("fieldName", emptyTokens, null)); + + float[][] valueBad = new float[random().nextInt(3, 12)][]; + for (int i = 0; i < valueBad.length; i++) { + if (random().nextBoolean()) { + valueBad[i] = TestVectorUtil.randomVector(dimension); + } else { + valueBad[i] = TestVectorUtil.randomVector(dimension + 1); + } + } + expectThrows( + IllegalArgumentException.class, + () -> new LateInteractionValuesSource("fieldName", valueBad, null)); + + float[][] value = createMultiVector(); + expectThrows( + IllegalArgumentException.class, + () -> new LateInteractionValuesSource("fieldName", value, null)); + expectThrows( + IllegalArgumentException.class, + () -> + new LateInteractionValuesSource( + "fieldName", value, VectorSimilarityFunction.COSINE, null)); + } + + public void testValues() throws IOException { + List corpus = new ArrayList<>(); + final int numDocs = atLeast(1000); + final int numSegments = random().nextInt(2, 10); + final VectorSimilarityFunction vectorSimilarityFunction = + VectorSimilarityFunction.values()[ + random().nextInt(VectorSimilarityFunction.values().length)]; + LateInteractionValuesSource.ScoreFunction scoreFunction = + LateInteractionValuesSource.ScoreFunction.values()[ + random().nextInt(LateInteractionValuesSource.ScoreFunction.values().length)]; + + try (Directory dir = newDirectory()) { + int id = 0; + try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int j = 0; j < numSegments; j++) { + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + if (random().nextInt(100) < 30) { + // skip value for some docs to create sparse field + doc.add(new IntField("has_li_vector", 0, Field.Store.YES)); + } else { + float[][] value = createMultiVector(); + corpus.add(value); + doc.add(new IntField("id", id++, Field.Store.YES)); + doc.add(new LateInteractionField(LATE_I_FIELD, value)); + doc.add(new IntField("has_li_vector", 1, Field.Store.YES)); + } + w.addDocument(doc); + w.flush(); + } + } + // add a segment with no vectors + for (int i = 0; i < 100; i++) { + Document doc = new Document(); + doc.add(new IntField("has_li_vector", 0, Field.Store.YES)); + w.addDocument(doc); + } + w.flush(); + } + + float[][] queryVector = createMultiVector(); + LateInteractionValuesSource source = + new LateInteractionValuesSource( + LATE_I_FIELD, queryVector, vectorSimilarityFunction, scoreFunction); + try (IndexReader reader = DirectoryReader.open(dir)) { + for (LeafReaderContext ctx : reader.leaves()) { + DoubleValues values = source.getValues(ctx, null); + final FieldInfo fi = ctx.reader().getFieldInfos().fieldInfo(LATE_I_FIELD); + if (fi == null) { + assertEquals(values, DoubleValues.EMPTY); + continue; + } + BinaryDocValues disi = ctx.reader().getBinaryDocValues(LATE_I_FIELD); + StoredFields storedFields = ctx.reader().storedFields(); + while (disi.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + int doc = disi.docID(); + int idValue = Integer.parseInt(storedFields.document(doc).get("id")); + float[][] docVector = corpus.get(idValue); + float expected = + scoreFunction.compare(queryVector, docVector, vectorSimilarityFunction); + values.advanceExact(doc); + assertEquals(expected, values.doubleValue(), 0.0001); + } + } + } + } + } + + private float[][] createMultiVector() { + float[][] value = new float[random().nextInt(3, 12)][]; + for (int i = 0; i < value.length; i++) { + value[i] = TestVectorUtil.randomVector(dimension); + } + return value; + } +} From 9f8f0e5fce9d8f78c0ae9bcddcd15173fcb5501c Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Tue, 27 May 2025 23:03:54 -0700 Subject: [PATCH 05/18] remove some null tests --- .../search/LateInteractionValuesSource.java | 21 ++++++------------- .../TestLateInteractionValuesSource.java | 17 --------------- 2 files changed, 6 insertions(+), 32 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/LateInteractionValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/LateInteractionValuesSource.java index 173e3cf23e42..5c58f6a34008 100644 --- a/lucene/core/src/java/org/apache/lucene/search/LateInteractionValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/LateInteractionValuesSource.java @@ -46,23 +46,13 @@ public LateInteractionValuesSource( float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction, ScoreFunction scoreFunction) { - if (fieldName == null) { - throw new IllegalArgumentException("fieldName must not be null"); - } - validateQueryVector(queryVector); - if (vectorSimilarityFunction == null) { - throw new IllegalArgumentException("vectorSimilarityFunction must not be null"); - } - if (scoreFunction == null) { - throw new IllegalArgumentException("scoreFunction must not be null"); - } - this.fieldName = fieldName; - this.queryVector = queryVector; - this.vectorSimilarityFunction = vectorSimilarityFunction; - this.scoreFunction = scoreFunction; + this.fieldName = Objects.requireNonNull(fieldName); + this.queryVector = validateQueryVector(queryVector); + this.vectorSimilarityFunction = Objects.requireNonNull(vectorSimilarityFunction); + this.scoreFunction = Objects.requireNonNull(scoreFunction); } - private void validateQueryVector(float[][] queryVector) { + private float[][] validateQueryVector(float[][] queryVector) { if (queryVector == null || queryVector.length == 0) { throw new IllegalArgumentException("queryVector must not be null or empty"); } @@ -76,6 +66,7 @@ private void validateQueryVector(float[][] queryVector) { "all composing token vectors in provided query vector should have the same length"); } } + return queryVector; } @Override diff --git a/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionValuesSource.java b/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionValuesSource.java index 42291fe5b9b9..ac287ce452b0 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionValuesSource.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionValuesSource.java @@ -42,16 +42,12 @@ public class TestLateInteractionValuesSource extends LuceneTestCase { private static final String LATE_I_FIELD = "lateIF"; public void testValidations() { - expectThrows( - IllegalArgumentException.class, - () -> new LateInteractionValuesSource(null, new float[0][], null)); expectThrows( IllegalArgumentException.class, () -> new LateInteractionValuesSource("fieldName", null, null)); expectThrows( IllegalArgumentException.class, () -> new LateInteractionValuesSource("fieldName", new float[0][], null)); - float[][] emptyTokens = new float[1][]; emptyTokens[0] = new float[0]; expectThrows( @@ -66,19 +62,6 @@ public void testValidations() { valueBad[i] = TestVectorUtil.randomVector(dimension + 1); } } - expectThrows( - IllegalArgumentException.class, - () -> new LateInteractionValuesSource("fieldName", valueBad, null)); - - float[][] value = createMultiVector(); - expectThrows( - IllegalArgumentException.class, - () -> new LateInteractionValuesSource("fieldName", value, null)); - expectThrows( - IllegalArgumentException.class, - () -> - new LateInteractionValuesSource( - "fieldName", value, VectorSimilarityFunction.COSINE, null)); } public void testValues() throws IOException { From 047ad7761a7275e587ffc6b85061495a5bbd686b Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Tue, 27 May 2025 23:17:52 -0700 Subject: [PATCH 06/18] test sumMaxSim score fn --- .../search/TestLateInteractionValuesSource.java | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionValuesSource.java b/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionValuesSource.java index ac287ce452b0..a0ce85bb6ee1 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionValuesSource.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionValuesSource.java @@ -132,6 +132,23 @@ public void testValues() throws IOException { } } + public void testSumMaxSim() { + float[][] queryVector = createMultiVector(); + // docVector has all the token vectors in the query and some random token vectors + float[][] docVector = new float[queryVector.length + 4][]; + for (int i = 0; i < queryVector.length; i++) { + docVector[i] = new float[queryVector[i].length]; + System.arraycopy(queryVector[i], 0, docVector[i], 0, queryVector[i].length); + } + for (int i = queryVector.length; i < docVector.length; i++) { + docVector[i] = TestVectorUtil.randomVector(queryVector[0].length); + } + float score = + LateInteractionValuesSource.ScoreFunction.SUM_MAX_SIM.compare( + queryVector, docVector, VectorSimilarityFunction.COSINE); + assertEquals(queryVector.length, score, 1e-5); + } + private float[][] createMultiVector() { float[][] value = new float[random().nextInt(3, 12)][]; for (int i = 0; i < value.length; i++) { From 0b07aba9c566e2e33bfddbcd2434e60b54bd80b2 Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Wed, 28 May 2025 16:28:59 -0700 Subject: [PATCH 07/18] late I query test --- .../function/TestFunctionScoreQuery.java | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java index 85bcc05cbf78..39c22d3b4475 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java @@ -18,9 +18,19 @@ package org.apache.lucene.queries.function; import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.Set; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.LateInteractionField; import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.TextField; import org.apache.lucene.expressions.Expression; @@ -31,16 +41,21 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.LateInteractionValuesSource; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.PhraseQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TopDocs; @@ -377,4 +392,89 @@ public void testQueryMatchesCount() throws Exception { } assertEquals(searchCount, weightCount); } + + public void testLateInteractionQuery() throws Exception { + final String LATE_I_FIELD = "li_vector"; + final String KNN_FIELD = "knn_vector"; + List corpus = new ArrayList<>(); + final int numDocs = atLeast(1000); + final int numSegments = random().nextInt(2, 10); + final int dim = 128; + final VectorSimilarityFunction vectorSimilarityFunction = + VectorSimilarityFunction.values()[ + random().nextInt(VectorSimilarityFunction.values().length)]; + LateInteractionValuesSource.ScoreFunction scoreFunction = + LateInteractionValuesSource.ScoreFunction.values()[ + random().nextInt(LateInteractionValuesSource.ScoreFunction.values().length)]; + + try (Directory dir = newDirectory()) { + int id = 0; + try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int j = 0; j < numSegments; j++) { + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + if (random().nextInt(100) < 30) { + // skip value for some docs to create sparse field + doc.add(new IntField("has_li_vector", 0, Field.Store.YES)); + } else { + float[][] value = createMultiVector(dim); + corpus.add(value); + doc.add(new IntField("id", id++, Field.Store.YES)); + doc.add(new LateInteractionField(LATE_I_FIELD, value)); + doc.add(new KnnFloatVectorField(KNN_FIELD, randomVector(dim))); + doc.add(new IntField("has_li_vector", 1, Field.Store.YES)); + } + w.addDocument(doc); + w.flush(); + } + } + // add a segment with no vectors + for (int i = 0; i < 100; i++) { + Document doc = new Document(); + doc.add(new IntField("has_li_vector", 0, Field.Store.YES)); + w.addDocument(doc); + } + w.flush(); + } + + float[][] lateIQueryVector = createMultiVector(dim); + float[] knnQueryVector = randomVector(dim); + KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery(KNN_FIELD, knnQueryVector, 50); + + try (IndexReader reader = DirectoryReader.open(dir)) { + IndexSearcher s = new IndexSearcher(reader); + TopDocs knnHits = s.search(knnQuery, 50); + Set knnHitDocs = Arrays.stream(knnHits.scoreDocs).map(k -> k.doc).collect(Collectors.toSet()); + FunctionScoreQuery lateIQuery = + FunctionScoreQuery.lateInteractionRerankQuery(knnQuery, LATE_I_FIELD, lateIQueryVector, vectorSimilarityFunction); + TopDocs lateIHits = s.search(lateIQuery, 10); + StoredFields storedFields = reader.storedFields(); + for (ScoreDoc hit : lateIHits.scoreDocs) { + assertTrue(knnHitDocs.contains(hit.doc)); + int idValue = Integer.parseInt(storedFields.document(hit.doc).get("id")); + float[][] docVector = corpus.get(idValue); + float expected = + scoreFunction.compare(lateIQueryVector, docVector, vectorSimilarityFunction); + assertEquals(expected, hit.score, 1e-5); + } + } + } + } + + private float[] randomVector(int dim) { + float[] v = new float[dim]; + Random random = random(); + for (int i = 0; i < dim; i++) { + v[i] = random.nextFloat(); + } + return v; + } + + private float[][] createMultiVector(int dimension) { + float[][] value = new float[random().nextInt(3, 12)][]; + for (int i = 0; i < value.length; i++) { + value[i] = randomVector(dimension); + } + return value; + } } From f3f64c284576900961eb52313ac02d53d80dac74 Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Wed, 28 May 2025 16:30:26 -0700 Subject: [PATCH 08/18] missing docstring --- .../org/apache/lucene/document/LateInteractionField.java | 2 +- .../lucene/queries/function/TestFunctionScoreQuery.java | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java b/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java index e1446eb15f9d..d6d9d75a88af 100644 --- a/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java +++ b/lucene/core/src/java/org/apache/lucene/document/LateInteractionField.java @@ -109,7 +109,7 @@ public static BytesRef encode(float[][] value) { * LateInteractionField#encode(float[][])} * * @param payload to decode into multi-vector value - * @return + * @return decoded multi-vector value */ public static float[][] decode(BytesRef payload) { final ByteBuffer buffer = ByteBuffer.wrap(payload.bytes, payload.offset, payload.length); diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java index 39c22d3b4475..03591f7c363e 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java @@ -25,7 +25,6 @@ import java.util.Set; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; - import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.IntField; @@ -444,9 +443,11 @@ public void testLateInteractionQuery() throws Exception { try (IndexReader reader = DirectoryReader.open(dir)) { IndexSearcher s = new IndexSearcher(reader); TopDocs knnHits = s.search(knnQuery, 50); - Set knnHitDocs = Arrays.stream(knnHits.scoreDocs).map(k -> k.doc).collect(Collectors.toSet()); + Set knnHitDocs = + Arrays.stream(knnHits.scoreDocs).map(k -> k.doc).collect(Collectors.toSet()); FunctionScoreQuery lateIQuery = - FunctionScoreQuery.lateInteractionRerankQuery(knnQuery, LATE_I_FIELD, lateIQueryVector, vectorSimilarityFunction); + FunctionScoreQuery.lateInteractionRerankQuery( + knnQuery, LATE_I_FIELD, lateIQueryVector, vectorSimilarityFunction); TopDocs lateIHits = s.search(lateIQuery, 10); StoredFields storedFields = reader.storedFields(); for (ScoreDoc hit : lateIHits.scoreDocs) { From b59f853e91a25f65403552009454f1b13697b1b2 Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Wed, 28 May 2025 16:33:51 -0700 Subject: [PATCH 09/18] rename files --- ... => LateInteractionFloatValuesSource.java} | 12 +++++------ ...TestLateInteractionFloatValuesSource.java} | 20 +++++++++---------- .../queries/function/FunctionScoreQuery.java | 10 +++++----- .../function/TestFunctionScoreQuery.java | 10 +++++----- 4 files changed, 26 insertions(+), 26 deletions(-) rename lucene/core/src/java/org/apache/lucene/search/{LateInteractionValuesSource.java => LateInteractionFloatValuesSource.java} (94%) rename lucene/core/src/test/org/apache/lucene/search/{TestLateInteractionValuesSource.java => TestLateInteractionFloatValuesSource.java} (89%) diff --git a/lucene/core/src/java/org/apache/lucene/search/LateInteractionValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java similarity index 94% rename from lucene/core/src/java/org/apache/lucene/search/LateInteractionValuesSource.java rename to lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java index 5c58f6a34008..14e499f62c9e 100644 --- a/lucene/core/src/java/org/apache/lucene/search/LateInteractionValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java @@ -25,23 +25,23 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorSimilarityFunction; -public class LateInteractionValuesSource extends DoubleValuesSource { +public class LateInteractionFloatValuesSource extends DoubleValuesSource { private final String fieldName; private final float[][] queryVector; private final VectorSimilarityFunction vectorSimilarityFunction; private final ScoreFunction scoreFunction; - public LateInteractionValuesSource(String fieldName, float[][] queryVector) { + public LateInteractionFloatValuesSource(String fieldName, float[][] queryVector) { this(fieldName, queryVector, VectorSimilarityFunction.COSINE, ScoreFunction.SUM_MAX_SIM); } - public LateInteractionValuesSource( + public LateInteractionFloatValuesSource( String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) { this(fieldName, queryVector, vectorSimilarityFunction, ScoreFunction.SUM_MAX_SIM); } - public LateInteractionValuesSource( + public LateInteractionFloatValuesSource( String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction, @@ -112,7 +112,7 @@ public int hashCode() { public boolean equals(Object obj) { if (this == obj) return true; if (obj == null || getClass() != obj.getClass()) return false; - LateInteractionValuesSource other = (LateInteractionValuesSource) obj; + LateInteractionFloatValuesSource other = (LateInteractionFloatValuesSource) obj; return Objects.equals(fieldName, other.fieldName) && vectorSimilarityFunction == other.vectorSimilarityFunction && scoreFunction == other.scoreFunction @@ -121,7 +121,7 @@ public boolean equals(Object obj) { @Override public String toString() { - return "LateInteractionValuesSource(fieldName=" + return "LateInteractionFloatValuesSource(fieldName=" + fieldName + " similarityFunction=" + vectorSimilarityFunction diff --git a/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionValuesSource.java b/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionFloatValuesSource.java similarity index 89% rename from lucene/core/src/test/org/apache/lucene/search/TestLateInteractionValuesSource.java rename to lucene/core/src/test/org/apache/lucene/search/TestLateInteractionFloatValuesSource.java index a0ce85bb6ee1..6d72830a0558 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionValuesSource.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionFloatValuesSource.java @@ -36,7 +36,7 @@ import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.TestVectorUtil; -public class TestLateInteractionValuesSource extends LuceneTestCase { +public class TestLateInteractionFloatValuesSource extends LuceneTestCase { private static final int dimension = 128; private static final String LATE_I_FIELD = "lateIF"; @@ -44,15 +44,15 @@ public class TestLateInteractionValuesSource extends LuceneTestCase { public void testValidations() { expectThrows( IllegalArgumentException.class, - () -> new LateInteractionValuesSource("fieldName", null, null)); + () -> new LateInteractionFloatValuesSource("fieldName", null, null)); expectThrows( IllegalArgumentException.class, - () -> new LateInteractionValuesSource("fieldName", new float[0][], null)); + () -> new LateInteractionFloatValuesSource("fieldName", new float[0][], null)); float[][] emptyTokens = new float[1][]; emptyTokens[0] = new float[0]; expectThrows( IllegalArgumentException.class, - () -> new LateInteractionValuesSource("fieldName", emptyTokens, null)); + () -> new LateInteractionFloatValuesSource("fieldName", emptyTokens, null)); float[][] valueBad = new float[random().nextInt(3, 12)][]; for (int i = 0; i < valueBad.length; i++) { @@ -71,9 +71,9 @@ public void testValues() throws IOException { final VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.values()[ random().nextInt(VectorSimilarityFunction.values().length)]; - LateInteractionValuesSource.ScoreFunction scoreFunction = - LateInteractionValuesSource.ScoreFunction.values()[ - random().nextInt(LateInteractionValuesSource.ScoreFunction.values().length)]; + LateInteractionFloatValuesSource.ScoreFunction scoreFunction = + LateInteractionFloatValuesSource.ScoreFunction.values()[ + random().nextInt(LateInteractionFloatValuesSource.ScoreFunction.values().length)]; try (Directory dir = newDirectory()) { int id = 0; @@ -105,8 +105,8 @@ public void testValues() throws IOException { } float[][] queryVector = createMultiVector(); - LateInteractionValuesSource source = - new LateInteractionValuesSource( + LateInteractionFloatValuesSource source = + new LateInteractionFloatValuesSource( LATE_I_FIELD, queryVector, vectorSimilarityFunction, scoreFunction); try (IndexReader reader = DirectoryReader.open(dir)) { for (LeafReaderContext ctx : reader.leaves()) { @@ -144,7 +144,7 @@ public void testSumMaxSim() { docVector[i] = TestVectorUtil.randomVector(queryVector[0].length); } float score = - LateInteractionValuesSource.ScoreFunction.SUM_MAX_SIM.compare( + LateInteractionFloatValuesSource.ScoreFunction.SUM_MAX_SIM.compare( queryVector, docVector, VectorSimilarityFunction.COSINE); assertEquals(queryVector.length, score, 1e-5); } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java index 75280c3ff02d..182e60a2516f 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java @@ -27,7 +27,7 @@ import org.apache.lucene.search.Explanation; import org.apache.lucene.search.FilterScorer; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.LateInteractionValuesSource; +import org.apache.lucene.search.LateInteractionFloatValuesSource; import org.apache.lucene.search.Matches; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -79,7 +79,7 @@ public DoubleValuesSource getSource() { * *

Document multi-vectors are indexed using {@link * org.apache.lucene.document.LateInteractionField}. Documents are scored using {@link - * LateInteractionValuesSource.ScoreFunction#SUM_MAX_SIM} computation on provided vector + * LateInteractionFloatValuesSource.ScoreFunction#SUM_MAX_SIM} computation on provided vector * similarity function. * * @param in the query to re-score @@ -87,13 +87,13 @@ public DoubleValuesSource getSource() { * @param queryVector query multi-vector * @param vectorSimilarityFunction vector similarity function used for computing scores */ - public static FunctionScoreQuery lateInteractionRerankQuery( + public static FunctionScoreQuery lateInteractionFloatRerankQuery( Query in, String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) { - LateInteractionValuesSource scoreSource = - new LateInteractionValuesSource(fieldName, queryVector, vectorSimilarityFunction); + LateInteractionFloatValuesSource scoreSource = + new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction); return new FunctionScoreQuery(in, scoreSource); } diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java index 03591f7c363e..ce5034a50ed0 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java @@ -50,7 +50,7 @@ import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnFloatVectorQuery; -import org.apache.lucene.search.LateInteractionValuesSource; +import org.apache.lucene.search.LateInteractionFloatValuesSource; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.PhraseQuery; import org.apache.lucene.search.Query; @@ -402,9 +402,9 @@ public void testLateInteractionQuery() throws Exception { final VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.values()[ random().nextInt(VectorSimilarityFunction.values().length)]; - LateInteractionValuesSource.ScoreFunction scoreFunction = - LateInteractionValuesSource.ScoreFunction.values()[ - random().nextInt(LateInteractionValuesSource.ScoreFunction.values().length)]; + LateInteractionFloatValuesSource.ScoreFunction scoreFunction = + LateInteractionFloatValuesSource.ScoreFunction.values()[ + random().nextInt(LateInteractionFloatValuesSource.ScoreFunction.values().length)]; try (Directory dir = newDirectory()) { int id = 0; @@ -446,7 +446,7 @@ public void testLateInteractionQuery() throws Exception { Set knnHitDocs = Arrays.stream(knnHits.scoreDocs).map(k -> k.doc).collect(Collectors.toSet()); FunctionScoreQuery lateIQuery = - FunctionScoreQuery.lateInteractionRerankQuery( + FunctionScoreQuery.lateInteractionFloatRerankQuery( knnQuery, LATE_I_FIELD, lateIQueryVector, vectorSimilarityFunction); TopDocs lateIHits = s.search(lateIQuery, 10); StoredFields storedFields = reader.storedFields(); From 46c209a63a26cdf18b478a7d6671ee4b08f17d5f Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Wed, 28 May 2025 16:36:21 -0700 Subject: [PATCH 10/18] docstring fix --- .../apache/lucene/queries/function/FunctionScoreQuery.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java index 182e60a2516f..e8af8b5d0b37 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java @@ -78,9 +78,7 @@ public DoubleValuesSource getSource() { * scores between provided query and indexed document multi-vectors. * *

Document multi-vectors are indexed using {@link - * org.apache.lucene.document.LateInteractionField}. Documents are scored using {@link - * LateInteractionFloatValuesSource.ScoreFunction#SUM_MAX_SIM} computation on provided vector - * similarity function. + * org.apache.lucene.document.LateInteractionField}. * * @param in the query to re-score * @param fieldName field containing document multi-vectors for re-scoring From 1d67b85d9f65fd2b0d66e692361c68bc8f3831b8 Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Wed, 28 May 2025 16:43:13 -0700 Subject: [PATCH 11/18] add lateI value source docString --- .../lucene/search/LateInteractionFloatValuesSource.java | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java index 14e499f62c9e..21d16a2eeb7a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java @@ -25,6 +25,15 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorSimilarityFunction; + +/** + * A {@link DoubleValuesSource} that scores documents using similarity between a multi-vector query, + * and indexed document multi-vectors. + * + *

This is useful re-ranking query results using late interaction models, where documents and queries + * are represented as multi-vectors of composing token vectors. Document vectors are indexed + * using {@link org.apache.lucene.document.LateInteractionField}. + */ public class LateInteractionFloatValuesSource extends DoubleValuesSource { private final String fieldName; From 22a57732bcffdec2af7ab0ed9d727d804749ec0b Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Wed, 28 May 2025 16:45:47 -0700 Subject: [PATCH 12/18] tidy --- .../lucene/search/LateInteractionFloatValuesSource.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java index 21d16a2eeb7a..84392339cfd2 100644 --- a/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java @@ -25,13 +25,12 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorSimilarityFunction; - /** * A {@link DoubleValuesSource} that scores documents using similarity between a multi-vector query, * and indexed document multi-vectors. * - *

This is useful re-ranking query results using late interaction models, where documents and queries - * are represented as multi-vectors of composing token vectors. Document vectors are indexed + *

This is useful re-ranking query results using late interaction models, where documents and + * queries are represented as multi-vectors of composing token vectors. Document vectors are indexed * using {@link org.apache.lucene.document.LateInteractionField}. */ public class LateInteractionFloatValuesSource extends DoubleValuesSource { From 6163060df6968b9301b51ca000cae7bcd9692795 Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Wed, 28 May 2025 17:14:32 -0700 Subject: [PATCH 13/18] changes entry --- lucene/CHANGES.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index f4115ba9874b..88c1bebad65e 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -121,6 +121,8 @@ New Features * GITHUB#14009: Add a new Query that can rescore other Query based on a generic DoubleValueSource and trim the results down to top N (Anh Dung Bui) +* GITHUB#14729: Support for Re-Ranking Queries using Late Interaction Model Multi-Vectors. (Vigya Sharma, Jim Ferenczi) + Improvements --------------------- * GITHUB#14458: Add an IndexDeletion policy that retains the last N commits. (Owais Kazi) From 94ee3319c015733071aa1680123f59ec299f704b Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Sun, 22 Jun 2025 22:17:07 -0700 Subject: [PATCH 14/18] separate LI Rescorer --- .../search/LateInteractionRescorer.java | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java diff --git a/lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java b/lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java new file mode 100644 index 000000000000..efa949ffa450 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java @@ -0,0 +1,34 @@ +package org.apache.lucene.search; + +import org.apache.lucene.index.VectorSimilarityFunction; + +public class LateInteractionRescorer extends DoubleValuesSourceRescorer { + + public LateInteractionRescorer(LateInteractionFloatValuesSource valuesSource) { + super(valuesSource); + } + + public LateInteractionRescorer(String fieldName, float[][] queryVector) { + this(fieldName, queryVector, VectorSimilarityFunction.COSINE); + } + + public LateInteractionRescorer(String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) { + final LateInteractionFloatValuesSource valuesSource = new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction); + super(valuesSource); + } + + @Override + protected float combine(float firstPassScore, boolean valuePresent, double sourceValue) { + return valuePresent ? (float) sourceValue : 0f; + } + + public static LateInteractionRescorer withFallback(String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) { + final LateInteractionFloatValuesSource valuesSource = new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction); + return new LateInteractionRescorer(valuesSource) { + @Override + protected float combine(float firstPassScore, boolean valuePresent, double sourceValue) { + return valuePresent ? (float) sourceValue : firstPassScore; + } + }; + } +} From 9b7f9bb98ecbccc6e41c8744eec58eaf54ea92f0 Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Tue, 1 Jul 2025 00:17:40 -0700 Subject: [PATCH 15/18] add late interaction rescorer and rescoreTopNQuery --- .../LateInteractionFloatValuesSource.java | 2 + .../search/LateInteractionRescorer.java | 47 ++++++++++++++++--- .../lucene/search/RescoreTopNQuery.java | 17 +++++++ 3 files changed, 60 insertions(+), 6 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java index 84392339cfd2..f45cd4930454 100644 --- a/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java @@ -32,6 +32,8 @@ *

This is useful re-ranking query results using late interaction models, where documents and * queries are represented as multi-vectors of composing token vectors. Document vectors are indexed * using {@link org.apache.lucene.document.LateInteractionField}. + * + * @lucene.experimental */ public class LateInteractionFloatValuesSource extends DoubleValuesSource { diff --git a/lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java b/lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java index efa949ffa450..6e5ec33164a2 100644 --- a/lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java @@ -2,27 +2,62 @@ import org.apache.lucene.index.VectorSimilarityFunction; +/** + * Rescores top N results from a first pass query using a {@link LateInteractionFloatValuesSource} + * + *

Typically, you run a low-cost first pass query to collect results from across the index, then + * use this rescorer to rerank top N hits using multi-vectors, usually from a late interaction model. + * Multi-vectors should be indexed in the {@link org.apache.lucene.document.LateInteractionField} + * provided to rescorer. + * + * @lucene.experimental + */ public class LateInteractionRescorer extends DoubleValuesSourceRescorer { public LateInteractionRescorer(LateInteractionFloatValuesSource valuesSource) { super(valuesSource); } - public LateInteractionRescorer(String fieldName, float[][] queryVector) { - this(fieldName, queryVector, VectorSimilarityFunction.COSINE); + /** + * Creates a LateInteractionRescorer for provided query vector. + */ + public static LateInteractionRescorer create(String fieldName, float[][] queryVector) { + return create(fieldName, queryVector, VectorSimilarityFunction.COSINE); } - public LateInteractionRescorer(String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) { + /** + * Creates a LateInteractionRescorer for provided query vector. + * + *

Top N results from a first pass query are rescored based on the similarity between {@code queryVector} and + * the multi-vector indexed in {@code fieldName}. If document does not have a value indexed in {@code fieldName}, + * a 0f score is assigned. + * + * @param fieldName the {@link org.apache.lucene.document.LateInteractionField} used for reranking. + * @param queryVector query multi-vector to use for similarity comparison + * @param vectorSimilarityFunction function used for vector similarity comparisons + */ + public static LateInteractionRescorer create(String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) { final LateInteractionFloatValuesSource valuesSource = new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction); - super(valuesSource); - } + return new LateInteractionRescorer(valuesSource); + } @Override protected float combine(float firstPassScore, boolean valuePresent, double sourceValue) { return valuePresent ? (float) sourceValue : 0f; } - public static LateInteractionRescorer withFallback(String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) { + /** + * Creates a LateInteractionRescorer for provided query vector. + * + *

Top N results from a first pass query are rescored based on the similarity between {@code queryVector} and + * the multi-vector indexed in {@code fieldName}. Falls back to score from the first pass query if a document + * does not have a value indexed in {@code fieldName}. + * + * @param fieldName the {@link org.apache.lucene.document.LateInteractionField} used for reranking. + * @param queryVector query multi-vector to use for similarity comparison + * @param vectorSimilarityFunction function used for vector similarity comparisons. + */ + public static LateInteractionRescorer withFallbackToFirstPassScore(String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) { final LateInteractionFloatValuesSource valuesSource = new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction); return new LateInteractionRescorer(valuesSource) { @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/RescoreTopNQuery.java b/lucene/core/src/java/org/apache/lucene/search/RescoreTopNQuery.java index 3b9a8fa045c6..1a6902d2174a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/RescoreTopNQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/RescoreTopNQuery.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Objects; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.VectorSimilarityFunction; /** * A Query that re-scores another Query with a {@link DoubleValuesSource} function and cut-off the @@ -148,4 +149,20 @@ public static Query createFullPrecisionRescorerQuery( new FullPrecisionFloatVectorSimilarityValuesSource(targetVector, field); return new RescoreTopNQuery(in, valuaSource, n); } + + /** + * Creates a {@code RescoreTopNQuery} that computes top N results using multi-vector similarity + * comparisons against a late interaction field. + * + * @param in the inner Query to rescore + * @param n number of results to keep + * @param fieldName the {@link org.apache.lucene.document.LateInteractionField} for recomputing top N hits + * @param queryVector query multi-vector to use for similarity comparisons + * @param vectorSimilarityFunction function to use for vector similarity comparisons. + */ + public static Query createLateInteractionQuery( + Query in, int n, String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) { + final LateInteractionFloatValuesSource valuesSource = new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction); + return new RescoreTopNQuery(in, valuesSource, n); + } } From 9108ad1a168755159475a772b23124800b08c8c6 Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Thu, 3 Jul 2025 11:51:32 -0700 Subject: [PATCH 16/18] late I recore query with test and tidy --- .../search/LateInteractionRescorer.java | 40 +++---- .../lucene/search/RescoreTopNQuery.java | 17 ++- .../lucene/search/TestRescoreTopNQuery.java | 88 +++++++++++++++ .../queries/function/FunctionScoreQuery.java | 24 ----- .../function/TestFunctionScoreQuery.java | 101 ------------------ 5 files changed, 124 insertions(+), 146 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java b/lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java index 6e5ec33164a2..99a6ca02935f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java @@ -6,9 +6,9 @@ * Rescores top N results from a first pass query using a {@link LateInteractionFloatValuesSource} * *

Typically, you run a low-cost first pass query to collect results from across the index, then - * use this rescorer to rerank top N hits using multi-vectors, usually from a late interaction model. - * Multi-vectors should be indexed in the {@link org.apache.lucene.document.LateInteractionField} - * provided to rescorer. + * use this rescorer to rerank top N hits using multi-vectors, usually from a late interaction + * model. Multi-vectors should be indexed in the {@link + * org.apache.lucene.document.LateInteractionField} provided to rescorer. * * @lucene.experimental */ @@ -18,9 +18,7 @@ public LateInteractionRescorer(LateInteractionFloatValuesSource valuesSource) { super(valuesSource); } - /** - * Creates a LateInteractionRescorer for provided query vector. - */ + /** Creates a LateInteractionRescorer for provided query vector. */ public static LateInteractionRescorer create(String fieldName, float[][] queryVector) { return create(fieldName, queryVector, VectorSimilarityFunction.COSINE); } @@ -28,16 +26,19 @@ public static LateInteractionRescorer create(String fieldName, float[][] queryVe /** * Creates a LateInteractionRescorer for provided query vector. * - *

Top N results from a first pass query are rescored based on the similarity between {@code queryVector} and - * the multi-vector indexed in {@code fieldName}. If document does not have a value indexed in {@code fieldName}, - * a 0f score is assigned. + *

Top N results from a first pass query are rescored based on the similarity between {@code + * queryVector} and the multi-vector indexed in {@code fieldName}. If document does not have a + * value indexed in {@code fieldName}, a 0f score is assigned. * - * @param fieldName the {@link org.apache.lucene.document.LateInteractionField} used for reranking. + * @param fieldName the {@link org.apache.lucene.document.LateInteractionField} used for + * reranking. * @param queryVector query multi-vector to use for similarity comparison * @param vectorSimilarityFunction function used for vector similarity comparisons */ - public static LateInteractionRescorer create(String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) { - final LateInteractionFloatValuesSource valuesSource = new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction); + public static LateInteractionRescorer create( + String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) { + final LateInteractionFloatValuesSource valuesSource = + new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction); return new LateInteractionRescorer(valuesSource); } @@ -49,16 +50,19 @@ protected float combine(float firstPassScore, boolean valuePresent, double sourc /** * Creates a LateInteractionRescorer for provided query vector. * - *

Top N results from a first pass query are rescored based on the similarity between {@code queryVector} and - * the multi-vector indexed in {@code fieldName}. Falls back to score from the first pass query if a document - * does not have a value indexed in {@code fieldName}. + *

Top N results from a first pass query are rescored based on the similarity between {@code + * queryVector} and the multi-vector indexed in {@code fieldName}. Falls back to score from the + * first pass query if a document does not have a value indexed in {@code fieldName}. * - * @param fieldName the {@link org.apache.lucene.document.LateInteractionField} used for reranking. + * @param fieldName the {@link org.apache.lucene.document.LateInteractionField} used for + * reranking. * @param queryVector query multi-vector to use for similarity comparison * @param vectorSimilarityFunction function used for vector similarity comparisons. */ - public static LateInteractionRescorer withFallbackToFirstPassScore(String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) { - final LateInteractionFloatValuesSource valuesSource = new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction); + public static LateInteractionRescorer withFallbackToFirstPassScore( + String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) { + final LateInteractionFloatValuesSource valuesSource = + new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction); return new LateInteractionRescorer(valuesSource) { @Override protected float combine(float firstPassScore, boolean valuePresent, double sourceValue) { diff --git a/lucene/core/src/java/org/apache/lucene/search/RescoreTopNQuery.java b/lucene/core/src/java/org/apache/lucene/search/RescoreTopNQuery.java index 1a6902d2174a..523d194c0e87 100644 --- a/lucene/core/src/java/org/apache/lucene/search/RescoreTopNQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/RescoreTopNQuery.java @@ -154,15 +154,26 @@ public static Query createFullPrecisionRescorerQuery( * Creates a {@code RescoreTopNQuery} that computes top N results using multi-vector similarity * comparisons against a late interaction field. * + *

Note: This query computes late interaction field similarity for the entire match-set of + * wrapped query, and returns a new query with only top-N hits in the match-set. This is typically + * useful in combining a query's results with other queries for hybrid search. To simply rerank + * the top N hits without scoring entire match-set, see {@link LateInteractionRescorer}. + * * @param in the inner Query to rescore * @param n number of results to keep - * @param fieldName the {@link org.apache.lucene.document.LateInteractionField} for recomputing top N hits + * @param fieldName the {@link org.apache.lucene.document.LateInteractionField} for recomputing + * top N hits * @param queryVector query multi-vector to use for similarity comparisons * @param vectorSimilarityFunction function to use for vector similarity comparisons. */ public static Query createLateInteractionQuery( - Query in, int n, String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) { - final LateInteractionFloatValuesSource valuesSource = new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction); + Query in, + int n, + String fieldName, + float[][] queryVector, + VectorSimilarityFunction vectorSimilarityFunction) { + final LateInteractionFloatValuesSource valuesSource = + new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction); return new RescoreTopNQuery(in, valuesSource, n); } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestRescoreTopNQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestRescoreTopNQuery.java index e1be2c81822a..df21442e2b5f 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestRescoreTopNQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestRescoreTopNQuery.java @@ -17,18 +17,25 @@ package org.apache.lucene.search; import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Random; +import java.util.Set; +import java.util.stream.Collectors; import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.IntField; import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.LateInteractionField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.ByteBuffersDirectory; @@ -156,6 +163,87 @@ public void testMissingDoubleValues() throws IOException { } } + public void testLateInteractionQuery() throws Exception { + final String LATE_I_FIELD = "li_vector"; + final String KNN_FIELD = "knn_vector"; + List corpus = new ArrayList<>(); + final int numDocs = atLeast(1000); + final int numSegments = random().nextInt(2, 10); + final int dim = 128; + final VectorSimilarityFunction vectorSimilarityFunction = + VectorSimilarityFunction.values()[ + random().nextInt(VectorSimilarityFunction.values().length)]; + LateInteractionFloatValuesSource.ScoreFunction scoreFunction = + LateInteractionFloatValuesSource.ScoreFunction.values()[ + random().nextInt(LateInteractionFloatValuesSource.ScoreFunction.values().length)]; + + try (Directory dir = newDirectory()) { + int id = 0; + try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int j = 0; j < numSegments; j++) { + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + if (random().nextInt(100) < 30) { + // skip value for some docs to create sparse field + doc.add(new IntField("has_li_vector", 0, Field.Store.YES)); + } else { + float[][] value = createMultiVector(dim); + corpus.add(value); + doc.add(new IntField("id", id++, Field.Store.YES)); + doc.add(new LateInteractionField(LATE_I_FIELD, value)); + doc.add(new KnnFloatVectorField(KNN_FIELD, randomFloatVector(dim, random()))); + doc.add(new IntField("has_li_vector", 1, Field.Store.YES)); + } + w.addDocument(doc); + w.flush(); + } + } + // add a segment with no vectors + for (int i = 0; i < 100; i++) { + Document doc = new Document(); + doc.add(new IntField("has_li_vector", 0, Field.Store.YES)); + w.addDocument(doc); + } + w.flush(); + } + + float[][] lateIQueryVector = createMultiVector(dim); + float[] knnQueryVector = randomFloatVector(dim, random()); + KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery(KNN_FIELD, knnQueryVector, 50); + + try (IndexReader reader = DirectoryReader.open(dir)) { + final int topN = 10; + IndexSearcher s = new IndexSearcher(reader); + TopDocs knnHits = s.search(knnQuery, 5 * topN); + Set knnHitDocs = + Arrays.stream(knnHits.scoreDocs).map(k -> k.doc).collect(Collectors.toSet()); + Query lateIQuery = + RescoreTopNQuery.createLateInteractionQuery( + knnQuery, topN, LATE_I_FIELD, lateIQueryVector, vectorSimilarityFunction); + TopDocs lateIHits = s.search(lateIQuery, 3 * topN); + // total match-set for RescoreTopNQuery is topN + assertEquals(topN, lateIHits.scoreDocs.length); + StoredFields storedFields = reader.storedFields(); + for (ScoreDoc hit : lateIHits.scoreDocs) { + assertTrue(knnHitDocs.contains(hit.doc)); + int idValue = Integer.parseInt(storedFields.document(hit.doc).get("id")); + float[][] docVector = corpus.get(idValue); + float expected = + scoreFunction.compare(lateIQueryVector, docVector, vectorSimilarityFunction); + assertEquals(expected, hit.score, 1e-5); + } + } + } + } + + private float[][] createMultiVector(int dimension) { + float[][] value = new float[random().nextInt(3, 12)][]; + for (int i = 0; i < value.length; i++) { + value[i] = randomFloatVector(dimension, random()); + } + return value; + } + private float[] randomFloatVector(int dimension, Random random) { float[] vector = new float[dimension]; for (int i = 0; i < dimension; i++) { diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java index e8af8b5d0b37..d16749f73868 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java @@ -20,14 +20,12 @@ import java.io.IOException; import java.util.Objects; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.DoubleValues; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.FilterScorer; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.LateInteractionFloatValuesSource; import org.apache.lucene.search.Matches; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -73,28 +71,6 @@ public DoubleValuesSource getSource() { return source; } - /** - * Returns a FunctionScoreQuery that re-scores hits from the wrapped query using late-interaction - * scores between provided query and indexed document multi-vectors. - * - *

Document multi-vectors are indexed using {@link - * org.apache.lucene.document.LateInteractionField}. - * - * @param in the query to re-score - * @param fieldName field containing document multi-vectors for re-scoring - * @param queryVector query multi-vector - * @param vectorSimilarityFunction vector similarity function used for computing scores - */ - public static FunctionScoreQuery lateInteractionFloatRerankQuery( - Query in, - String fieldName, - float[][] queryVector, - VectorSimilarityFunction vectorSimilarityFunction) { - LateInteractionFloatValuesSource scoreSource = - new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction); - return new FunctionScoreQuery(in, scoreSource); - } - /** * Returns a FunctionScoreQuery where the scores of a wrapped query are multiplied by the value of * a DoubleValuesSource. diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java index ce5034a50ed0..85bcc05cbf78 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java @@ -18,18 +18,9 @@ package org.apache.lucene.queries.function; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Random; -import java.util.Set; import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; -import org.apache.lucene.document.IntField; -import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.document.LateInteractionField; import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.TextField; import org.apache.lucene.expressions.Expression; @@ -40,21 +31,16 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.Term; -import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.KnnFloatVectorQuery; -import org.apache.lucene.search.LateInteractionFloatValuesSource; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.PhraseQuery; import org.apache.lucene.search.Query; -import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TopDocs; @@ -391,91 +377,4 @@ public void testQueryMatchesCount() throws Exception { } assertEquals(searchCount, weightCount); } - - public void testLateInteractionQuery() throws Exception { - final String LATE_I_FIELD = "li_vector"; - final String KNN_FIELD = "knn_vector"; - List corpus = new ArrayList<>(); - final int numDocs = atLeast(1000); - final int numSegments = random().nextInt(2, 10); - final int dim = 128; - final VectorSimilarityFunction vectorSimilarityFunction = - VectorSimilarityFunction.values()[ - random().nextInt(VectorSimilarityFunction.values().length)]; - LateInteractionFloatValuesSource.ScoreFunction scoreFunction = - LateInteractionFloatValuesSource.ScoreFunction.values()[ - random().nextInt(LateInteractionFloatValuesSource.ScoreFunction.values().length)]; - - try (Directory dir = newDirectory()) { - int id = 0; - try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { - for (int j = 0; j < numSegments; j++) { - for (int i = 0; i < numDocs; i++) { - Document doc = new Document(); - if (random().nextInt(100) < 30) { - // skip value for some docs to create sparse field - doc.add(new IntField("has_li_vector", 0, Field.Store.YES)); - } else { - float[][] value = createMultiVector(dim); - corpus.add(value); - doc.add(new IntField("id", id++, Field.Store.YES)); - doc.add(new LateInteractionField(LATE_I_FIELD, value)); - doc.add(new KnnFloatVectorField(KNN_FIELD, randomVector(dim))); - doc.add(new IntField("has_li_vector", 1, Field.Store.YES)); - } - w.addDocument(doc); - w.flush(); - } - } - // add a segment with no vectors - for (int i = 0; i < 100; i++) { - Document doc = new Document(); - doc.add(new IntField("has_li_vector", 0, Field.Store.YES)); - w.addDocument(doc); - } - w.flush(); - } - - float[][] lateIQueryVector = createMultiVector(dim); - float[] knnQueryVector = randomVector(dim); - KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery(KNN_FIELD, knnQueryVector, 50); - - try (IndexReader reader = DirectoryReader.open(dir)) { - IndexSearcher s = new IndexSearcher(reader); - TopDocs knnHits = s.search(knnQuery, 50); - Set knnHitDocs = - Arrays.stream(knnHits.scoreDocs).map(k -> k.doc).collect(Collectors.toSet()); - FunctionScoreQuery lateIQuery = - FunctionScoreQuery.lateInteractionFloatRerankQuery( - knnQuery, LATE_I_FIELD, lateIQueryVector, vectorSimilarityFunction); - TopDocs lateIHits = s.search(lateIQuery, 10); - StoredFields storedFields = reader.storedFields(); - for (ScoreDoc hit : lateIHits.scoreDocs) { - assertTrue(knnHitDocs.contains(hit.doc)); - int idValue = Integer.parseInt(storedFields.document(hit.doc).get("id")); - float[][] docVector = corpus.get(idValue); - float expected = - scoreFunction.compare(lateIQueryVector, docVector, vectorSimilarityFunction); - assertEquals(expected, hit.score, 1e-5); - } - } - } - } - - private float[] randomVector(int dim) { - float[] v = new float[dim]; - Random random = random(); - for (int i = 0; i < dim; i++) { - v[i] = random.nextFloat(); - } - return v; - } - - private float[][] createMultiVector(int dimension) { - float[][] value = new float[random().nextInt(3, 12)][]; - for (int i = 0; i < value.length; i++) { - value[i] = randomVector(dimension); - } - return value; - } } From df32233721adf932a7c08c62880ed095c571e11b Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Thu, 3 Jul 2025 12:40:46 -0700 Subject: [PATCH 17/18] test for lateI rescorer --- .../search/LateInteractionRescorer.java | 17 ++ .../search/TestLateInteractionRescorer.java | 180 ++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 lucene/core/src/test/org/apache/lucene/search/TestLateInteractionRescorer.java diff --git a/lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java b/lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java index 99a6ca02935f..921614e9d625 100644 --- a/lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.lucene.search; import org.apache.lucene.index.VectorSimilarityFunction; diff --git a/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionRescorer.java b/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionRescorer.java new file mode 100644 index 000000000000..c3685434591a --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestLateInteractionRescorer.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import static org.apache.lucene.search.LateInteractionFloatValuesSource.ScoreFunction; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.LateInteractionField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.StoredFields; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.LuceneTestCase; + +public class TestLateInteractionRescorer extends LuceneTestCase { + + private final String LATE_I_FIELD = "li_vector"; + private final String KNN_FIELD = "knn_vector"; + private final int DIMENSION = 128; + + public void testBasic() throws Exception { + List corpus = new ArrayList<>(); + final VectorSimilarityFunction vectorSimilarityFunction = + VectorSimilarityFunction.values()[ + random().nextInt(VectorSimilarityFunction.values().length)]; + ScoreFunction scoreFunction = ScoreFunction.SUM_MAX_SIM; + + try (Directory dir = newDirectory()) { + indexMultiVectors(dir, corpus); + float[][] lateIQueryVector = createMultiVector(DIMENSION); + float[] knnQueryVector = randomFloatVector(DIMENSION, random()); + KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery(KNN_FIELD, knnQueryVector, 50); + + try (IndexReader reader = DirectoryReader.open(dir)) { + final int topN = 10; + IndexSearcher s = new IndexSearcher(reader); + TopDocs knnHits = s.search(knnQuery, 5 * topN); + LateInteractionRescorer rescorer = + LateInteractionRescorer.create( + LATE_I_FIELD, lateIQueryVector, vectorSimilarityFunction); + TopDocs rerankedHits = rescorer.rescore(s, knnHits, topN); + Set knnHitDocs = + Arrays.stream(knnHits.scoreDocs).map(k -> k.doc).collect(Collectors.toSet()); + assertEquals(topN, rerankedHits.scoreDocs.length); + StoredFields storedFields = reader.storedFields(); + for (int i = 0; i < rerankedHits.scoreDocs.length; i++) { + assertTrue(knnHitDocs.contains(rerankedHits.scoreDocs[i].doc)); + int idValue = + Integer.parseInt(storedFields.document(rerankedHits.scoreDocs[i].doc).get("id")); + float[][] docVector = corpus.get(idValue); + float expected = + scoreFunction.compare(lateIQueryVector, docVector, vectorSimilarityFunction); + assertEquals(expected, rerankedHits.scoreDocs[i].score, 1e-5); + if (i > 0) { + assertTrue(rerankedHits.scoreDocs[i].score <= rerankedHits.scoreDocs[i - 1].score); + } + } + } + } + } + + public void testMissingLateIValues() throws Exception { + List corpus = new ArrayList<>(); + final VectorSimilarityFunction vectorSimilarityFunction = + VectorSimilarityFunction.values()[ + random().nextInt(VectorSimilarityFunction.values().length)]; + + try (Directory dir = newDirectory()) { + indexMultiVectors(dir, corpus); + float[][] lateIQueryVector = createMultiVector(DIMENSION); + float[] knnQueryVector = randomFloatVector(DIMENSION, random()); + KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery(KNN_FIELD, knnQueryVector, 50); + + try (IndexReader reader = DirectoryReader.open(dir)) { + final int topN = 10; + IndexSearcher s = new IndexSearcher(reader); + TopDocs knnHits = s.search(knnQuery, 5 * topN); + LateInteractionRescorer rescorer = + LateInteractionRescorer.create( + "bad-test-field", lateIQueryVector, vectorSimilarityFunction); + TopDocs rerankedHits = rescorer.rescore(s, knnHits, topN); + Set knnHitDocs = + Arrays.stream(knnHits.scoreDocs).map(k -> k.doc).collect(Collectors.toSet()); + assertEquals(topN, rerankedHits.scoreDocs.length); + for (int i = 0; i < rerankedHits.scoreDocs.length; i++) { + assertTrue(knnHitDocs.contains(rerankedHits.scoreDocs[i].doc)); + assertEquals(0f, rerankedHits.scoreDocs[i].score, 1e-5); + } + + LateInteractionRescorer rescorerWithFallback = + LateInteractionRescorer.withFallbackToFirstPassScore( + "bad-test-field", lateIQueryVector, vectorSimilarityFunction); + knnHits = s.search(knnQuery, 5 * topN); + rerankedHits = rescorerWithFallback.rescore(s, knnHits, topN); + knnHitDocs = Arrays.stream(knnHits.scoreDocs).map(k -> k.doc).collect(Collectors.toSet()); + assertEquals(topN, rerankedHits.scoreDocs.length); + for (int i = 0; i < rerankedHits.scoreDocs.length; i++) { + assertTrue(knnHitDocs.contains(rerankedHits.scoreDocs[i].doc)); + assertEquals(knnHits.scoreDocs[i].score, rerankedHits.scoreDocs[i].score, 1e-5); + } + } + } + } + + private void indexMultiVectors(Directory dir, List corpus) throws IOException { + final int numDocs = atLeast(1000); + final int numSegments = random().nextInt(2, 10); + int id = 0; + try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int j = 0; j < numSegments; j++) { + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + if (random().nextInt(100) < 30) { + // skip value for some docs to create sparse field + doc.add(new IntField("has_li_vector", 0, Field.Store.YES)); + } else { + float[][] value = createMultiVector(DIMENSION); + corpus.add(value); + doc.add(new IntField("id", id++, Field.Store.YES)); + doc.add(new LateInteractionField(LATE_I_FIELD, value)); + doc.add(new KnnFloatVectorField(KNN_FIELD, randomFloatVector(DIMENSION, random()))); + doc.add(new IntField("has_li_vector", 1, Field.Store.YES)); + } + w.addDocument(doc); + w.flush(); + } + } + // add a segment with no vectors + for (int i = 0; i < 100; i++) { + Document doc = new Document(); + doc.add(new IntField("has_li_vector", 0, Field.Store.YES)); + w.addDocument(doc); + } + w.flush(); + } + } + + private float[][] createMultiVector(int dimension) { + float[][] value = new float[random().nextInt(3, 12)][]; + for (int i = 0; i < value.length; i++) { + value[i] = randomFloatVector(dimension, random()); + } + return value; + } + + private float[] randomFloatVector(int dimension, Random random) { + float[] vector = new float[dimension]; + for (int i = 0; i < dimension; i++) { + vector[i] = random.nextFloat(); + } + return vector; + } +} From e824de394848ca8eeda58202ec09d0f25c0da109 Mon Sep 17 00:00:00 2001 From: vigyasharma Date: Thu, 3 Jul 2025 14:06:42 -0700 Subject: [PATCH 18/18] use multi-vec similarity interface --- .../LateInteractionFloatValuesSource.java | 21 ++-------- .../lucene/search/MultiVectorSimilarity.java | 41 +++++++++++++++++++ 2 files changed, 45 insertions(+), 17 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/search/MultiVectorSimilarity.java diff --git a/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java index f45cd4930454..3bf3bcb3d0b9 100644 --- a/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/LateInteractionFloatValuesSource.java @@ -40,7 +40,7 @@ public class LateInteractionFloatValuesSource extends DoubleValuesSource { private final String fieldName; private final float[][] queryVector; private final VectorSimilarityFunction vectorSimilarityFunction; - private final ScoreFunction scoreFunction; + private final MultiVectorSimilarity scoreFunction; public LateInteractionFloatValuesSource(String fieldName, float[][] queryVector) { this(fieldName, queryVector, VectorSimilarityFunction.COSINE, ScoreFunction.SUM_MAX_SIM); @@ -55,7 +55,7 @@ public LateInteractionFloatValuesSource( String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction, - ScoreFunction scoreFunction) { + MultiVectorSimilarity scoreFunction) { this.fieldName = Objects.requireNonNull(fieldName); this.queryVector = validateQueryVector(queryVector); this.vectorSimilarityFunction = Objects.requireNonNull(vectorSimilarityFunction); @@ -136,7 +136,7 @@ public String toString() { + " similarityFunction=" + vectorSimilarityFunction + " scoreFunction=" - + scoreFunction.name() + + scoreFunction.getClass() + " queryVector=" + Arrays.deepToString(queryVector) + ")"; @@ -148,7 +148,7 @@ public boolean isCacheable(LeafReaderContext ctx) { } /** Defines the function to compute similarity score between query and document multi-vectors */ - public enum ScoreFunction { + public enum ScoreFunction implements MultiVectorSimilarity { /** Computes the sum of max similarity between query and document vectors */ SUM_MAX_SIM { @@ -179,18 +179,5 @@ public float compare( return result; } }; - - /** - * Computes similarity between two multi-vectors using provided {@link VectorSimilarityFunction} - * - *

Provided multi-vectors can have varying number of composing token vectors, but their token - * vectors should have the same dimension. - * - * @param outer a multi-vector - * @param inner another multi-vector - * @return similarity score between two multi-vectors - */ - public abstract float compare( - float[][] outer, float[][] inner, VectorSimilarityFunction vectorSimilarityFunction); } } diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiVectorSimilarity.java b/lucene/core/src/java/org/apache/lucene/search/MultiVectorSimilarity.java new file mode 100644 index 000000000000..2816f97b14a2 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/MultiVectorSimilarity.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import org.apache.lucene.index.VectorSimilarityFunction; + +/** + * Interface to define the similarity function between multi-vectors + * + * @lucene.experimental + */ +public interface MultiVectorSimilarity { + + /** + * Computes similarity between two multi-vectors using provided {@link VectorSimilarityFunction} + * + *

Provided multi-vectors can have varying number of composing token vectors, but their token + * vectors should have the same dimension. + * + * @param outer a multi-vector + * @param inner another multi-vector + * @return similarity score between two multi-vectors + */ + float compare( + float[][] outer, float[][] inner, VectorSimilarityFunction vectorSimilarityFunction); +}