Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,12 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
final float[] centroid;
long slicePos;
OptimizedScalarQuantizer.QuantizationResult queryCorrections;
DocIdsWriter docIdsWriter = new DocIdsWriter();

final float[] scratch;
final int[] quantizationScratch;
final byte[] quantizedQueryScratch;
final OptimizedScalarQuantizer quantizer;
final DocIdsWriter idsWriter = new DocIdsWriter();
final float[] correctiveValues = new float[3];
final long quantizedVectorByteSize;

Expand Down Expand Up @@ -369,7 +369,13 @@ public int resetPostingsScorer(long offset) throws IOException {
vectors = indexInput.readVInt();
// read the doc ids
assert vectors <= docIdsScratch.length;
docIdsWriter.readInts(indexInput, vectors, docIdsScratch);
idsWriter.readInts(indexInput, vectors, docIdsScratch);
// reconstitute from the deltas
int sum = 0;
for (int i = 0; i < vectors; i++) {
sum += docIdsScratch[i];
docIdsScratch[i] = sum;
}
slicePos = indexInput.getFilePointer();
return vectors;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IntroSorter;
import org.apache.lucene.util.LongValues;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.IntToIntFunction;
Expand Down Expand Up @@ -101,14 +102,17 @@ LongValues buildAndWritePostingsLists(
postingsOutput.writeVInt(maxPostingListSize);
// write the posting lists
final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT);
DocIdsWriter docIdsWriter = new DocIdsWriter();
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors(
floatVectorValues,
fieldInfo.getVectorDimension(),
new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction())
);
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
int[] docIds = null;
int[] docDeltas = null;
int[] clusterOrds = null;
DocIdsWriter idsWriter = new DocIdsWriter();
for (int c = 0; c < centroidSupplier.size(); c++) {
float[] centroid = centroidSupplier.centroid(c);
int[] cluster = assignmentsByCluster[c];
Expand All @@ -121,11 +125,29 @@ LongValues buildAndWritePostingsLists(
int size = cluster.length;
// write docIds
postingsOutput.writeVInt(size);
onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]);
if (docIds == null || docIds.length < cluster.length) {
docIds = new int[cluster.length];
clusterOrds = new int[cluster.length];
docDeltas = new int[cluster.length];
}
for (int j = 0; j < size; j++) {
docIds[j] = floatVectorValues.ordToDoc(cluster[j]);
clusterOrds[j] = j;
}
final int[] finalDocs = docIds;
final int[] finalOrds = clusterOrds;
// sort cluster.buffer by docIds values, this way cluster ordinals are sorted by docIds
new IntSorter(clusterOrds, i -> finalDocs[i]).sort(0, size);
// encode doc deltas
for (int j = 0; j < size; j++) {
docDeltas[j] = j == 0 ? finalDocs[finalOrds[j]] : finalDocs[finalOrds[j]] - finalDocs[finalOrds[j - 1]];
}
final int[] finalDocDeltas = docDeltas;
onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[finalOrds[ord]]);
// TODO we might want to consider putting the docIds in a separate file
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
idsWriter.writeDocIds(i -> finalDocDeltas[i], size, postingsOutput);
// write vectors
bulkWriter.writeVectors(onHeapQuantizedVectors);
}
Expand Down Expand Up @@ -233,12 +255,15 @@ LongValues buildAndWritePostingsLists(
quantizedVectorsInput,
fieldInfo.getVectorDimension()
);
DocIdsWriter docIdsWriter = new DocIdsWriter();
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
// write the max posting list size
postingsOutput.writeVInt(maxPostingListSize);
// write the posting lists
int[] docIds = null;
int[] docDeltas = null;
int[] clusterOrds = null;
DocIdsWriter idsWriter = new DocIdsWriter();
for (int c = 0; c < centroidSupplier.size(); c++) {
float[] centroid = centroidSupplier.centroid(c);
int[] cluster = assignmentsByCluster[c];
Expand All @@ -252,11 +277,29 @@ LongValues buildAndWritePostingsLists(
// write docIds
int size = cluster.length;
postingsOutput.writeVInt(size);
offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]);
if (docIds == null || docIds.length < cluster.length) {
docIds = new int[cluster.length];
clusterOrds = new int[cluster.length];
docDeltas = new int[cluster.length];
}
for (int j = 0; j < size; j++) {
docIds[j] = floatVectorValues.ordToDoc(cluster[j]);
clusterOrds[j] = j;
}
final int[] finalDocs = docIds;
final int[] finalOrds = clusterOrds;
// sort cluster.buffer by docIds values, this way cluster ordinals are sorted by docIds
new IntSorter(clusterOrds, i -> finalDocs[i]).sort(0, size);
// encode doc deltas
for (int j = 0; j < size; j++) {
docDeltas[j] = j == 0 ? finalDocs[finalOrds[j]] : finalDocs[finalOrds[j]] - finalDocs[finalOrds[j - 1]];
}
final int[] finalDocDeltas = docDeltas;
offHeapQuantizedVectors.reset(size, ord -> isOverspill[finalOrds[ord]], ord -> cluster[finalOrds[ord]]);
// TODO we might want to consider putting the docIds in a separate file
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput);
idsWriter.writeDocIds(i -> finalDocDeltas[i], size, postingsOutput);
// write vectors
bulkWriter.writeVectors(offHeapQuantizedVectors);
}
Expand Down Expand Up @@ -717,4 +760,37 @@ public void readQuantizedVector(int ord, boolean isOverspill) throws IOException
bitSum = quantizedVectorsInput.readShort();
}
}

private static class IntSorter extends IntroSorter {
int pivot = -1;
private final int[] arr;
private final IntToIntFunction func;

private IntSorter(int[] arr, IntToIntFunction func) {
this.arr = arr;
this.func = func;
}

@Override
protected void setPivot(int i) {
pivot = func.apply(arr[i]);
}

@Override
protected int comparePivot(int j) {
return Integer.compare(pivot, func.apply(arr[j]));
}

@Override
protected int compare(int a, int b) {
return Integer.compare(func.apply(arr[a]), func.apply(arr[b]));
}

@Override
protected void swap(int i, int j) {
final int tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.DocBaseBitSetIterator;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.LongsRef;
import org.apache.lucene.util.hnsw.IntToIntFunction;
Expand All @@ -42,7 +40,6 @@ final class DocIdsWriter {
public static final int DEFAULT_MAX_POINTS_IN_LEAF_NODE = 512;

private static final byte CONTINUOUS_IDS = (byte) -2;
private static final byte BITSET_IDS = (byte) -1;
private static final byte DELTA_BPV_16 = (byte) 16;
private static final byte BPV_21 = (byte) 21;
private static final byte BPV_24 = (byte) 24;
Expand Down Expand Up @@ -92,21 +89,11 @@ void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOEx
}

int min2max = max - min + 1;
if (strictlySorted) {
if (min2max == count) {
// continuous ids, typically happens when segment is sorted
out.writeByte(CONTINUOUS_IDS);
out.writeVInt(docIds.apply(0));
return;
} else if (min2max <= (count << 4)) {
assert min2max > count : "min2max: " + min2max + ", count: " + count;
// Only trigger bitset optimization when max - min + 1 <= 16 * count in order to avoid
// expanding too much storage.
// A field with lower cardinality will have higher probability to trigger this optimization.
out.writeByte(BITSET_IDS);
writeIdsAsBitSet(docIds, count, out);
return;
}
if (strictlySorted && min2max == count) {
// continuous ids, typically happens when segment is sorted
out.writeByte(CONTINUOUS_IDS);
out.writeVInt(docIds.apply(0));
return;
}

if (min2max <= 0xFFFF) {
Expand Down Expand Up @@ -180,38 +167,6 @@ void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOEx
}
}

private static void writeIdsAsBitSet(IntToIntFunction docIds, int count, DataOutput out) throws IOException {
int min = docIds.apply(0);
int max = docIds.apply(count - 1);

final int offsetWords = min >> 6;
final int offsetBits = offsetWords << 6;
final int totalWordCount = FixedBitSet.bits2words(max - offsetBits + 1);
long currentWord = 0;
int currentWordIndex = 0;

out.writeVInt(offsetWords);
out.writeVInt(totalWordCount);
// build bit set streaming
for (int i = 0; i < count; i++) {
final int index = docIds.apply(i) - offsetBits;
final int nextWordIndex = index >> 6;
assert currentWordIndex <= nextWordIndex;
if (currentWordIndex < nextWordIndex) {
out.writeLong(currentWord);
currentWord = 0L;
currentWordIndex++;
while (currentWordIndex < nextWordIndex) {
currentWordIndex++;
out.writeLong(0L);
}
}
currentWord |= 1L << index;
}
out.writeLong(currentWord);
assert currentWordIndex + 1 == totalWordCount;
}

/** Read {@code count} integers into {@code docIDs}. */
void readInts(IndexInput in, int count, int[] docIDs) throws IOException {
if (count == 0) {
Expand All @@ -225,9 +180,6 @@ void readInts(IndexInput in, int count, int[] docIDs) throws IOException {
case CONTINUOUS_IDS:
readContinuousIds(in, count, docIDs);
break;
case BITSET_IDS:
readBitSet(in, count, docIDs);
break;
case DELTA_BPV_16:
readDelta16(in, count, docIDs);
break;
Expand All @@ -245,36 +197,13 @@ void readInts(IndexInput in, int count, int[] docIDs) throws IOException {
}
}

private DocIdSetIterator readBitSetIterator(IndexInput in, int count) throws IOException {
int offsetWords = in.readVInt();
int longLen = in.readVInt();
scratchLongs.longs = ArrayUtil.growNoCopy(scratchLongs.longs, longLen);
in.readLongs(scratchLongs.longs, 0, longLen);
// make ghost bits clear for FixedBitSet.
if (longLen < scratchLongs.length) {
Arrays.fill(scratchLongs.longs, longLen, scratchLongs.longs.length, 0);
}
scratchLongs.length = longLen;
FixedBitSet bitSet = new FixedBitSet(scratchLongs.longs, longLen << 6);
return new DocBaseBitSetIterator(bitSet, count, offsetWords << 6);
}

private static void readContinuousIds(IndexInput in, int count, int[] docIDs) throws IOException {
int start = in.readVInt();
for (int i = 0; i < count; i++) {
docIDs[i] = start + i;
}
}

private void readBitSet(IndexInput in, int count, int[] docIDs) throws IOException {
DocIdSetIterator iterator = readBitSetIterator(in, count);
int docId, pos = 0;
while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
docIDs[pos++] = docId;
}
assert pos == count : "pos: " + pos + ", count: " + count;
}

private static void readDelta16(IndexInput in, int count, int[] docIds) throws IOException {
final int min = in.readVInt();
final int half = count >> 1;
Expand Down