Skip to content

Commit e32ab82

Browse files
committed
ensure doc id sorting keep soar separate from nearest
1 parent f6caf36 commit e32ab82

File tree

2 files changed

+139
-142
lines changed

2 files changed

+139
-142
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 84 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import org.apache.lucene.store.IOContext;
1818
import org.apache.lucene.store.IndexInput;
1919
import org.apache.lucene.store.IndexOutput;
20-
import org.apache.lucene.util.IntroSorter;
2120
import org.apache.lucene.util.LongValues;
2221
import org.apache.lucene.util.VectorUtil;
2322
import org.apache.lucene.util.hnsw.IntToIntFunction;
@@ -49,28 +48,57 @@ public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVec
4948
this.vectorPerCluster = vectorPerCluster;
5049
}
5150

52-
@Override
53-
LongValues buildAndWritePostingsLists(
54-
FieldInfo fieldInfo,
55-
CentroidSupplier centroidSupplier,
56-
FloatVectorValues floatVectorValues,
57-
IndexOutput postingsOutput,
51+
private static void deltaEncode(int[] vals, int size, int[] deltas) {
52+
if (size == 0) {
53+
return;
54+
}
55+
deltas[0] = vals[0];
56+
for (int i = 1; i < size; i++) {
57+
assert vals[i] >= vals[i - 1] : "vals are not sorted: " + vals[i] + " < " + vals[i - 1];
58+
deltas[i] = vals[i] - vals[i - 1];
59+
}
60+
}
61+
62+
private static void translateOrdsToDocs(
63+
int[] ords,
64+
int size,
65+
int[] spillOrds,
66+
int spillSize,
67+
int[] docIds,
68+
int[] spillDocIds,
69+
IntToIntFunction ordToDoc
70+
) {
71+
int ordIdx = 0, spillOrdIdx = 0;
72+
while (ordIdx < size || spillOrdIdx < spillSize) {
73+
int nextOrd = (ordIdx < size) ? ords[ordIdx] : Integer.MAX_VALUE;
74+
int nextSpillOrd = (spillOrdIdx < spillSize) ? spillOrds[spillOrdIdx] : Integer.MAX_VALUE;
75+
if (nextOrd < nextSpillOrd) {
76+
docIds[ordIdx] = ordToDoc.apply(nextOrd);
77+
ordIdx++;
78+
} else {
79+
spillDocIds[spillOrdIdx] = ordToDoc.apply(nextSpillOrd);
80+
spillOrdIdx++;
81+
}
82+
}
83+
}
84+
85+
private static void pivotAssignments(
86+
int centroidCount,
5887
int[] assignments,
59-
int[] overspillAssignments
60-
) throws IOException {
61-
int[] centroidVectorCount = new int[centroidSupplier.size()];
62-
int[] overspillVectorCount = new int[centroidSupplier.size()];
88+
int[] overspillAssignments,
89+
int[][] assignmentsByCluster,
90+
int[][] overspillAssignmentsByCluster
91+
) {
92+
int[] centroidVectorCount = new int[centroidCount];
93+
int[] overspillVectorCount = new int[centroidCount];
6394
for (int i = 0; i < assignments.length; i++) {
6495
centroidVectorCount[assignments[i]]++;
6596
// if soar assignments are present, count them as well
6697
if (overspillAssignments.length > i && overspillAssignments[i] != -1) {
6798
overspillVectorCount[overspillAssignments[i]]++;
6899
}
69100
}
70-
71-
int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
72-
int[][] overspillAssignmentsByCluster = new int[centroidSupplier.size()][];
73-
for (int c = 0; c < centroidSupplier.size(); c++) {
101+
for (int c = 0; c < centroidCount; c++) {
74102
assignmentsByCluster[c] = new int[centroidVectorCount[c]];
75103
overspillAssignmentsByCluster[c] = new int[overspillVectorCount[c]];
76104
}
@@ -88,14 +116,35 @@ LongValues buildAndWritePostingsLists(
88116
}
89117
}
90118
}
119+
}
120+
121+
@Override
122+
LongValues buildAndWritePostingsLists(
123+
FieldInfo fieldInfo,
124+
CentroidSupplier centroidSupplier,
125+
FloatVectorValues floatVectorValues,
126+
IndexOutput postingsOutput,
127+
int[] assignments,
128+
int[] overspillAssignments
129+
) throws IOException {
130+
91131
// write the posting lists
92132
final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT);
93133
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
134+
// pivot the assignments into clusters
135+
int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
136+
int[][] overspillAssignmentsByCluster = new int[centroidSupplier.size()][];
137+
pivotAssignments(centroidSupplier.size(), assignments, overspillAssignments, assignmentsByCluster, overspillAssignmentsByCluster);
94138

95139
int[] docIds = null;
96140
int[] docDeltas = null;
97141
int[] spillDocIds = null;
98142
int[] spillDeltas = null;
143+
final OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors(
144+
floatVectorValues,
145+
fieldInfo.getVectorDimension(),
146+
new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction())
147+
);
99148
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
100149
for (int c = 0; c < centroidSupplier.size(); c++) {
101150
float[] centroid = centroidSupplier.centroid(c);
@@ -115,42 +164,13 @@ LongValues buildAndWritePostingsLists(
115164
spillDocIds = new int[spillSize];
116165
spillDeltas = new int[spillSize];
117166
}
118-
for (int j = 0; j < size; j++) {
119-
docIds[j] = floatVectorValues.ordToDoc(cluster[j]);
120-
}
121-
for (int j = 0; j < spillSize; j++) {
122-
spillDocIds[j] = floatVectorValues.ordToDoc(overspillCluster[j]);
123-
}
124-
final int[] finalDocs = docIds;
125-
final int[] finalSpillDocs = spillDocIds;
167+
translateOrdsToDocs(cluster, size, overspillCluster, spillSize, docIds, spillDocIds, floatVectorValues::ordToDoc);
126168
// encode doc deltas
127169
if (size > 0) {
128-
docDeltas[0] = finalDocs[0];
129-
for (int j = size - 1; j > 0; j--) {
130-
if (finalDocs[j] < finalDocs[j - 1]) {
131-
throw new IllegalStateException(
132-
"docIds are not sorted: "
133-
+ finalDocs[j]
134-
+ " < "
135-
+ finalDocs[j - 1]
136-
);
137-
}
138-
docDeltas[j] = finalDocs[j] - finalDocs[j - 1];
139-
}
170+
deltaEncode(docIds, size, docDeltas);
140171
}
141172
if (spillSize > 0) {
142-
spillDeltas[0] = finalSpillDocs[0];
143-
for (int j = spillSize - 1; j > 0; j--) {
144-
if (finalSpillDocs[j] < finalSpillDocs[j - 1]) {
145-
throw new IllegalStateException(
146-
"Overspill docIds are not sorted: "
147-
+ finalSpillDocs[j]
148-
+ " < "
149-
+ finalSpillDocs[j - 1]
150-
);
151-
}
152-
spillDeltas[j] = finalSpillDocs[j] - finalSpillDocs[j - 1];
153-
}
173+
deltaEncode(spillDocIds, spillSize, spillDeltas);
154174
}
155175
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
156176
postingsOutput.writeInt(size);
@@ -160,25 +180,16 @@ LongValues buildAndWritePostingsLists(
160180
// keeping them in the same file indicates we pull the entire file into cache
161181
postingsOutput.writeGroupVInts(docDeltas, size);
162182
postingsOutput.writeGroupVInts(spillDeltas, spillSize);
163-
OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors(
164-
floatVectorValues,
165-
fieldInfo.getVectorDimension(),
166-
new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction())
167-
);
168183
onHeapQuantizedVectors.reset(centroid, size, j -> cluster[j]);
169184
bulkWriter.writeVectors(onHeapQuantizedVectors);
170185
// write overspill vectors
171-
onHeapQuantizedVectors = new OnHeapQuantizedVectors(
172-
floatVectorValues,
173-
fieldInfo.getVectorDimension(),
174-
new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction())
175-
);
176186
onHeapQuantizedVectors.reset(centroid, spillSize, j -> overspillCluster[j]);
177187
bulkWriter.writeVectors(onHeapQuantizedVectors);
178188
}
179189

180190
if (logger.isDebugEnabled()) {
181191
printClusterQualityStatistics(assignmentsByCluster);
192+
printClusterQualityStatistics(overspillAssignmentsByCluster);
182193
}
183194

184195
return offsets.build();
@@ -240,40 +251,22 @@ LongValues buildAndWritePostingsLists(
240251
mergeState.segmentInfo.dir.deleteFile(quantizedVectorsTemp.getName());
241252
}
242253
}
243-
int[] centroidVectorCount = new int[centroidSupplier.size()];
244-
int[] overspillVectorCount = new int[centroidSupplier.size()];
245-
for (int i = 0; i < assignments.length; i++) {
246-
centroidVectorCount[assignments[i]]++;
247-
// if soar assignments are present, count them as well
248-
if (overspillAssignments.length > i && overspillAssignments[i] != -1) {
249-
overspillVectorCount[overspillAssignments[i]]++;
250-
}
251-
}
252-
253254
int[][] assignmentsByCluster = new int[centroidSupplier.size()][];
254255
int[][] overspillAssignmentsByCluster = new int[centroidSupplier.size()][];
255-
for (int c = 0; c < centroidSupplier.size(); c++) {
256-
assignmentsByCluster[c] = new int[centroidVectorCount[c]];
257-
overspillAssignmentsByCluster[c] = new int[overspillVectorCount[c]];
258-
}
259-
Arrays.fill(centroidVectorCount, 0);
260-
Arrays.fill(overspillVectorCount, 0);
261-
for (int i = 0; i < assignments.length; i++) {
262-
int c = assignments[i];
263-
assignmentsByCluster[c][centroidVectorCount[c]++] = i;
264-
// if soar assignments are present, add them to the cluster as well
265-
if (overspillAssignments.length > i) {
266-
int s = overspillAssignments[i];
267-
if (s != -1) {
268-
overspillAssignmentsByCluster[s][overspillVectorCount[s]++] = i;
269-
}
270-
}
271-
}
256+
// pivot the assignments into clusters
257+
pivotAssignments(centroidSupplier.size(), assignments, overspillAssignments, assignmentsByCluster, overspillAssignmentsByCluster);
272258
// now we can read the quantized vectors from the temporary file
273259
try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) {
274260
final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT);
275261

276-
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
262+
final DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(
263+
ES91OSQVectorsScorer.BULK_SIZE,
264+
postingsOutput
265+
);
266+
final OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors(
267+
quantizedVectorsInput,
268+
fieldInfo.getVectorDimension()
269+
);
277270
int[] docIds = null;
278271
int[] docDeltas = null;
279272
int[] spillDocIds = null;
@@ -297,26 +290,14 @@ LongValues buildAndWritePostingsLists(
297290
spillDocIds = new int[spillSize];
298291
spillDeltas = new int[spillSize];
299292
}
300-
for (int j = 0; j < size; j++) {
301-
docIds[j] = floatVectorValues.ordToDoc(cluster[j]);
302-
}
303-
for (int j = 0; j < spillSize; j++) {
304-
spillDocIds[j] = floatVectorValues.ordToDoc(overspillCluster[j]);
305-
}
306-
final int[] finalDocs = docIds;
307-
final int[] finalSpillDocs = spillDocIds;
293+
// translate ordinals to docIds
294+
translateOrdsToDocs(cluster, size, overspillCluster, spillSize, docIds, spillDocIds, floatVectorValues::ordToDoc);
308295
// encode doc deltas
309296
if (size > 0) {
310-
docDeltas[0] = finalDocs[0];
311-
for (int j = size - 1; j > 0; j--) {
312-
docDeltas[j] = finalDocs[j] - finalDocs[j - 1];
313-
}
297+
deltaEncode(docIds, size, docDeltas);
314298
}
315299
if (spillSize > 0) {
316-
spillDeltas[0] = finalSpillDocs[0];
317-
for (int j = spillSize - 1; j > 0; j--) {
318-
spillDeltas[j] = finalSpillDocs[j] - finalSpillDocs[j - 1];
319-
}
300+
deltaEncode(spillDocIds, spillSize, spillDeltas);
320301
}
321302
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
322303
postingsOutput.writeInt(size);
@@ -327,22 +308,16 @@ LongValues buildAndWritePostingsLists(
327308
postingsOutput.writeGroupVInts(docDeltas, size);
328309
postingsOutput.writeGroupVInts(spillDeltas, spillSize);
329310
// write overspill vectors
330-
OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors(
331-
quantizedVectorsInput,
332-
fieldInfo.getVectorDimension()
333-
);
311+
334312
offHeapQuantizedVectors.reset(size, false, j -> cluster[j]);
335313
bulkWriter.writeVectors(offHeapQuantizedVectors);
336-
offHeapQuantizedVectors = new OffHeapQuantizedVectors(
337-
quantizedVectorsInput,
338-
fieldInfo.getVectorDimension()
339-
);
340314
offHeapQuantizedVectors.reset(spillSize, true, j -> overspillCluster[j]);
341315
bulkWriter.writeVectors(offHeapQuantizedVectors);
342316
}
343317

344318
if (logger.isDebugEnabled()) {
345319
printClusterQualityStatistics(assignmentsByCluster);
320+
printClusterQualityStatistics(overspillAssignmentsByCluster);
346321
}
347322
return offsets.build();
348323
}
@@ -506,39 +481,6 @@ public float[] centroid(int centroidOrdinal) throws IOException {
506481
}
507482
}
508483

509-
static class IntSorter extends IntroSorter {
510-
int pivot = -1;
511-
private final int[] arr;
512-
private final IntToIntFunction func;
513-
514-
IntSorter(int[] arr, IntToIntFunction func) {
515-
this.arr = arr;
516-
this.func = func;
517-
}
518-
519-
@Override
520-
protected void setPivot(int i) {
521-
pivot = func.apply(arr[i]);
522-
}
523-
524-
@Override
525-
protected int comparePivot(int j) {
526-
return Integer.compare(pivot, func.apply(arr[j]));
527-
}
528-
529-
@Override
530-
protected int compare(int a, int b) {
531-
return Integer.compare(func.apply(arr[a]), func.apply(arr[b]));
532-
}
533-
534-
@Override
535-
protected void swap(int i, int j) {
536-
final int tmp = arr[i];
537-
arr[i] = arr[j];
538-
arr[j] = tmp;
539-
}
540-
}
541-
542484
interface QuantizedVectorValues {
543485
int count();
544486

0 commit comments

Comments
 (0)