Skip to content

Commit e7178bc

Browse files
committed
try another formulation of vector handling
1 parent 955d083 commit e7178bc

File tree

1 file changed

+46
-8
lines changed

1 file changed

+46
-8
lines changed

lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene104MemorySegmentScalarQuantizedVectorScorer.java

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,43 @@ Node getNode(int ord) throws IOException {
160160
vector.get(INT_UNALIGNED_LE, vectorByteSize + Integer.BYTES * 3));
161161
}
162162

163+
MemorySegment getRawVector(int ord) throws IOException {
164+
checkOrdinal(ord);
165+
long byteOffset = (long) ord * nodeSize;
166+
MemorySegment vector = input.segmentSliceOrNull(byteOffset, nodeSize);
167+
if (vector != null) {
168+
return vector;
169+
}
170+
171+
if (scratch == null) {
172+
scratch = new byte[nodeSize];
173+
}
174+
input.readBytes(byteOffset, scratch, 0, nodeSize);
175+
return MemorySegment.ofArray(scratch);
176+
}
177+
178+
@SuppressWarnings("restricted")
179+
MemorySegment getVector(MemorySegment rawVector) {
180+
return rawVector.reinterpret(vectorByteSize);
181+
}
182+
183+
float getLowerInterval(MemorySegment rawVector) {
184+
return Float.intBitsToFloat(rawVector.get(INT_UNALIGNED_LE, vectorByteSize));
185+
}
186+
187+
float getUpperInterval(MemorySegment rawVector) {
188+
return Float.intBitsToFloat(rawVector.get(INT_UNALIGNED_LE, vectorByteSize + Integer.BYTES));
189+
}
190+
191+
float getAdditionalCorrection(MemorySegment rawVector) {
192+
return Float.intBitsToFloat(
193+
rawVector.get(INT_UNALIGNED_LE, vectorByteSize + Integer.BYTES * 2));
194+
}
195+
196+
int getComponentSum(MemorySegment rawVector) {
197+
return rawVector.get(INT_UNALIGNED_LE, vectorByteSize + Integer.BYTES * 3);
198+
}
199+
163200
OptimizedScalarQuantizedVectorSimilarity getSimilarity() {
164201
return similarity;
165202
}
@@ -199,13 +236,14 @@ private static class RandomVectorScorerImpl extends RandomVectorScorerBase {
199236

200237
@Override
201238
public float score(int node) throws IOException {
202-
Node doc = getNode(node);
239+
MemorySegment rawDoc = getRawVector(node);
240+
MemorySegment docVector = getVector(rawDoc);
203241
float dotProduct =
204242
switch (getScalarEncoding()) {
205-
case UNSIGNED_BYTE -> PanamaVectorUtilSupport.uint8DotProduct(query, doc.vector);
206-
case SEVEN_BIT -> PanamaVectorUtilSupport.uint8DotProduct(query, doc.vector);
243+
case UNSIGNED_BYTE -> PanamaVectorUtilSupport.uint8DotProduct(query, docVector);
244+
case SEVEN_BIT -> PanamaVectorUtilSupport.uint8DotProduct(query, docVector);
207245
case PACKED_NIBBLE ->
208-
PanamaVectorUtilSupport.int4DotProductSinglePacked(query, doc.vector);
246+
PanamaVectorUtilSupport.int4DotProductSinglePacked(query, docVector);
209247
};
210248
// Call getCorrectiveTerms() after computing dot product since corrective terms
211249
// bytes appear after the vector bytes, so this sequence of calls is more cache
@@ -214,10 +252,10 @@ public float score(int node) throws IOException {
214252
.score(
215253
dotProduct,
216254
queryCorrectiveTerms,
217-
doc.lowerInterval,
218-
doc.upperInterval,
219-
doc.additionalCorrection,
220-
doc.componentSum);
255+
getLowerInterval(rawDoc),
256+
getUpperInterval(rawDoc),
257+
getAdditionalCorrection(rawDoc),
258+
getComponentSum(rawDoc));
221259
}
222260
}
223261

0 commit comments

Comments
 (0)