|
14 | 14 | import org.apache.lucene.index.FloatVectorValues; |
15 | 15 | import org.apache.lucene.index.MergeState; |
16 | 16 | import org.apache.lucene.index.SegmentWriteState; |
| 17 | +import org.apache.lucene.store.IOContext; |
17 | 18 | import org.apache.lucene.store.IndexInput; |
18 | 19 | import org.apache.lucene.store.IndexOutput; |
19 | 20 | import org.apache.lucene.util.VectorUtil; |
@@ -49,8 +50,35 @@ long[] buildAndWritePostingsLists( |
49 | 50 | CentroidSupplier centroidSupplier, |
50 | 51 | FloatVectorValues floatVectorValues, |
51 | 52 | IndexOutput postingsOutput, |
52 | | - int[][] assignmentsByCluster |
| 53 | + int[] assignments, |
| 54 | + int[] overspillAssignments |
53 | 55 | ) throws IOException { |
| 56 | + int[] centroidVectorCount = new int[centroidSupplier.size()]; |
| 57 | + for (int i = 0; i < assignments.length; i++) { |
| 58 | + centroidVectorCount[assignments[i]]++; |
| 59 | + // if soar assignments are present, count them as well |
| 60 | + if (overspillAssignments.length > i && overspillAssignments[i] != -1) { |
| 61 | + centroidVectorCount[overspillAssignments[i]]++; |
| 62 | + } |
| 63 | + } |
| 64 | + |
| 65 | + int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; |
| 66 | + for (int c = 0; c < centroidSupplier.size(); c++) { |
| 67 | + assignmentsByCluster[c] = new int[centroidVectorCount[c]]; |
| 68 | + } |
| 69 | + Arrays.fill(centroidVectorCount, 0); |
| 70 | + |
| 71 | + for (int i = 0; i < assignments.length; i++) { |
| 72 | + int c = assignments[i]; |
| 73 | + assignmentsByCluster[c][centroidVectorCount[c]++] = i; |
| 74 | + // if soar assignments are present, add them to the cluster as well |
| 75 | + if (overspillAssignments.length > i) { |
| 76 | + int s = overspillAssignments[i]; |
| 77 | + if (s != -1) { |
| 78 | + assignmentsByCluster[s][centroidVectorCount[s]++] = i; |
| 79 | + } |
| 80 | + } |
| 81 | + } |
54 | 82 | // write the posting lists |
55 | 83 | final long[] offsets = new long[centroidSupplier.size()]; |
56 | 84 | OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); |
@@ -84,6 +112,92 @@ long[] buildAndWritePostingsLists( |
84 | 112 | return offsets; |
85 | 113 | } |
86 | 114 |
|
| 115 | + @Override |
| 116 | + long[] buildAndWritePostingsLists( |
| 117 | + FieldInfo fieldInfo, |
| 118 | + CentroidSupplier centroidSupplier, |
| 119 | + FloatVectorValues floatVectorValues, |
| 120 | + IndexOutput postingsOutput, |
| 121 | + MergeState mergeState, |
| 122 | + int[] assignments, |
| 123 | + int[] overspillAssignments |
| 124 | + ) throws IOException { |
| 125 | + // first, quantize all the vectors into a temporary file |
| 126 | + String quantizedVectorsTempName = null; |
| 127 | + IndexOutput quantizedVectorsTemp = null; |
| 128 | + boolean success = false; |
| 129 | + try { |
| 130 | + quantizedVectorsTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "qvec_", IOContext.DEFAULT); |
| 131 | + quantizedVectorsTempName = quantizedVectorsTemp.getName(); |
| 132 | + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); |
| 133 | + int[] quantized = new int[fieldInfo.getVectorDimension()]; |
| 134 | + byte[] binary = new byte[BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64) / 8]; |
| 135 | + for (int i = 0; i < assignments.length; i++) { |
| 136 | + int c = assignments[i]; |
| 137 | + float[] centroid = centroidSupplier.centroid(c); |
| 138 | + float[] vector = floatVectorValues.vectorValue(i); |
| 139 | + OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(vector, quantized, (byte) 1, centroid); |
| 140 | + BQVectorUtils.packAsBinary(quantized, binary); |
| 141 | + writeQuantizedValue(quantizedVectorsTemp, binary, result); |
| 142 | + boolean overspill = overspillAssignments.length > i && overspillAssignments[i] != -1; |
| 143 | + if (overspill) { |
| 144 | + int s = overspillAssignments[i]; |
| 145 | + // write the overspill vector as well |
| 146 | + result = quantizer.scalarQuantize(vector, quantized, (byte) 1, centroidSupplier.centroid(s)); |
| 147 | + BQVectorUtils.packAsBinary(quantized, binary); |
| 148 | + writeQuantizedValue(quantizedVectorsTemp, binary, result); |
| 149 | + } else { |
| 150 | + // write a zero vector for the overspill |
| 151 | + Arrays.fill(binary, (byte) 0); |
| 152 | + OptimizedScalarQuantizer.QuantizationResult zeroResult = new OptimizedScalarQuantizer.QuantizationResult(0f, 0f, 0f, 0); |
| 153 | + writeQuantizedValue(quantizedVectorsTemp, binary, zeroResult); |
| 154 | + } |
| 155 | + } |
| 156 | + // close the temporary file so we can read it later |
| 157 | + quantizedVectorsTemp.close(); |
| 158 | + success = true; |
| 159 | + } finally { |
| 160 | + if (success == false && quantizedVectorsTemp != null) { |
| 161 | + mergeState.segmentInfo.dir.deleteFile(quantizedVectorsTemp.getName()); |
| 162 | + } |
| 163 | + } |
| 164 | + // now we can read the quantized vectors from the temporary file |
| 165 | + try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) { |
| 166 | + final long[] offsets = new long[centroidSupplier.size()]; |
| 167 | + OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors( |
| 168 | + quantizedVectorsInput, |
| 169 | + fieldInfo.getVectorDimension() |
| 170 | + ); |
| 171 | + DocIdsWriter docIdsWriter = new DocIdsWriter(); |
| 172 | + DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter( |
| 173 | + ES91OSQVectorsScorer.BULK_SIZE, |
| 174 | + quantizer, |
| 175 | + floatVectorValues, |
| 176 | + postingsOutput |
| 177 | + ); |
| 178 | + for (int c = 0; c < centroidSupplier.size(); c++) { |
| 179 | + float[] centroid = centroidSupplier.centroid(c); |
| 180 | + // TODO: add back in sorting vectors by distance to centroid |
| 181 | + int[] cluster = assignmentsByCluster[c]; |
| 182 | + // TODO align??? |
| 183 | + offsets[c] = postingsOutput.getFilePointer(); |
| 184 | + int size = cluster.length; |
| 185 | + postingsOutput.writeVInt(size); |
| 186 | + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); |
| 187 | + // TODO we might want to consider putting the docIds in a separate file |
| 188 | + // to aid with only having to fetch vectors from slower storage when they are required |
| 189 | + // keeping them in the same file indicates we pull the entire file into cache |
| 190 | + docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); |
| 191 | + bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid); |
| 192 | + } |
| 193 | + |
| 194 | + if (logger.isDebugEnabled()) { |
| 195 | + printClusterQualityStatistics(assignmentsByCluster); |
| 196 | + } |
| 197 | + return offsets; |
| 198 | + } |
| 199 | + } |
| 200 | + |
87 | 201 | private static void printClusterQualityStatistics(int[][] clusters) { |
88 | 202 | float min = Float.MAX_VALUE; |
89 | 203 | float max = Float.MIN_VALUE; |
@@ -210,33 +324,7 @@ static CentroidAssignments buildCentroidAssignments(KMeansResult kMeansResult) { |
210 | 324 | float[][] centroids = kMeansResult.centroids(); |
211 | 325 | int[] assignments = kMeansResult.assignments(); |
212 | 326 | int[] soarAssignments = kMeansResult.soarAssignments(); |
213 | | - int[] centroidVectorCount = new int[centroids.length]; |
214 | | - for (int i = 0; i < assignments.length; i++) { |
215 | | - centroidVectorCount[assignments[i]]++; |
216 | | - // if soar assignments are present, count them as well |
217 | | - if (soarAssignments.length > i && soarAssignments[i] != -1) { |
218 | | - centroidVectorCount[soarAssignments[i]]++; |
219 | | - } |
220 | | - } |
221 | | - |
222 | | - int[][] assignmentsByCluster = new int[centroids.length][]; |
223 | | - for (int c = 0; c < centroids.length; c++) { |
224 | | - assignmentsByCluster[c] = new int[centroidVectorCount[c]]; |
225 | | - } |
226 | | - Arrays.fill(centroidVectorCount, 0); |
227 | | - |
228 | | - for (int i = 0; i < assignments.length; i++) { |
229 | | - int c = assignments[i]; |
230 | | - assignmentsByCluster[c][centroidVectorCount[c]++] = i; |
231 | | - // if soar assignments are present, add them to the cluster as well |
232 | | - if (soarAssignments.length > i) { |
233 | | - int s = soarAssignments[i]; |
234 | | - if (s != -1) { |
235 | | - assignmentsByCluster[s][centroidVectorCount[s]++] = i; |
236 | | - } |
237 | | - } |
238 | | - } |
239 | | - return new CentroidAssignments(centroids, assignmentsByCluster); |
| 327 | + return new CentroidAssignments(centroids, assignments, soarAssignments); |
240 | 328 | } |
241 | 329 |
|
242 | 330 | static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections) |
@@ -281,4 +369,48 @@ public float[] centroid(int centroidOrdinal) throws IOException { |
281 | 369 | return scratch; |
282 | 370 | } |
283 | 371 | } |
| 372 | + |
| 373 | + static class OffHeapQuantizedVectors { |
| 374 | + private final IndexInput quantizedVectorsInput; |
| 375 | + private final byte[] binaryScratch; |
| 376 | + private final float[] corrections = new float[3]; |
| 377 | + |
| 378 | + private final int vectorByteSize; |
| 379 | + private short bitSum; |
| 380 | + private int currOrd = -1; |
| 381 | + private boolean isOverspill = false; |
| 382 | + |
| 383 | + OffHeapQuantizedVectors(IndexInput quantizedVectorsInput, int dimension) { |
| 384 | + this.quantizedVectorsInput = quantizedVectorsInput; |
| 385 | + this.binaryScratch = new byte[BQVectorUtils.discretize(dimension, 64) / 8]; |
| 386 | + this.vectorByteSize = (binaryScratch.length + 3 * Float.BYTES + Short.BYTES); |
| 387 | + } |
| 388 | + |
| 389 | + byte[] getVector(int ord, boolean isOverspill) throws IOException { |
| 390 | + readQuantizedVector(ord, isOverspill); |
| 391 | + return binaryScratch; |
| 392 | + } |
| 393 | + |
| 394 | + OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException { |
| 395 | + if (currOrd == -1) { |
| 396 | + throw new IllegalStateException("No vector read yet, call readQuantizedVector first"); |
| 397 | + } |
| 398 | + return new OptimizedScalarQuantizer.QuantizationResult(corrections[0], corrections[1], corrections[2], bitSum); |
| 399 | + } |
| 400 | + |
| 401 | + public void readQuantizedVector(int ord, boolean isOverspill) throws IOException { |
| 402 | + if (ord == currOrd && isOverspill == this.isOverspill) { |
| 403 | + return; // no need to read again |
| 404 | + } |
| 405 | + long offset = (long) ord * (vectorByteSize * 2) + (isOverspill ? vectorByteSize : 0); |
| 406 | + quantizedVectorsInput.seek(offset); |
| 407 | + quantizedVectorsInput.readBytes(binaryScratch, 0, binaryScratch.length); |
| 408 | + quantizedVectorsInput.readFloats(corrections, 0, 3); |
| 409 | + bitSum = quantizedVectorsInput.readShort(); |
| 410 | + if (ord != currOrd) { |
| 411 | + currOrd = ord; |
| 412 | + } |
| 413 | + this.isOverspill = isOverspill; |
| 414 | + } |
| 415 | + } |
284 | 416 | } |
0 commit comments