Skip to content

Commit 4fcc6f8

Browse files
committed
iter
1 parent 5dcf12d commit 4fcc6f8

File tree

2 files changed

+153
-61
lines changed

2 files changed

+153
-61
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 138 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.apache.lucene.store.IndexInput;
1919
import org.apache.lucene.store.IndexOutput;
2020
import org.apache.lucene.util.VectorUtil;
21+
import org.apache.lucene.util.hnsw.IntToIntFunction;
2122
import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans;
2223
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
2324
import org.elasticsearch.logging.LogManager;
@@ -81,28 +82,27 @@ long[] buildAndWritePostingsLists(
8182
}
8283
// write the posting lists
8384
final long[] offsets = new long[centroidSupplier.size()];
84-
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
8585
DocIdsWriter docIdsWriter = new DocIdsWriter();
86-
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(
87-
ES91OSQVectorsScorer.BULK_SIZE,
88-
quantizer,
86+
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
87+
OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors(
8988
floatVectorValues,
90-
postingsOutput
89+
fieldInfo.getVectorDimension(),
90+
new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction())
9191
);
9292
for (int c = 0; c < centroidSupplier.size(); c++) {
9393
float[] centroid = centroidSupplier.centroid(c);
94-
// TODO: add back in sorting vectors by distance to centroid
9594
int[] cluster = assignmentsByCluster[c];
9695
// TODO align???
9796
offsets[c] = postingsOutput.getFilePointer();
9897
int size = cluster.length;
9998
postingsOutput.writeVInt(size);
10099
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
100+
onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]);
101101
// TODO we might want to consider putting the docIds in a separate file
102102
// to aid with only having to fetch vectors from slower storage when they are required
103103
// keeping them in the same file indicates we pull the entire file into cache
104104
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
105-
bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid);
105+
bulkWriter.writeVectors(onHeapQuantizedVectors);
106106
}
107107

108108
if (logger.isDebugEnabled()) {
@@ -161,6 +161,35 @@ long[] buildAndWritePostingsLists(
161161
mergeState.segmentInfo.dir.deleteFile(quantizedVectorsTemp.getName());
162162
}
163163
}
164+
int[] centroidVectorCount = new int[centroidSupplier.size()];
165+
for (int i = 0; i < assignments.length; i++) {
166+
centroidVectorCount[assignments[i]]++;
167+
// if soar assignments are present, count them as well
168+
if (overspillAssignments.length > i && overspillAssignments[i] != -1) {
169+
centroidVectorCount[overspillAssignments[i]]++;
170+
}
171+
}
172+
173+
int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
174+
boolean[][] isOverspillByCluster = new boolean[centroidSupplier.size()][];
175+
for (int c = 0; c < centroidSupplier.size(); c++) {
176+
assignmentsByCluster[c] = new int[centroidVectorCount[c]];
177+
isOverspillByCluster[c] = new boolean[centroidVectorCount[c]];
178+
}
179+
Arrays.fill(centroidVectorCount, 0);
180+
181+
for (int i = 0; i < assignments.length; i++) {
182+
int c = assignments[i];
183+
assignmentsByCluster[c][centroidVectorCount[c]++] = i;
184+
// if soar assignments are present, add them to the cluster as well
185+
if (overspillAssignments.length > i) {
186+
int s = overspillAssignments[i];
187+
if (s != -1) {
188+
assignmentsByCluster[s][centroidVectorCount[s]] = i;
189+
isOverspillByCluster[s][centroidVectorCount[s]++] = true;
190+
}
191+
}
192+
}
164193
// now we can read the quantized vectors from the temporary file
165194
try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) {
166195
final long[] offsets = new long[centroidSupplier.size()];
@@ -169,26 +198,22 @@ long[] buildAndWritePostingsLists(
169198
fieldInfo.getVectorDimension()
170199
);
171200
DocIdsWriter docIdsWriter = new DocIdsWriter();
172-
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(
173-
ES91OSQVectorsScorer.BULK_SIZE,
174-
quantizer,
175-
floatVectorValues,
176-
postingsOutput
177-
);
201+
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
178202
for (int c = 0; c < centroidSupplier.size(); c++) {
179203
float[] centroid = centroidSupplier.centroid(c);
180-
// TODO: add back in sorting vectors by distance to centroid
181204
int[] cluster = assignmentsByCluster[c];
205+
boolean[] isOverspill = isOverspillByCluster[c];
182206
// TODO align???
183207
offsets[c] = postingsOutput.getFilePointer();
184208
int size = cluster.length;
185209
postingsOutput.writeVInt(size);
186210
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
211+
offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]);
187212
// TODO we might want to consider putting the docIds in a separate file
188213
// to aid with only having to fetch vectors from slower storage when they are required
189214
// keeping them in the same file indicates we pull the entire file into cache
190215
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
191-
bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid);
216+
bulkWriter.writeVectors(offHeapQuantizedVectors);
192217
}
193218

194219
if (logger.isDebugEnabled()) {
@@ -370,47 +395,131 @@ public float[] centroid(int centroidOrdinal) throws IOException {
370395
}
371396
}
372397

373-
static class OffHeapQuantizedVectors {
398+
interface QuantizedVectorValues {
399+
int count();
400+
401+
byte[] next() throws IOException;
402+
403+
OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException;
404+
}
405+
406+
interface IntToBooleanFunction {
407+
boolean apply(int ord);
408+
}
409+
410+
static class OnHeapQuantizedVectors implements QuantizedVectorValues {
411+
private final FloatVectorValues vectorValues;
412+
private final OptimizedScalarQuantizer quantizer;
413+
private final byte[] quantizedVector;
414+
private final int[] quantizedVectorScratch;
415+
private OptimizedScalarQuantizer.QuantizationResult corrections;
416+
private float[] currentCentroid;
417+
private IntToIntFunction ordTransformer = null;
418+
private int currOrd = -1;
419+
private int count;
420+
421+
OnHeapQuantizedVectors(FloatVectorValues vectorValues, int dimension, OptimizedScalarQuantizer quantizer) {
422+
this.vectorValues = vectorValues;
423+
this.quantizer = quantizer;
424+
this.quantizedVector = new byte[BQVectorUtils.discretize(dimension, 64) / 8];
425+
this.quantizedVectorScratch = new int[dimension];
426+
this.corrections = null;
427+
}
428+
429+
private void reset(float[] centroid, int count, IntToIntFunction ordTransformer) {
430+
this.currentCentroid = centroid;
431+
this.ordTransformer = ordTransformer;
432+
this.currOrd = -1;
433+
this.count = count;
434+
}
435+
436+
@Override
437+
public int count() {
438+
return count;
439+
}
440+
441+
@Override
442+
public byte[] next() throws IOException {
443+
if (currOrd >= count() - 1) {
444+
throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count());
445+
}
446+
currOrd++;
447+
int ord = ordTransformer.apply(currOrd);
448+
float[] vector = vectorValues.vectorValue(ord);
449+
corrections = quantizer.scalarQuantize(vector, quantizedVectorScratch, (byte) 1, currentCentroid);
450+
BQVectorUtils.packAsBinary(quantizedVectorScratch, quantizedVector);
451+
return quantizedVector;
452+
}
453+
454+
@Override
455+
public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException {
456+
if (currOrd == -1) {
457+
throw new IllegalStateException("No vector read yet, call next first");
458+
}
459+
return corrections;
460+
}
461+
}
462+
463+
static class OffHeapQuantizedVectors implements QuantizedVectorValues {
374464
private final IndexInput quantizedVectorsInput;
375465
private final byte[] binaryScratch;
376466
private final float[] corrections = new float[3];
377467

378468
private final int vectorByteSize;
379469
private short bitSum;
380470
private int currOrd = -1;
381-
private boolean isOverspill = false;
471+
private int count;
472+
private IntToBooleanFunction isOverspill = null;
473+
private IntToIntFunction ordTransformer = null;
382474

383475
OffHeapQuantizedVectors(IndexInput quantizedVectorsInput, int dimension) {
384476
this.quantizedVectorsInput = quantizedVectorsInput;
385477
this.binaryScratch = new byte[BQVectorUtils.discretize(dimension, 64) / 8];
386478
this.vectorByteSize = (binaryScratch.length + 3 * Float.BYTES + Short.BYTES);
387479
}
388480

389-
byte[] getVector(int ord, boolean isOverspill) throws IOException {
390-
readQuantizedVector(ord, isOverspill);
391-
return binaryScratch;
481+
private void reset(int count, IntToBooleanFunction isOverspill, IntToIntFunction ordTransformer) {
482+
this.count = count;
483+
this.isOverspill = isOverspill;
484+
this.ordTransformer = ordTransformer;
485+
this.currOrd = -1;
392486
}
393487

394-
OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException {
488+
@Override
489+
public int count() {
490+
return count;
491+
}
492+
493+
@Override
494+
public byte[] next() throws IOException {
495+
if (currOrd >= count - 1) {
496+
throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count);
497+
}
498+
currOrd++;
499+
int ord = ordTransformer.apply(currOrd);
500+
boolean isOverspill = this.isOverspill.apply(currOrd);
501+
return getVector(ord, isOverspill);
502+
}
503+
504+
@Override
505+
public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException {
395506
if (currOrd == -1) {
396507
throw new IllegalStateException("No vector read yet, call readQuantizedVector first");
397508
}
398509
return new OptimizedScalarQuantizer.QuantizationResult(corrections[0], corrections[1], corrections[2], bitSum);
399510
}
400511

512+
byte[] getVector(int ord, boolean isOverspill) throws IOException {
513+
readQuantizedVector(ord, isOverspill);
514+
return binaryScratch;
515+
}
516+
401517
public void readQuantizedVector(int ord, boolean isOverspill) throws IOException {
402-
if (ord == currOrd && isOverspill == this.isOverspill) {
403-
return; // no need to read again
404-
}
405-
long offset = (long) ord * (vectorByteSize * 2) + (isOverspill ? vectorByteSize : 0);
518+
long offset = (long) ord * (vectorByteSize * 2L) + (isOverspill ? vectorByteSize : 0);
406519
quantizedVectorsInput.seek(offset);
407520
quantizedVectorsInput.readBytes(binaryScratch, 0, binaryScratch.length);
408521
quantizedVectorsInput.readFloats(corrections, 0, 3);
409522
bitSum = quantizedVectorsInput.readShort();
410-
if (ord != currOrd) {
411-
currOrd = ord;
412-
}
413-
this.isOverspill = isOverspill;
414523
}
415524
}
416525
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,25 @@
99

1010
package org.elasticsearch.index.codec.vectors;
1111

12-
import org.apache.lucene.index.FloatVectorValues;
1312
import org.apache.lucene.store.IndexOutput;
14-
import org.apache.lucene.util.hnsw.IntToIntFunction;
1513

1614
import java.io.IOException;
1715

18-
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
19-
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary;
20-
2116
/**
2217
* Base class for bulk writers that write vectors to disk using the BBQ encoding.
2318
* This class provides the structure for writing vectors in bulk, with specific
2419
* implementations for different bit sizes strategies.
2520
*/
26-
public abstract class DiskBBQBulkWriter {
21+
abstract class DiskBBQBulkWriter {
2722
protected final int bulkSize;
28-
protected final OptimizedScalarQuantizer quantizer;
2923
protected final IndexOutput out;
30-
protected final FloatVectorValues fvv;
3124

32-
protected DiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
25+
protected DiskBBQBulkWriter(int bulkSize, IndexOutput out) {
3326
this.bulkSize = bulkSize;
34-
this.quantizer = quantizer;
3527
this.out = out;
36-
this.fvv = fvv;
3728
}
3829

39-
public abstract void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException;
30+
abstract void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException;
4031

4132
private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException {
4233
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
@@ -64,39 +55,31 @@ private static void writeCorrection(OptimizedScalarQuantizer.QuantizationResult
6455
out.writeShort((short) targetComponentSum);
6556
}
6657

67-
public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
68-
private final byte[] binarized;
69-
private final int[] initQuantized;
58+
static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
7059
private final OptimizedScalarQuantizer.QuantizationResult[] corrections;
7160

72-
public OneBitDiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
73-
super(bulkSize, quantizer, fvv, out);
74-
this.binarized = new byte[discretize(fvv.dimension(), 64) / 8];
75-
this.initQuantized = new int[fvv.dimension()];
61+
OneBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) {
62+
super(bulkSize, out);
7663
this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize];
7764
}
7865

7966
@Override
80-
public void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException {
81-
int limit = count - bulkSize + 1;
67+
void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException {
68+
int limit = qvv.count() - bulkSize + 1;
8269
int i = 0;
8370
for (; i < limit; i += bulkSize) {
8471
for (int j = 0; j < bulkSize; j++) {
85-
int ord = ords.apply(i + j);
86-
float[] fv = fvv.vectorValue(ord);
87-
corrections[j] = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid);
88-
packAsBinary(initQuantized, binarized);
89-
out.writeBytes(binarized, binarized.length);
72+
byte[] qv = qvv.next();
73+
corrections[j] = qvv.getCorrections();
74+
out.writeBytes(qv, qv.length);
9075
}
9176
writeCorrections(corrections, out);
9277
}
9378
// write tail
94-
for (; i < count; ++i) {
95-
int ord = ords.apply(i);
96-
float[] fv = fvv.vectorValue(ord);
97-
OptimizedScalarQuantizer.QuantizationResult correction = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid);
98-
packAsBinary(initQuantized, binarized);
99-
out.writeBytes(binarized, binarized.length);
79+
for (; i < qvv.count(); ++i) {
80+
byte[] qv = qvv.next();
81+
OptimizedScalarQuantizer.QuantizationResult correction = qvv.getCorrections();
82+
out.writeBytes(qv, qv.length);
10083
writeCorrection(correction, out);
10184
}
10285
}

0 commit comments

Comments
 (0)