Skip to content

Commit 7de24ce

Browse files
committed
Support for Re-Ranking Queries using Late Interaction Model Multi-Vectors. (#14729)
1 parent 2a5ae4b commit 7de24ce

File tree

10 files changed

+996
-0
lines changed

10 files changed

+996
-0
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ New Features
3232
* GITHUB#14009: Add a new Query that can rescore other Query based on a generic DoubleValueSource
3333
and trim the results down to top N (Anh Dung Bui)
3434

35+
* GITHUB#14729: Support for Re-Ranking Queries using Late Interaction Model Multi-Vectors. (Vigya Sharma, Jim Ferenczi)
36+
3537
Improvements
3638
---------------------
3739
* GITHUB#14458: Add an IndexDeletion policy that retains the last N commits. (Owais Kazi)
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.lucene.document;
19+
20+
import java.nio.ByteBuffer;
21+
import java.nio.ByteOrder;
22+
import java.nio.FloatBuffer;
23+
import org.apache.lucene.util.BytesRef;
24+
25+
/**
26+
* A field for storing multi-vector values for late interaction models.
27+
*
28+
* <p>The value is stored as a binary payload, and can be retrieved using {@link
29+
* LateInteractionField#decode(BytesRef)}. Multi-vectors are expected to have the same dimension for
30+
* each composing token vector. This is stored along with the token vectors in the first 4 bytes of
31+
* the payload.
32+
*
33+
* <p>Note: This field does not ensure consistency in token vector dimensions for values across
34+
* documents.
35+
*/
36+
public class LateInteractionField extends BinaryDocValuesField {
37+
38+
/**
39+
* Creates a new {@link LateInteractionField} from the provided multi-vector matrix.
40+
*
41+
* @param name field name
42+
* @param value multi-vector value
43+
*/
44+
public LateInteractionField(String name, float[][] value) {
45+
super(name, encode(value));
46+
}
47+
48+
/**
49+
* Set multi-vector value for the field.
50+
*
51+
* <p>Value should not be null or empty. Composing token vectors for provided multi-vector value
52+
* should have the same dimension.
53+
*/
54+
public void setValue(float[][] value) {
55+
this.fieldsData = encode(value);
56+
}
57+
58+
/** Returns the multi-vector value stored in this field */
59+
public float[][] getValue() {
60+
return decode((BytesRef) fieldsData);
61+
}
62+
63+
/**
64+
* Encodes provided multi-vector matrix into a {@link BytesRef} that can be stored in the {@link
65+
* LateInteractionField}.
66+
*
67+
* <p>Composing token vectors for the multi-vector are expected to have the same dimension, which
68+
* is stored along with the token vectors in the first 4 bytes of the payload. Use {@link
69+
* LateInteractionField#decode(BytesRef)} to retrieve the multi-vector.
70+
*
71+
* @param value Multi-Vector to encode
72+
* @return BytesRef representation for provided multi-vector
73+
*/
74+
public static BytesRef encode(float[][] value) {
75+
if (value == null || value.length == 0) {
76+
throw new IllegalArgumentException("Value should not be null or empty");
77+
}
78+
if (value[0] == null || value[0].length == 0) {
79+
throw new IllegalArgumentException("Composing token vectors should not be null or empty");
80+
}
81+
final int tokenVectorDimension = value[0].length;
82+
final ByteBuffer buffer =
83+
ByteBuffer.allocate(Integer.BYTES + value.length * tokenVectorDimension * Float.BYTES)
84+
.order(ByteOrder.LITTLE_ENDIAN);
85+
// TODO: Should we store dimension in FieldType to ensure consistency across all documents?
86+
buffer.putInt(tokenVectorDimension);
87+
FloatBuffer floatBuffer = buffer.asFloatBuffer();
88+
for (int i = 0; i < value.length; i++) {
89+
if (value[i].length != tokenVectorDimension) {
90+
throw new IllegalArgumentException(
91+
"Composing token vectors should have the same dimension. "
92+
+ "Mismatching dimensions detected between token[0] and token["
93+
+ i
94+
+ "], "
95+
+ value[0].length
96+
+ " != "
97+
+ value[i].length);
98+
}
99+
floatBuffer.put(value[i]);
100+
}
101+
return new BytesRef(buffer.array());
102+
}
103+
104+
/**
105+
* Decodes provided {@link BytesRef} into a multi-vector matrix.
106+
*
107+
* <p>The token vectors are expected to have the same dimension, which is stored along with the
108+
* token vectors in the first 4 bytes of the payload. Meant to be used as a counterpart to {@link
109+
* LateInteractionField#encode(float[][])}
110+
*
111+
* @param payload to decode into multi-vector value
112+
* @return decoded multi-vector value
113+
*/
114+
public static float[][] decode(BytesRef payload) {
115+
final ByteBuffer buffer = ByteBuffer.wrap(payload.bytes, payload.offset, payload.length);
116+
buffer.order(ByteOrder.LITTLE_ENDIAN);
117+
final int tokenVectorDimension = buffer.getInt();
118+
int numVectors = (payload.length - Integer.BYTES) / (tokenVectorDimension * Float.BYTES);
119+
if (numVectors * tokenVectorDimension * Float.BYTES + Integer.BYTES != payload.length) {
120+
throw new IllegalArgumentException(
121+
"Provided payload does not appear to have been encoded via LateInteractionField.encode. "
122+
+ "Payload length should be equal to 4 + numVectors * tokenVectorDimension, "
123+
+ "got "
124+
+ payload.length
125+
+ " != 4 + "
126+
+ numVectors
127+
+ " * "
128+
+ tokenVectorDimension);
129+
}
130+
var floatBuffer = buffer.asFloatBuffer();
131+
float[][] value = new float[numVectors][];
132+
for (int i = 0; i < numVectors; i++) {
133+
value[i] = new float[tokenVectorDimension];
134+
floatBuffer.get(value[i]);
135+
}
136+
return value;
137+
}
138+
}
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.lucene.search;
19+
20+
import java.io.IOException;
21+
import java.util.Arrays;
22+
import java.util.Objects;
23+
import org.apache.lucene.document.LateInteractionField;
24+
import org.apache.lucene.index.BinaryDocValues;
25+
import org.apache.lucene.index.LeafReaderContext;
26+
import org.apache.lucene.index.VectorSimilarityFunction;
27+
28+
/**
29+
* A {@link DoubleValuesSource} that scores documents using similarity between a multi-vector query,
30+
* and indexed document multi-vectors.
31+
*
32+
* <p>This is useful re-ranking query results using late interaction models, where documents and
33+
* queries are represented as multi-vectors of composing token vectors. Document vectors are indexed
34+
* using {@link org.apache.lucene.document.LateInteractionField}.
35+
*
36+
* @lucene.experimental
37+
*/
38+
public class LateInteractionFloatValuesSource extends DoubleValuesSource {
39+
40+
private final String fieldName;
41+
private final float[][] queryVector;
42+
private final VectorSimilarityFunction vectorSimilarityFunction;
43+
private final MultiVectorSimilarity scoreFunction;
44+
45+
public LateInteractionFloatValuesSource(String fieldName, float[][] queryVector) {
46+
this(fieldName, queryVector, VectorSimilarityFunction.COSINE, ScoreFunction.SUM_MAX_SIM);
47+
}
48+
49+
public LateInteractionFloatValuesSource(
50+
String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) {
51+
this(fieldName, queryVector, vectorSimilarityFunction, ScoreFunction.SUM_MAX_SIM);
52+
}
53+
54+
public LateInteractionFloatValuesSource(
55+
String fieldName,
56+
float[][] queryVector,
57+
VectorSimilarityFunction vectorSimilarityFunction,
58+
MultiVectorSimilarity scoreFunction) {
59+
this.fieldName = Objects.requireNonNull(fieldName);
60+
this.queryVector = validateQueryVector(queryVector);
61+
this.vectorSimilarityFunction = Objects.requireNonNull(vectorSimilarityFunction);
62+
this.scoreFunction = Objects.requireNonNull(scoreFunction);
63+
}
64+
65+
private float[][] validateQueryVector(float[][] queryVector) {
66+
if (queryVector == null || queryVector.length == 0) {
67+
throw new IllegalArgumentException("queryVector must not be null or empty");
68+
}
69+
if (queryVector[0] == null || queryVector[0].length == 0) {
70+
throw new IllegalArgumentException(
71+
"composing token vectors in provided query vector should not be null or empty");
72+
}
73+
for (int i = 1; i < queryVector.length; i++) {
74+
if (queryVector[i] == null || queryVector[i].length != queryVector[0].length) {
75+
throw new IllegalArgumentException(
76+
"all composing token vectors in provided query vector should have the same length");
77+
}
78+
}
79+
return queryVector;
80+
}
81+
82+
@Override
83+
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
84+
BinaryDocValues values = ctx.reader().getBinaryDocValues(fieldName);
85+
if (values == null) {
86+
return DoubleValues.EMPTY;
87+
}
88+
89+
return new DoubleValues() {
90+
@Override
91+
public double doubleValue() throws IOException {
92+
return scoreFunction.compare(
93+
queryVector,
94+
LateInteractionField.decode(values.binaryValue()),
95+
vectorSimilarityFunction);
96+
}
97+
98+
@Override
99+
public boolean advanceExact(int doc) throws IOException {
100+
return values.advanceExact(doc);
101+
}
102+
};
103+
}
104+
105+
@Override
106+
public boolean needsScores() {
107+
return false;
108+
}
109+
110+
@Override
111+
public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
112+
return this;
113+
}
114+
115+
@Override
116+
public int hashCode() {
117+
return Objects.hash(
118+
fieldName, Arrays.deepHashCode(queryVector), vectorSimilarityFunction, scoreFunction);
119+
}
120+
121+
@Override
122+
public boolean equals(Object obj) {
123+
if (this == obj) return true;
124+
if (obj == null || getClass() != obj.getClass()) return false;
125+
LateInteractionFloatValuesSource other = (LateInteractionFloatValuesSource) obj;
126+
return Objects.equals(fieldName, other.fieldName)
127+
&& vectorSimilarityFunction == other.vectorSimilarityFunction
128+
&& scoreFunction == other.scoreFunction
129+
&& Arrays.deepEquals(queryVector, other.queryVector);
130+
}
131+
132+
@Override
133+
public String toString() {
134+
return "LateInteractionFloatValuesSource(fieldName="
135+
+ fieldName
136+
+ " similarityFunction="
137+
+ vectorSimilarityFunction
138+
+ " scoreFunction="
139+
+ scoreFunction.getClass()
140+
+ " queryVector="
141+
+ Arrays.deepToString(queryVector)
142+
+ ")";
143+
}
144+
145+
@Override
146+
public boolean isCacheable(LeafReaderContext ctx) {
147+
return true;
148+
}
149+
150+
/** Defines the function to compute similarity score between query and document multi-vectors */
151+
public enum ScoreFunction implements MultiVectorSimilarity {
152+
153+
/** Computes the sum of max similarity between query and document vectors */
154+
SUM_MAX_SIM {
155+
@Override
156+
public float compare(
157+
float[][] queryVector,
158+
float[][] docVector,
159+
VectorSimilarityFunction vectorSimilarityFunction) {
160+
if (docVector.length == 0) {
161+
return Float.MIN_VALUE;
162+
}
163+
float result = 0f;
164+
for (float[] q : queryVector) {
165+
float maxSim = Float.MIN_VALUE;
166+
for (float[] d : docVector) {
167+
if (q.length != d.length) {
168+
throw new IllegalArgumentException(
169+
"Provided multi-vectors are incompatible. "
170+
+ "Their composing token vectors should have the same dimension, got "
171+
+ q.length
172+
+ " != "
173+
+ d.length);
174+
}
175+
maxSim = Float.max(maxSim, vectorSimilarityFunction.compare(q, d));
176+
}
177+
result += maxSim;
178+
}
179+
return result;
180+
}
181+
};
182+
}
183+
}

0 commit comments

Comments
 (0)