Skip to content

Commit 628828f

Browse files
authored
[DiskBBQ] Write the raw centroid on the posting list file instead of the centroids file (elastic#131421)
1 parent f739673 commit 628828f

File tree

3 files changed

+31
-44
lines changed

3 files changed

+31
-44
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -68,35 +68,23 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind
6868
return new CentroidQueryScorer() {
6969
int currentCentroid = -1;
7070
long postingListOffset;
71-
private final float[] centroid = new float[fieldInfo.getVectorDimension()];
7271
private final float[] centroidCorrectiveValues = new float[3];
73-
private final long rawCentroidsOffset = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES);
74-
private final long rawCentroidsByteSize = (long) Float.BYTES * fieldInfo.getVectorDimension() + Long.BYTES;
72+
private final long quantizeCentroidsLength = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES
73+
+ Short.BYTES);
7574

7675
@Override
7776
public int size() {
7877
return numCentroids;
7978
}
8079

81-
@Override
82-
public float[] centroid(int centroidOrdinal) throws IOException {
83-
readDataIfNecessary(centroidOrdinal);
84-
return centroid;
85-
}
86-
8780
@Override
8881
public long postingListOffset(int centroidOrdinal) throws IOException {
89-
readDataIfNecessary(centroidOrdinal);
90-
return postingListOffset;
91-
}
92-
93-
private void readDataIfNecessary(int centroidOrdinal) throws IOException {
9482
if (centroidOrdinal != currentCentroid) {
95-
centroids.seek(rawCentroidsOffset + rawCentroidsByteSize * centroidOrdinal);
96-
centroids.readFloats(centroid, 0, centroid.length);
83+
centroids.seek(quantizeCentroidsLength + (long) Long.BYTES * centroidOrdinal);
9784
postingListOffset = centroids.readLong();
9885
currentCentroid = centroidOrdinal;
9986
}
87+
return postingListOffset;
10088
}
10189

10290
public void bulkScore(NeighborQueue queue) throws IOException {
@@ -193,7 +181,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
193181
int vectors;
194182
boolean quantized = false;
195183
float centroidDp;
196-
float[] centroid;
184+
final float[] centroid;
197185
long slicePos;
198186
OptimizedScalarQuantizer.QuantizationResult queryCorrections;
199187
DocIdsWriter docIdsWriter = new DocIdsWriter();
@@ -217,7 +205,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
217205
this.entry = entry;
218206
this.fieldInfo = fieldInfo;
219207
this.needsScoring = needsScoring;
220-
208+
centroid = new float[fieldInfo.getVectorDimension()];
221209
scratch = new float[target.length];
222210
quantizationScratch = new int[target.length];
223211
final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64);
@@ -229,12 +217,12 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
229217
}
230218

231219
@Override
232-
public int resetPostingsScorer(long offset, float[] centroid) throws IOException {
220+
public int resetPostingsScorer(long offset) throws IOException {
233221
quantized = false;
234222
indexInput.seek(offset);
235-
vectors = indexInput.readVInt();
223+
indexInput.readFloats(centroid, 0, centroid.length);
236224
centroidDp = Float.intBitsToFloat(indexInput.readInt());
237-
this.centroid = centroid;
225+
vectors = indexInput.readVInt();
238226
// read the doc ids
239227
docIdsScratch = vectors > docIdsScratch.length ? new int[vectors] : docIdsScratch;
240228
docIdsWriter.readInts(indexInput, vectors, docIdsScratch);

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,19 +92,25 @@ LongValues buildAndWritePostingsLists(
9292
fieldInfo.getVectorDimension(),
9393
new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction())
9494
);
95+
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
9596
for (int c = 0; c < centroidSupplier.size(); c++) {
9697
float[] centroid = centroidSupplier.centroid(c);
9798
int[] cluster = assignmentsByCluster[c];
98-
// TODO align???
99-
offsets.add(postingsOutput.getFilePointer());
99+
offsets.add(postingsOutput.alignFilePointer(Float.BYTES));
100+
buffer.asFloatBuffer().put(centroid);
101+
// write raw centroid for quantizing the query vectors
102+
postingsOutput.writeBytes(buffer.array(), buffer.array().length);
103+
// write centroid dot product for quantizing the query vectors
104+
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
100105
int size = cluster.length;
106+
// write docIds
101107
postingsOutput.writeVInt(size);
102-
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
103108
onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]);
104109
// TODO we might want to consider putting the docIds in a separate file
105110
// to aid with only having to fetch vectors from slower storage when they are required
106111
// keeping them in the same file indicates we pull the entire file into cache
107112
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
113+
// write vectors
108114
bulkWriter.writeVectors(onHeapQuantizedVectors);
109115
}
110116

@@ -209,20 +215,26 @@ LongValues buildAndWritePostingsLists(
209215
);
210216
DocIdsWriter docIdsWriter = new DocIdsWriter();
211217
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
218+
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
212219
for (int c = 0; c < centroidSupplier.size(); c++) {
213220
float[] centroid = centroidSupplier.centroid(c);
214221
int[] cluster = assignmentsByCluster[c];
215222
boolean[] isOverspill = isOverspillByCluster[c];
216-
offsets.add(postingsOutput.getFilePointer());
223+
offsets.add(postingsOutput.alignFilePointer(Float.BYTES));
224+
// write raw centroid for quantizing the query vectors
225+
buffer.asFloatBuffer().put(centroid);
226+
postingsOutput.writeBytes(buffer.array(), buffer.array().length);
227+
// write centroid dot product for quantizing the query vectors
228+
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
229+
// write docIds
217230
int size = cluster.length;
218-
// TODO align???
219231
postingsOutput.writeVInt(size);
220-
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
221232
offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]);
222233
// TODO we might want to consider putting the docIds in a separate file
223234
// to aid with only having to fetch vectors from slower storage when they are required
224235
// keeping them in the same file indicates we pull the entire file into cache
225236
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
237+
// write vectors
226238
bulkWriter.writeVectors(offHeapQuantizedVectors);
227239
}
228240

@@ -298,13 +310,8 @@ void writeCentroids(
298310
}
299311
writeQuantizedValue(centroidOutput, quantized, result);
300312
}
301-
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
313+
// write the centroid offsets at the end of the file
302314
for (int i = 0; i < centroidSupplier.size(); i++) {
303-
float[] centroid = centroidSupplier.centroid(i);
304-
buffer.asFloatBuffer().put(centroid);
305-
// write the centroids
306-
centroidOutput.writeBytes(buffer.array(), buffer.array().length);
307-
// write the offset of this posting list
308315
centroidOutput.writeLong(offsets.get(i));
309316
}
310317
}

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
266266
int centroidOrdinal = centroidQueue.pop();
267267
// todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing
268268
// is enough?
269-
expectedDocs += scorer.resetPostingsScorer(
270-
centroidQueryScorer.postingListOffset(centroidOrdinal),
271-
centroidQueryScorer.centroid(centroidOrdinal)
272-
);
269+
expectedDocs += scorer.resetPostingsScorer(centroidQueryScorer.postingListOffset(centroidOrdinal));
273270
actualDocs += scorer.visit(knnCollector);
274271
}
275272
if (acceptDocs != null) {
@@ -278,10 +275,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
278275
float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f);
279276
while (centroidQueue.size() > 0 && (actualDocs < expectedScored || actualDocs < knnCollector.k())) {
280277
int centroidOrdinal = centroidQueue.pop();
281-
scorer.resetPostingsScorer(
282-
centroidQueryScorer.postingListOffset(centroidOrdinal),
283-
centroidQueryScorer.centroid(centroidOrdinal)
284-
);
278+
scorer.resetPostingsScorer(centroidQueryScorer.postingListOffset(centroidOrdinal));
285279
actualDocs += scorer.visit(knnCollector);
286280
}
287281
}
@@ -332,8 +326,6 @@ abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postin
332326
interface CentroidQueryScorer {
333327
int size();
334328

335-
float[] centroid(int centroidOrdinal) throws IOException;
336-
337329
long postingListOffset(int centroidOrdinal) throws IOException;
338330

339331
void bulkScore(NeighborQueue queue) throws IOException;
@@ -343,7 +335,7 @@ interface PostingVisitor {
343335
// TODO maybe we can not specifically pass the centroid...
344336

345337
/** returns the number of documents in the posting list */
346-
int resetPostingsScorer(long offset, float[] centroid) throws IOException;
338+
int resetPostingsScorer(long offset) throws IOException;
347339

348340
/** returns the number of scored documents */
349341
int visit(KnnCollector collector) throws IOException;

0 commit comments

Comments
 (0)