Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.LongsRef;
import org.apache.lucene.util.hnsw.IntToIntFunction;

import java.io.IOException;
Expand All @@ -43,7 +42,6 @@ final class DocIdsWriter {
private static final byte BPV_32 = (byte) 32;

private int[] scratch = new int[0];
private final LongsRef scratchLongs = new LongsRef();

/**
* IntsRef to be used to iterate over the scratch buffer. A single instance is reused to avoid
Expand All @@ -63,6 +61,175 @@ final class DocIdsWriter {

DocIdsWriter() {}

/**
* Calculate the best encoding that will be used to write blocks of doc ids of blockSize.
* The encoding choice is universal for all the blocks, which means that the encoding is only as
* efficient as the worst block.
* @param docIds function to access the doc ids
* @param count number of doc ids
* @param blockSize the block size
* @return the byte encoding to use for the blocks
*/
byte calculateBlockEncoding(IntToIntFunction docIds, int count, int blockSize) {
if (count == 0) {
return CONTINUOUS_IDS;
}
byte encoding = CONTINUOUS_IDS;
int iterationLimit = count - blockSize + 1;
int i = 0;
for (; i < iterationLimit; i += blockSize) {
int offset = i;
encoding = (byte) Math.max(encoding, blockEncoding(d -> docIds.apply(offset + d), blockSize));
}
// check the tail
if (i == count) {
return encoding;
}
int offset = i;
encoding = (byte) Math.max(encoding, blockEncoding(d -> docIds.apply(offset + d), count - i));
return encoding;
}

void writeDocIds(IntToIntFunction docIds, int count, byte encoding, DataOutput out) throws IOException {
if (count == 0) {
return;
}
if (count > scratch.length) {
scratch = new int[count];
}
int min = docIds.apply(0);
for (int i = 1; i < count; ++i) {
int current = docIds.apply(i);
min = Math.min(min, current);
}
switch (encoding) {
case CONTINUOUS_IDS:
writeContinuousIds(docIds, count, out);
break;
case DELTA_BPV_16:
writeDelta16(docIds, count, min, out);
break;
case BPV_21:
write21(docIds, count, min, out);
break;
case BPV_24:
write24(docIds, count, min, out);
break;
case BPV_32:
write32(docIds, count, min, out);
break;
default:
throw new IOException("Unsupported number of bits per value: " + encoding);
}
}

private static void writeContinuousIds(IntToIntFunction docIds, int count, DataOutput out) throws IOException {
out.writeVInt(docIds.apply(0));
}

private void writeDelta16(IntToIntFunction docIds, int count, int min, DataOutput out) throws IOException {
for (int i = 0; i < count; i++) {
scratch[i] = docIds.apply(i) - min;
}
out.writeVInt(min);
final int halfLen = count >> 1;
for (int i = 0; i < halfLen; ++i) {
scratch[i] = scratch[halfLen + i] | (scratch[i] << 16);
}
for (int i = 0; i < halfLen; i++) {
out.writeInt(scratch[i]);
}
if ((count & 1) == 1) {
out.writeShort((short) scratch[count - 1]);
}
}

private void write21(IntToIntFunction docIds, int count, int min, DataOutput out) throws IOException {
final int oneThird = floorToMultipleOf16(count / 3);
final int numInts = oneThird * 2;
for (int i = 0; i < numInts; i++) {
scratch[i] = docIds.apply(i) << 11;
}
for (int i = 0; i < oneThird; i++) {
final int longIdx = i + numInts;
scratch[i] |= docIds.apply(longIdx) & 0x7FF;
scratch[i + oneThird] |= (docIds.apply(longIdx) >>> 11) & 0x7FF;
}
for (int i = 0; i < numInts; i++) {
out.writeInt(scratch[i]);
}
int i = oneThird * 3;
for (; i < count - 2; i += 3) {
out.writeLong(((long) docIds.apply(i)) | (((long) docIds.apply(i + 1)) << 21) | (((long) docIds.apply(i + 2)) << 42));
}
for (; i < count; ++i) {
out.writeShort((short) docIds.apply(i));
out.writeByte((byte) (docIds.apply(i) >>> 16));
}
}

private void write24(IntToIntFunction docIds, int count, int min, DataOutput out) throws IOException {

// encode the docs in the format that can be vectorized decoded.
final int quarter = count >> 2;
final int numInts = quarter * 3;
for (int i = 0; i < numInts; i++) {
scratch[i] = docIds.apply(i) << 8;
}
for (int i = 0; i < quarter; i++) {
final int longIdx = i + numInts;
scratch[i] |= docIds.apply(longIdx) & 0xFF;
scratch[i + quarter] |= (docIds.apply(longIdx) >>> 8) & 0xFF;
scratch[i + quarter * 2] |= docIds.apply(longIdx) >>> 16;
}
for (int i = 0; i < numInts; i++) {
out.writeInt(scratch[i]);
}
for (int i = quarter << 2; i < count; ++i) {
out.writeShort((short) docIds.apply(i));
out.writeByte((byte) (docIds.apply(i) >>> 16));
}
}

private void write32(IntToIntFunction docIds, int count, int min, DataOutput out) throws IOException {
for (int i = 0; i < count; i++) {
out.writeInt(docIds.apply(i));
}
}

private static byte blockEncoding(IntToIntFunction docIds, int count) {
// docs can be sorted either when all docs in a block have the same value
// or when a segment is sorted
boolean strictlySorted = true;
int min = docIds.apply(0);
int max = min;
for (int i = 1; i < count; ++i) {
int last = docIds.apply(i - 1);
int current = docIds.apply(i);
if (last >= current) {
strictlySorted = false;
}
min = Math.min(min, current);
max = Math.max(max, current);
}

int min2max = max - min + 1;
if (strictlySorted && min2max == count) {
return CONTINUOUS_IDS;
}
if (min2max <= 0xFFFF) {
return DELTA_BPV_16;
} else {
if (max <= 0x1FFFFF) {
return BPV_21;
} else if (max <= 0xFFFFFF) {
return BPV_24;
} else {
return BPV_32;
}
}
}

void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOException {
if (count == 0) {
return;
Expand All @@ -89,91 +256,35 @@ void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOEx
if (strictlySorted && min2max == count) {
// continuous ids, typically happens when segment is sorted
out.writeByte(CONTINUOUS_IDS);
out.writeVInt(docIds.apply(0));
writeContinuousIds(docIds, count, out);
return;
}

if (min2max <= 0xFFFF) {
out.writeByte(DELTA_BPV_16);
for (int i = 0; i < count; i++) {
scratch[i] = docIds.apply(i) - min;
}
out.writeVInt(min);
final int halfLen = count >> 1;
for (int i = 0; i < halfLen; ++i) {
scratch[i] = scratch[halfLen + i] | (scratch[i] << 16);
}
for (int i = 0; i < halfLen; i++) {
out.writeInt(scratch[i]);
}
if ((count & 1) == 1) {
out.writeShort((short) scratch[count - 1]);
}
writeDelta16(docIds, count, min, out);
} else {
if (max <= 0x1FFFFF) {
out.writeByte(BPV_21);
final int oneThird = floorToMultipleOf16(count / 3);
final int numInts = oneThird * 2;
for (int i = 0; i < numInts; i++) {
scratch[i] = docIds.apply(i) << 11;
}
for (int i = 0; i < oneThird; i++) {
final int longIdx = i + numInts;
scratch[i] |= docIds.apply(longIdx) & 0x7FF;
scratch[i + oneThird] |= (docIds.apply(longIdx) >>> 11) & 0x7FF;
}
for (int i = 0; i < numInts; i++) {
out.writeInt(scratch[i]);
}
int i = oneThird * 3;
for (; i < count - 2; i += 3) {
out.writeLong(((long) docIds.apply(i)) | (((long) docIds.apply(i + 1)) << 21) | (((long) docIds.apply(i + 2)) << 42));
}
for (; i < count; ++i) {
out.writeShort((short) docIds.apply(i));
out.writeByte((byte) (docIds.apply(i) >>> 16));
}
write21(docIds, count, min, out);
} else if (max <= 0xFFFFFF) {
out.writeByte(BPV_24);

// encode the docs in the format that can be vectorized decoded.
final int quarter = count >> 2;
final int numInts = quarter * 3;
for (int i = 0; i < numInts; i++) {
scratch[i] = docIds.apply(i) << 8;
}
for (int i = 0; i < quarter; i++) {
final int longIdx = i + numInts;
scratch[i] |= docIds.apply(longIdx) & 0xFF;
scratch[i + quarter] |= (docIds.apply(longIdx) >>> 8) & 0xFF;
scratch[i + quarter * 2] |= docIds.apply(longIdx) >>> 16;
}
for (int i = 0; i < numInts; i++) {
out.writeInt(scratch[i]);
}
for (int i = quarter << 2; i < count; ++i) {
out.writeShort((short) docIds.apply(i));
out.writeByte((byte) (docIds.apply(i) >>> 16));
}
write24(docIds, count, min, out);
} else {
out.writeByte(BPV_32);
for (int i = 0; i < count; i++) {
out.writeInt(docIds.apply(i));
}
write32(docIds, count, min, out);
}
}
}

/** Read {@code count} integers into {@code docIDs}. */
void readInts(IndexInput in, int count, int[] docIDs) throws IOException {
void readInts(IndexInput in, int count, byte encoding, int[] docIDs) throws IOException {
if (count == 0) {
return;
}
if (count > scratch.length) {
scratch = new int[count];
}
final int bpv = in.readByte();
switch (bpv) {
switch (encoding) {
case CONTINUOUS_IDS:
readContinuousIds(in, count, docIDs);
break;
Expand All @@ -190,8 +301,20 @@ void readInts(IndexInput in, int count, int[] docIDs) throws IOException {
readInts32(in, count, docIDs);
break;
default:
throw new IOException("Unsupported number of bits per value: " + bpv);
throw new IOException("Unsupported number of bits per value: " + encoding);
}
}

/** Read {@code count} integers into {@code docIDs}. */
void readInts(IndexInput in, int count, int[] docIDs) throws IOException {
if (count == 0) {
return;
}
if (count > scratch.length) {
scratch = new int[count];
}
final int bpv = in.readByte();
readInts(in, count, (byte) bpv, docIDs);
}

private static void readContinuousIds(IndexInput in, int count, int[] docIDs) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ public void testContinuousIds() throws Exception {
}

private void test(Directory dir, int[] ints) throws Exception {
if (random().nextBoolean()) {
testSingleBlock(dir, ints);
} else {
testMultiBlock(dir, ints);
}
}

private void testSingleBlock(Directory dir, int[] ints) throws Exception {
final long len;
// It is hard to get BPV24-encoded docs in TextLuceneXXPointsFormat, test bwc here as well.
DocIdsWriter docIdsWriter = new DocIdsWriter();
Expand All @@ -143,6 +151,52 @@ private void test(Directory dir, int[] ints) throws Exception {
dir.deleteFile("tmp");
}

private void testMultiBlock(Directory dir, int[] ints) throws Exception {
final long len;
final int blockSize = 16 + random().nextInt(100);
DocIdsWriter docIdsWriter = new DocIdsWriter();
try (IndexOutput out = dir.createOutput("tmp", IOContext.DEFAULT)) {
byte encoding = docIdsWriter.calculateBlockEncoding(i -> ints[i], ints.length, blockSize);
out.writeByte(encoding);
int limit = ints.length - blockSize + 1;
int i = 0;
for (; i < limit; i += blockSize) {
int offset = i;
docIdsWriter.writeDocIds(d -> ints[d + offset], blockSize, encoding, out);
}
// handle tail
if (i < ints.length) {
int offset = i;
docIdsWriter.writeDocIds(d -> ints[d + offset], ints.length - i, encoding, out);
}
len = out.getFilePointer();
if (random().nextBoolean()) {
out.writeLong(0); // garbage
}
}
try (IndexInput in = dir.openInput("tmp", IOContext.READONCE)) {
int[] read = new int[ints.length];
int[] block = new int[blockSize];
int limit = ints.length - blockSize + 1;
byte encoding = in.readByte();
int i = 0;
for (; i < limit; i += blockSize) {
int offset = i;
docIdsWriter.readInts(in, blockSize, encoding, block);
System.arraycopy(block, 0, read, offset, blockSize);
}
// handle tail
if (i < ints.length) {
int offset = i;
docIdsWriter.readInts(in, ints.length - i, encoding, block);
System.arraycopy(block, 0, read, offset, ints.length - i);
}
assertArrayEquals(ints, read);
assertEquals(len, in.getFilePointer());
}
dir.deleteFile("tmp");
}

// This simple test tickles a JVM C2 JIT crash on JDK's less than 21.0.1
// Crashes only when run with HotSpot C2.
// Regardless of whether C2 is enabled or not, the test should never fail.
Expand Down