Skip to content

Commit f4792a9

Browse files
committed
[DiskBBQ] Write the raw centroid on the posting list file instead of the centroids file
1 parent efd3110 commit f4792a9

File tree

3 files changed

+29
-43
lines changed

3 files changed

+29
-43
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -68,35 +68,22 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind
6868
return new CentroidQueryScorer() {
6969
int currentCentroid = -1;
7070
long postingListOffset;
71-
private final float[] centroid = new float[fieldInfo.getVectorDimension()];
7271
private final float[] centroidCorrectiveValues = new float[3];
7372
private final long rawCentroidsOffset = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES);
74-
private final long rawCentroidsByteSize = (long) Float.BYTES * fieldInfo.getVectorDimension() + Long.BYTES;
7573

7674
@Override
7775
public int size() {
7876
return numCentroids;
7977
}
8078

81-
@Override
82-
public float[] centroid(int centroidOrdinal) throws IOException {
83-
readDataIfNecessary(centroidOrdinal);
84-
return centroid;
85-
}
86-
8779
@Override
8880
public long postingListOffset(int centroidOrdinal) throws IOException {
89-
readDataIfNecessary(centroidOrdinal);
90-
return postingListOffset;
91-
}
92-
93-
private void readDataIfNecessary(int centroidOrdinal) throws IOException {
9481
if (centroidOrdinal != currentCentroid) {
95-
centroids.seek(rawCentroidsOffset + rawCentroidsByteSize * centroidOrdinal);
96-
centroids.readFloats(centroid, 0, centroid.length);
82+
centroids.seek(rawCentroidsOffset + (long) Long.BYTES * centroidOrdinal);
9783
postingListOffset = centroids.readLong();
9884
currentCentroid = centroidOrdinal;
9985
}
86+
return postingListOffset;
10087
}
10188

10289
public void bulkScore(NeighborQueue queue) throws IOException {
@@ -193,7 +180,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
193180
int vectors;
194181
boolean quantized = false;
195182
float centroidDp;
196-
float[] centroid;
183+
final float[] centroid;
197184
long slicePos;
198185
OptimizedScalarQuantizer.QuantizationResult queryCorrections;
199186
DocIdsWriter docIdsWriter = new DocIdsWriter();
@@ -217,7 +204,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
217204
this.entry = entry;
218205
this.fieldInfo = fieldInfo;
219206
this.needsScoring = needsScoring;
220-
207+
centroid = new float[fieldInfo.getVectorDimension()];
221208
scratch = new float[target.length];
222209
quantizationScratch = new int[target.length];
223210
final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64);
@@ -229,12 +216,12 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
229216
}
230217

231218
@Override
232-
public int resetPostingsScorer(long offset, float[] centroid) throws IOException {
219+
public int resetPostingsScorer(long offset) throws IOException {
233220
quantized = false;
234221
indexInput.seek(offset);
235-
vectors = indexInput.readVInt();
222+
indexInput.readFloats(centroid, 0, centroid.length);
236223
centroidDp = Float.intBitsToFloat(indexInput.readInt());
237-
this.centroid = centroid;
224+
vectors = indexInput.readVInt();
238225
// read the doc ids
239226
docIdsScratch = vectors > docIdsScratch.length ? new int[vectors] : docIdsScratch;
240227
docIdsWriter.readInts(indexInput, vectors, docIdsScratch);

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,25 @@ long[] buildAndWritePostingsLists(
8989
fieldInfo.getVectorDimension(),
9090
new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction())
9191
);
92+
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
9293
for (int c = 0; c < centroidSupplier.size(); c++) {
9394
float[] centroid = centroidSupplier.centroid(c);
9495
int[] cluster = assignmentsByCluster[c];
95-
// TODO align???
96-
offsets[c] = postingsOutput.getFilePointer();
96+
offsets[c] = postingsOutput.alignFilePointer(Float.BYTES);
97+
buffer.asFloatBuffer().put(centroid);
98+
// write raw centroid for quantizing the query vectors
99+
postingsOutput.writeBytes(buffer.array(), buffer.array().length);
100+
// write centroid dot product for quantizing the query vectors
101+
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
97102
int size = cluster.length;
103+
// write docIds
98104
postingsOutput.writeVInt(size);
99-
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
100105
onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]);
101106
// TODO we might want to consider putting the docIds in a separate file
102107
// to aid with only having to fetch vectors from slower storage when they are required
103108
// keeping them in the same file indicates we pull the entire file into cache
104109
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
110+
// write vectors
105111
bulkWriter.writeVectors(onHeapQuantizedVectors);
106112
}
107113

@@ -206,20 +212,26 @@ long[] buildAndWritePostingsLists(
206212
);
207213
DocIdsWriter docIdsWriter = new DocIdsWriter();
208214
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
215+
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
209216
for (int c = 0; c < centroidSupplier.size(); c++) {
210217
float[] centroid = centroidSupplier.centroid(c);
211218
int[] cluster = assignmentsByCluster[c];
212219
boolean[] isOverspill = isOverspillByCluster[c];
213-
// TODO align???
214-
offsets[c] = postingsOutput.getFilePointer();
220+
offsets[c] = postingsOutput.alignFilePointer(Float.BYTES);
221+
// write raw centroid for quantizing the query vectors
222+
buffer.asFloatBuffer().put(centroid);
223+
postingsOutput.writeBytes(buffer.array(), buffer.array().length);
224+
// write centroid dot product for quantizing the query vectors
225+
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
226+
// write docIds
215227
int size = cluster.length;
216228
postingsOutput.writeVInt(size);
217-
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
218229
offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]);
219230
// TODO we might want to consider putting the docIds in a separate file
220231
// to aid with only having to fetch vectors from slower storage when they are required
221232
// keeping them in the same file indicates we pull the entire file into cache
222233
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
234+
// write vectors
223235
bulkWriter.writeVectors(offHeapQuantizedVectors);
224236
}
225237

@@ -295,13 +307,8 @@ void writeCentroids(
295307
}
296308
writeQuantizedValue(centroidOutput, quantized, result);
297309
}
298-
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
310+
// write the centroid offsets at the end of the file
299311
for (int i = 0; i < centroidSupplier.size(); i++) {
300-
float[] centroid = centroidSupplier.centroid(i);
301-
buffer.asFloatBuffer().put(centroid);
302-
// write the centroids
303-
centroidOutput.writeBytes(buffer.array(), buffer.array().length);
304-
// write the offset of this posting list
305312
centroidOutput.writeLong(offsets[i]);
306313
}
307314
}

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
266266
int centroidOrdinal = centroidQueue.pop();
267267
// todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing
268268
// is enough?
269-
expectedDocs += scorer.resetPostingsScorer(
270-
centroidQueryScorer.postingListOffset(centroidOrdinal),
271-
centroidQueryScorer.centroid(centroidOrdinal)
272-
);
269+
expectedDocs += scorer.resetPostingsScorer(centroidQueryScorer.postingListOffset(centroidOrdinal));
273270
actualDocs += scorer.visit(knnCollector);
274271
}
275272
if (acceptDocs != null) {
@@ -278,10 +275,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
278275
float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f);
279276
while (centroidQueue.size() > 0 && (actualDocs < expectedScored || actualDocs < knnCollector.k())) {
280277
int centroidOrdinal = centroidQueue.pop();
281-
scorer.resetPostingsScorer(
282-
centroidQueryScorer.postingListOffset(centroidOrdinal),
283-
centroidQueryScorer.centroid(centroidOrdinal)
284-
);
278+
scorer.resetPostingsScorer(centroidQueryScorer.postingListOffset(centroidOrdinal));
285279
actualDocs += scorer.visit(knnCollector);
286280
}
287281
}
@@ -332,8 +326,6 @@ abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postin
332326
interface CentroidQueryScorer {
333327
int size();
334328

335-
float[] centroid(int centroidOrdinal) throws IOException;
336-
337329
long postingListOffset(int centroidOrdinal) throws IOException;
338330

339331
void bulkScore(NeighborQueue queue) throws IOException;
@@ -343,7 +335,7 @@ interface PostingVisitor {
343335
// TODO maybe we can not specifically pass the centroid...
344336

345337
/** returns the number of documents in the posting list */
346-
int resetPostingsScorer(long offset, float[] centroid) throws IOException;
338+
int resetPostingsScorer(long offset) throws IOException;
347339

348340
/** returns the number of scored documents */
349341
int visit(KnnCollector collector) throws IOException;

0 commit comments

Comments
 (0)