Skip to content

Commit 5e18d3a

Browse files
committed
try flattening the corrective terms into the node
1 parent 78388d3 commit 5e18d3a

File tree

2 files changed

+78
-13
lines changed

2 files changed

+78
-13
lines changed

lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizedVectorSimilarity.java

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ public float score(
105105
float ly = (queryCorrections.upperInterval() - ay) * queryScale;
106106
float y1 = queryCorrections.quantizedComponentSum();
107107
float score = ax * ay * dimensions + ay * lx * x1 + ax * ly * y1 + lx * ly * dotProduct;
108-
// For euclidean, we need to invert the score and apply the additional correction, which is
108+
// For euclidean, we need to invert the score and apply the additional
109+
// correction, which is
109110
// assumed to be the squared l2norm of the centroid centered vectors.
110111
if (similarityFunction == EUCLIDEAN) {
111112
score =
@@ -114,8 +115,10 @@ public float score(
114115
- 2 * score;
115116
return Math.max(1 / (1f + score), 0);
116117
} else {
117-
// For cosine and max inner product, we need to apply the additional correction, which is
118-
// assumed to be the non-centered dot-product between the vector and the centroid
118+
// For cosine and max inner product, we need to apply the additional correction,
119+
// which is
120+
// assumed to be the non-centered dot-product between the vector and the
121+
// centroid
119122
score +=
120123
queryCorrections.additionalCorrection()
121124
+ indexCorrections.additionalCorrection()
@@ -126,4 +129,52 @@ public float score(
126129
return Math.max((1f + score) / 2f, 0);
127130
}
128131
}
132+
133+
// XXX DO NOT MERGE duplication with above.
134+
/**
135+
* Computes the similarity score between a 'query' and an 'index' quantized vector, given the dot
136+
* product of the two vectors and their corrective factors.
137+
*
138+
* @param dotProduct - dot product of the two quantized vectors.
139+
* @param queryCorrections - corrective factors for vector 'y'.
140+
* @param indexLowerInterval - corrective factors for vector 'x'.
141+
* @param indexUpperInterval - corrective factors for vector 'x'.
142+
* @param indexAdditionalCorrection - corrective factors for vector 'x'.
143+
* @param indexQuantizedComponentSum - corrective factors for vector 'x'.
144+
* @return - a similarity score value between 0 and 1; higher values are better.
145+
*/
146+
public float score(
147+
float dotProduct,
148+
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
149+
float indexLowerInterval,
150+
float indexUpperInterval,
151+
float indexAdditionalCorrection,
152+
int indexQuantizedComponentSum) {
153+
float x1 = indexQuantizedComponentSum;
154+
float ax = indexLowerInterval;
155+
// Here we must scale according to the bits
156+
float lx = (indexUpperInterval - ax) * indexScale;
157+
float ay = queryCorrections.lowerInterval();
158+
float ly = (queryCorrections.upperInterval() - ay) * queryScale;
159+
float y1 = queryCorrections.quantizedComponentSum();
160+
float score = ax * ay * dimensions + ay * lx * x1 + ax * ly * y1 + lx * ly * dotProduct;
161+
// For euclidean, we need to invert the score and apply the additional
162+
// correction, which is
163+
// assumed to be the squared l2norm of the centroid centered vectors.
164+
if (similarityFunction == EUCLIDEAN) {
165+
score = queryCorrections.additionalCorrection() + indexAdditionalCorrection - 2 * score;
166+
return Math.max(1 / (1f + score), 0);
167+
} else {
168+
// For cosine and max inner product, we need to apply the additional correction,
169+
// which is
170+
// assumed to be the non-centered dot-product between the vector and the
171+
// centroid
172+
score +=
173+
queryCorrections.additionalCorrection() + indexAdditionalCorrection - centroidDotProduct;
174+
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
175+
return VectorUtil.scaleMaxInnerProductScore(score);
176+
}
177+
return Math.max((1f + score) / 2f, 0);
178+
}
179+
}
129180
}

lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene104MemorySegmentScalarQuantizedVectorScorer.java

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,11 @@ OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) throws I
168168
}
169169

170170
record Node(
171-
MemorySegment vector, OptimizedScalarQuantizer.QuantizationResult correctiveTerms) {}
171+
MemorySegment vector,
172+
float lowerInterval,
173+
float upperInterval,
174+
float additionalCorrection,
175+
int componentSum) {}
172176

173177
@SuppressWarnings("restricted")
174178
Node getNode(int ord) throws IOException {
@@ -182,14 +186,17 @@ Node getNode(int ord) throws IOException {
182186
input.readBytes(byteOffset, scratch, 0, nodeSize);
183187
vector = MemorySegment.ofArray(scratch);
184188
}
185-
var correctiveTerms =
186-
new OptimizedScalarQuantizer.QuantizationResult(
187-
Float.intBitsToFloat(vector.get(INT_UNALIGNED_LE, vectorByteSize)),
188-
Float.intBitsToFloat(vector.get(INT_UNALIGNED_LE, vectorByteSize + Integer.BYTES)),
189-
Float.intBitsToFloat(
190-
vector.get(INT_UNALIGNED_LE, vectorByteSize + Integer.BYTES * 2)),
191-
vector.get(INT_UNALIGNED_LE, vectorByteSize + Integer.BYTES * 3));
192-
return new Node(vector.reinterpret(vectorByteSize), correctiveTerms);
189+
// XXX investigate reordering the vector so that corrective terms appear first.
190+
// we're forced to read them immediately to avoid creating a second memory
191+
// segment which is
192+
// not cheap, so they might as well be read first to avoid additional memory
193+
// latency.
194+
return new Node(
195+
vector.reinterpret(vectorByteSize),
196+
Float.intBitsToFloat(vector.get(INT_UNALIGNED_LE, vectorByteSize)),
197+
Float.intBitsToFloat(vector.get(INT_UNALIGNED_LE, vectorByteSize + Integer.BYTES)),
198+
Float.intBitsToFloat(vector.get(INT_UNALIGNED_LE, vectorByteSize + Integer.BYTES * 2)),
199+
vector.get(INT_UNALIGNED_LE, vectorByteSize + Integer.BYTES * 3));
193200
}
194201

195202
OptimizedScalarQuantizedVectorSimilarity getSimilarity() {
@@ -242,7 +249,14 @@ public float score(int node) throws IOException {
242249
// Call getCorrectiveTerms() after computing dot product since corrective terms
243250
// bytes appear after the vector bytes, so this sequence of calls is more cache
244251
// friendly.
245-
return getSimilarity().score(dotProduct, queryCorrectiveTerms, doc.correctiveTerms);
252+
return getSimilarity()
253+
.score(
254+
dotProduct,
255+
queryCorrectiveTerms,
256+
doc.lowerInterval,
257+
doc.upperInterval,
258+
doc.additionalCorrection,
259+
doc.componentSum);
246260
}
247261
}
248262

0 commit comments

Comments
 (0)