Skip to content

Commit 955d083

Browse files
committed
cleanup
1 parent 5e18d3a commit 955d083

File tree

2 files changed

+33
-83
lines changed

2 files changed

+33
-83
lines changed

lucene/core/src/java/org/apache/lucene/util/quantization/OptimizedScalarQuantizedVectorSimilarity.java

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -97,40 +97,15 @@ public float score(
9797
float dotProduct,
9898
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
9999
OptimizedScalarQuantizer.QuantizationResult indexCorrections) {
100-
float x1 = indexCorrections.quantizedComponentSum();
101-
float ax = indexCorrections.lowerInterval();
102-
// Here we must scale according to the bits
103-
float lx = (indexCorrections.upperInterval() - ax) * indexScale;
104-
float ay = queryCorrections.lowerInterval();
105-
float ly = (queryCorrections.upperInterval() - ay) * queryScale;
106-
float y1 = queryCorrections.quantizedComponentSum();
107-
float score = ax * ay * dimensions + ay * lx * x1 + ax * ly * y1 + lx * ly * dotProduct;
108-
// For euclidean, we need to invert the score and apply the additional
109-
// correction, which is
110-
// assumed to be the squared l2norm of the centroid centered vectors.
111-
if (similarityFunction == EUCLIDEAN) {
112-
score =
113-
queryCorrections.additionalCorrection()
114-
+ indexCorrections.additionalCorrection()
115-
- 2 * score;
116-
return Math.max(1 / (1f + score), 0);
117-
} else {
118-
// For cosine and max inner product, we need to apply the additional correction,
119-
// which is
120-
// assumed to be the non-centered dot-product between the vector and the
121-
// centroid
122-
score +=
123-
queryCorrections.additionalCorrection()
124-
+ indexCorrections.additionalCorrection()
125-
- centroidDotProduct;
126-
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
127-
return VectorUtil.scaleMaxInnerProductScore(score);
128-
}
129-
return Math.max((1f + score) / 2f, 0);
130-
}
100+
return score(
101+
dotProduct,
102+
queryCorrections,
103+
indexCorrections.lowerInterval(),
104+
indexCorrections.upperInterval(),
105+
indexCorrections.additionalCorrection(),
106+
indexCorrections.quantizedComponentSum());
131107
}
132108

133-
// XXX DO NOT MERGE duplication with above.
134109
/**
135110
* Computes the similarity score between a 'query' and an 'index' quantized vector, given the dot
136111
* product of the two vectors and their corrective factors.

lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene104MemorySegmentScalarQuantizedVectorScorer.java

Lines changed: 26 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -129,44 +129,6 @@ final void checkOrdinal(int ord) {
129129
private static final ValueLayout.OfInt INT_UNALIGNED_LE =
130130
JAVA_INT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
131131

132-
// XXX I need to return something wraps the MemorySegment and can produce the
133-
// corrective terms
134-
// on demand. rep is probably (MemorySegment, MemorySegment) with a slice for
135-
// the corrective terms.
136-
@SuppressWarnings("restricted")
137-
MemorySegment getVector(int ord) throws IOException {
138-
checkOrdinal(ord);
139-
long byteOffset = (long) ord * nodeSize;
140-
MemorySegment vector = input.segmentSliceOrNull(byteOffset, vectorByteSize);
141-
if (vector == null) {
142-
if (scratch == null) {
143-
scratch = new byte[nodeSize];
144-
}
145-
input.readBytes(byteOffset, scratch, 0, nodeSize);
146-
vector = MemorySegment.ofArray(scratch).reinterpret(vectorByteSize);
147-
}
148-
return vector;
149-
}
150-
151-
@SuppressWarnings("restricted")
152-
OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) throws IOException {
153-
checkOrdinal(ord);
154-
long byteOffset = (long) ord * nodeSize + vectorByteSize;
155-
MemorySegment node = input.segmentSliceOrNull(byteOffset, CORRECTIVE_TERMS_SIZE);
156-
if (node == null) {
157-
if (scratch == null) {
158-
scratch = new byte[nodeSize];
159-
}
160-
input.readBytes(byteOffset, scratch, 0, CORRECTIVE_TERMS_SIZE);
161-
node = MemorySegment.ofArray(scratch).reinterpret(CORRECTIVE_TERMS_SIZE);
162-
}
163-
return new OptimizedScalarQuantizer.QuantizationResult(
164-
Float.intBitsToFloat(node.get(INT_UNALIGNED_LE, 0)),
165-
Float.intBitsToFloat(node.get(INT_UNALIGNED_LE, Integer.BYTES)),
166-
Float.intBitsToFloat(node.get(INT_UNALIGNED_LE, Integer.BYTES * 2)),
167-
node.get(INT_UNALIGNED_LE, Integer.BYTES * 3));
168-
}
169-
170132
record Node(
171133
MemorySegment vector,
172134
float lowerInterval,
@@ -188,9 +150,8 @@ Node getNode(int ord) throws IOException {
188150
}
189151
// XXX investigate reordering the vector so that corrective terms appear first.
190152
// we're forced to read them immediately to avoid creating a second memory
191-
// segment which is
192-
// not cheap, so they might as well be read first to avoid additional memory
193-
// latency.
153+
// segment which is not cheap, so they might as well be read first to avoid
154+
// additional memory latency.
194155
return new Node(
195156
vector.reinterpret(vectorByteSize),
196157
Float.intBitsToFloat(vector.get(INT_UNALIGNED_LE, vectorByteSize)),
@@ -260,7 +221,7 @@ public float score(int node) throws IOException {
260221
}
261222
}
262223

263-
private record RandomVectorScorerSupplierImpl(
224+
record RandomVectorScorerSupplierImpl(
264225
VectorSimilarityFunction similarityFunction,
265226
QuantizedByteVectorValues values,
266227
MemorySegmentAccessInput input)
@@ -293,23 +254,37 @@ private static class UpdateableRandomVectorScorerImpl extends RandomVectorScorer
293254
@Override
294255
public void setScoringOrdinal(int ord) throws IOException {
295256
checkOrdinal(ord);
296-
query = getVector(ord);
297-
queryCorrectiveTerms = getCorrectiveTerms(ord);
257+
Node node = getNode(ord);
258+
query = node.vector();
259+
queryCorrectiveTerms =
260+
new OptimizedScalarQuantizer.QuantizationResult(
261+
node.lowerInterval(),
262+
node.upperInterval(),
263+
node.additionalCorrection(),
264+
node.componentSum());
298265
}
299266

300267
@Override
301268
public float score(int node) throws IOException {
302-
MemorySegment doc = getVector(node);
269+
Node doc = getNode(node);
303270
float dotProduct =
304271
switch (getScalarEncoding()) {
305-
case UNSIGNED_BYTE -> PanamaVectorUtilSupport.uint8DotProduct(query, doc);
306-
case SEVEN_BIT -> PanamaVectorUtilSupport.uint8DotProduct(query, doc);
307-
case PACKED_NIBBLE -> PanamaVectorUtilSupport.int4DotProductBothPacked(query, doc);
272+
case UNSIGNED_BYTE -> PanamaVectorUtilSupport.uint8DotProduct(query, doc.vector());
273+
case SEVEN_BIT -> PanamaVectorUtilSupport.uint8DotProduct(query, doc.vector());
274+
case PACKED_NIBBLE ->
275+
PanamaVectorUtilSupport.int4DotProductBothPacked(query, doc.vector());
308276
};
309277
// Call getCorrectiveTerms() after computing dot product since corrective terms
310-
// bytes appear
311-
// after the vector bytes, so this sequence of calls is more cache friendly.
312-
return getSimilarity().score(dotProduct, queryCorrectiveTerms, getCorrectiveTerms(node));
278+
// bytes appear after the vector bytes, so this sequence of calls is more cache
279+
// friendly.
280+
return getSimilarity()
281+
.score(
282+
dotProduct,
283+
queryCorrectiveTerms,
284+
doc.lowerInterval(),
285+
doc.upperInterval(),
286+
doc.additionalCorrection(),
287+
doc.componentSum());
313288
}
314289
}
315290
}

0 commit comments

Comments
 (0)