diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DocIdsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DocIdsWriter.java index 257a1340eeff1..2dd4acc5056e2 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DocIdsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DocIdsWriter.java @@ -22,7 +22,6 @@ import org.apache.lucene.store.DataOutput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.IntsRef; -import org.apache.lucene.util.LongsRef; import org.apache.lucene.util.hnsw.IntToIntFunction; import java.io.IOException; @@ -43,7 +42,6 @@ final class DocIdsWriter { private static final byte BPV_32 = (byte) 32; private int[] scratch = new int[0]; - private final LongsRef scratchLongs = new LongsRef(); /** * IntsRef to be used to iterate over the scratch buffer. A single instance is reused to avoid @@ -63,6 +61,175 @@ final class DocIdsWriter { DocIdsWriter() {} + /** + * Calculate the best encoding that will be used to write blocks of doc ids of blockSize. + * The encoding choice is universal for all the blocks, which means that the encoding is only as + * efficient as the worst block. + * @param docIds function to access the doc ids + * @param count number of doc ids + * @param blockSize the block size + * @return the byte encoding to use for the blocks + */ + byte calculateBlockEncoding(IntToIntFunction docIds, int count, int blockSize) { + if (count == 0) { + return CONTINUOUS_IDS; + } + byte encoding = CONTINUOUS_IDS; + int iterationLimit = count - blockSize + 1; + int i = 0; + for (; i < iterationLimit; i += blockSize) { + int offset = i; + encoding = (byte) Math.max(encoding, blockEncoding(d -> docIds.apply(offset + d), blockSize)); + } + // check the tail + if (i == count) { + return encoding; + } + int offset = i; + encoding = (byte) Math.max(encoding, blockEncoding(d -> docIds.apply(offset + d), count - i)); + return encoding; + } + + void writeDocIds(IntToIntFunction docIds, int count, byte encoding, DataOutput out) throws IOException { + if (count == 0) { + return; + } + if (count > scratch.length) { + scratch = new int[count]; + } + int min = docIds.apply(0); + for (int i = 1; i < count; ++i) { + int current = docIds.apply(i); + min = Math.min(min, current); + } + switch (encoding) { + case CONTINUOUS_IDS: + writeContinuousIds(docIds, count, out); + break; + case DELTA_BPV_16: + writeDelta16(docIds, count, min, out); + break; + case BPV_21: + write21(docIds, count, min, out); + break; + case BPV_24: + write24(docIds, count, min, out); + break; + case BPV_32: + write32(docIds, count, min, out); + break; + default: + throw new IOException("Unsupported number of bits per value: " + encoding); + } + } + + private static void writeContinuousIds(IntToIntFunction docIds, int count, DataOutput out) throws IOException { + out.writeVInt(docIds.apply(0)); + } + + private void writeDelta16(IntToIntFunction docIds, int count, int min, DataOutput out) throws IOException { + for (int i = 0; i < count; i++) { + scratch[i] = docIds.apply(i) - min; + } + out.writeVInt(min); + final int halfLen = count >> 1; + for (int i = 0; i < halfLen; ++i) { + scratch[i] = scratch[halfLen + i] | (scratch[i] << 16); + } + for (int i = 0; i < halfLen; i++) { + out.writeInt(scratch[i]); + } + if ((count & 1) == 1) { + out.writeShort((short) scratch[count - 1]); + } + } + + private void write21(IntToIntFunction docIds, int count, int min, DataOutput out) throws IOException { + final int oneThird = floorToMultipleOf16(count / 3); + final int numInts = oneThird * 2; + for (int i = 0; i < numInts; i++) { + scratch[i] = docIds.apply(i) << 11; + } + for (int i = 0; i < oneThird; i++) { + final int longIdx = i + numInts; + scratch[i] |= docIds.apply(longIdx) & 0x7FF; + scratch[i + oneThird] |= (docIds.apply(longIdx) >>> 11) & 0x7FF; + } + for (int i = 0; i < numInts; i++) { + out.writeInt(scratch[i]); + } + int i = oneThird * 3; + for (; i < count - 2; i += 3) { + out.writeLong(((long) docIds.apply(i)) | (((long) docIds.apply(i + 1)) << 21) | (((long) docIds.apply(i + 2)) << 42)); + } + for (; i < count; ++i) { + out.writeShort((short) docIds.apply(i)); + out.writeByte((byte) (docIds.apply(i) >>> 16)); + } + } + + private void write24(IntToIntFunction docIds, int count, int min, DataOutput out) throws IOException { + + // encode the docs in the format that can be vectorized decoded. + final int quarter = count >> 2; + final int numInts = quarter * 3; + for (int i = 0; i < numInts; i++) { + scratch[i] = docIds.apply(i) << 8; + } + for (int i = 0; i < quarter; i++) { + final int longIdx = i + numInts; + scratch[i] |= docIds.apply(longIdx) & 0xFF; + scratch[i + quarter] |= (docIds.apply(longIdx) >>> 8) & 0xFF; + scratch[i + quarter * 2] |= docIds.apply(longIdx) >>> 16; + } + for (int i = 0; i < numInts; i++) { + out.writeInt(scratch[i]); + } + for (int i = quarter << 2; i < count; ++i) { + out.writeShort((short) docIds.apply(i)); + out.writeByte((byte) (docIds.apply(i) >>> 16)); + } + } + + private void write32(IntToIntFunction docIds, int count, int min, DataOutput out) throws IOException { + for (int i = 0; i < count; i++) { + out.writeInt(docIds.apply(i)); + } + } + + private static byte blockEncoding(IntToIntFunction docIds, int count) { + // docs can be sorted either when all docs in a block have the same value + // or when a segment is sorted + boolean strictlySorted = true; + int min = docIds.apply(0); + int max = min; + for (int i = 1; i < count; ++i) { + int last = docIds.apply(i - 1); + int current = docIds.apply(i); + if (last >= current) { + strictlySorted = false; + } + min = Math.min(min, current); + max = Math.max(max, current); + } + + int min2max = max - min + 1; + if (strictlySorted && min2max == count) { + return CONTINUOUS_IDS; + } + if (min2max <= 0xFFFF) { + return DELTA_BPV_16; + } else { + if (max <= 0x1FFFFF) { + return BPV_21; + } else if (max <= 0xFFFFFF) { + return BPV_24; + } else { + return BPV_32; + } + } + } + void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOException { if (count == 0) { return; @@ -89,91 +256,35 @@ void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOEx if (strictlySorted && min2max == count) { // continuous ids, typically happens when segment is sorted out.writeByte(CONTINUOUS_IDS); - out.writeVInt(docIds.apply(0)); + writeContinuousIds(docIds, count, out); return; } if (min2max <= 0xFFFF) { out.writeByte(DELTA_BPV_16); - for (int i = 0; i < count; i++) { - scratch[i] = docIds.apply(i) - min; - } - out.writeVInt(min); - final int halfLen = count >> 1; - for (int i = 0; i < halfLen; ++i) { - scratch[i] = scratch[halfLen + i] | (scratch[i] << 16); - } - for (int i = 0; i < halfLen; i++) { - out.writeInt(scratch[i]); - } - if ((count & 1) == 1) { - out.writeShort((short) scratch[count - 1]); - } + writeDelta16(docIds, count, min, out); } else { if (max <= 0x1FFFFF) { out.writeByte(BPV_21); - final int oneThird = floorToMultipleOf16(count / 3); - final int numInts = oneThird * 2; - for (int i = 0; i < numInts; i++) { - scratch[i] = docIds.apply(i) << 11; - } - for (int i = 0; i < oneThird; i++) { - final int longIdx = i + numInts; - scratch[i] |= docIds.apply(longIdx) & 0x7FF; - scratch[i + oneThird] |= (docIds.apply(longIdx) >>> 11) & 0x7FF; - } - for (int i = 0; i < numInts; i++) { - out.writeInt(scratch[i]); - } - int i = oneThird * 3; - for (; i < count - 2; i += 3) { - out.writeLong(((long) docIds.apply(i)) | (((long) docIds.apply(i + 1)) << 21) | (((long) docIds.apply(i + 2)) << 42)); - } - for (; i < count; ++i) { - out.writeShort((short) docIds.apply(i)); - out.writeByte((byte) (docIds.apply(i) >>> 16)); - } + write21(docIds, count, min, out); } else if (max <= 0xFFFFFF) { out.writeByte(BPV_24); - - // encode the docs in the format that can be vectorized decoded. - final int quarter = count >> 2; - final int numInts = quarter * 3; - for (int i = 0; i < numInts; i++) { - scratch[i] = docIds.apply(i) << 8; - } - for (int i = 0; i < quarter; i++) { - final int longIdx = i + numInts; - scratch[i] |= docIds.apply(longIdx) & 0xFF; - scratch[i + quarter] |= (docIds.apply(longIdx) >>> 8) & 0xFF; - scratch[i + quarter * 2] |= docIds.apply(longIdx) >>> 16; - } - for (int i = 0; i < numInts; i++) { - out.writeInt(scratch[i]); - } - for (int i = quarter << 2; i < count; ++i) { - out.writeShort((short) docIds.apply(i)); - out.writeByte((byte) (docIds.apply(i) >>> 16)); - } + write24(docIds, count, min, out); } else { out.writeByte(BPV_32); - for (int i = 0; i < count; i++) { - out.writeInt(docIds.apply(i)); - } + write32(docIds, count, min, out); } } } - /** Read {@code count} integers into {@code docIDs}. */ - void readInts(IndexInput in, int count, int[] docIDs) throws IOException { + void readInts(IndexInput in, int count, byte encoding, int[] docIDs) throws IOException { if (count == 0) { return; } if (count > scratch.length) { scratch = new int[count]; } - final int bpv = in.readByte(); - switch (bpv) { + switch (encoding) { case CONTINUOUS_IDS: readContinuousIds(in, count, docIDs); break; @@ -190,8 +301,20 @@ void readInts(IndexInput in, int count, int[] docIDs) throws IOException { readInts32(in, count, docIDs); break; default: - throw new IOException("Unsupported number of bits per value: " + bpv); + throw new IOException("Unsupported number of bits per value: " + encoding); + } + } + + /** Read {@code count} integers into {@code docIDs}. */ + void readInts(IndexInput in, int count, int[] docIDs) throws IOException { + if (count == 0) { + return; } + if (count > scratch.length) { + scratch = new int[count]; + } + final int bpv = in.readByte(); + readInts(in, count, (byte) bpv, docIDs); } private static void readContinuousIds(IndexInput in, int count, int[] docIDs) throws IOException { diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/DocIdsWriterTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/DocIdsWriterTests.java index 8f26369e6ded4..9235823777b01 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/DocIdsWriterTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/DocIdsWriterTests.java @@ -124,6 +124,14 @@ public void testContinuousIds() throws Exception { } private void test(Directory dir, int[] ints) throws Exception { + if (random().nextBoolean()) { + testSingleBlock(dir, ints); + } else { + testMultiBlock(dir, ints); + } + } + + private void testSingleBlock(Directory dir, int[] ints) throws Exception { final long len; // It is hard to get BPV24-encoded docs in TextLuceneXXPointsFormat, test bwc here as well. DocIdsWriter docIdsWriter = new DocIdsWriter(); @@ -143,6 +151,52 @@ private void test(Directory dir, int[] ints) throws Exception { dir.deleteFile("tmp"); } + private void testMultiBlock(Directory dir, int[] ints) throws Exception { + final long len; + final int blockSize = 16 + random().nextInt(100); + DocIdsWriter docIdsWriter = new DocIdsWriter(); + try (IndexOutput out = dir.createOutput("tmp", IOContext.DEFAULT)) { + byte encoding = docIdsWriter.calculateBlockEncoding(i -> ints[i], ints.length, blockSize); + out.writeByte(encoding); + int limit = ints.length - blockSize + 1; + int i = 0; + for (; i < limit; i += blockSize) { + int offset = i; + docIdsWriter.writeDocIds(d -> ints[d + offset], blockSize, encoding, out); + } + // handle tail + if (i < ints.length) { + int offset = i; + docIdsWriter.writeDocIds(d -> ints[d + offset], ints.length - i, encoding, out); + } + len = out.getFilePointer(); + if (random().nextBoolean()) { + out.writeLong(0); // garbage + } + } + try (IndexInput in = dir.openInput("tmp", IOContext.READONCE)) { + int[] read = new int[ints.length]; + int[] block = new int[blockSize]; + int limit = ints.length - blockSize + 1; + byte encoding = in.readByte(); + int i = 0; + for (; i < limit; i += blockSize) { + int offset = i; + docIdsWriter.readInts(in, blockSize, encoding, block); + System.arraycopy(block, 0, read, offset, blockSize); + } + // handle tail + if (i < ints.length) { + int offset = i; + docIdsWriter.readInts(in, ints.length - i, encoding, block); + System.arraycopy(block, 0, read, offset, ints.length - i); + } + assertArrayEquals(ints, read); + assertEquals(len, in.getFilePointer()); + } + dir.deleteFile("tmp"); + } + // This simple test tickles a JVM C2 JIT crash on JDK's less than 21.0.1 // Crashes only when run with HotSpot C2. // Regardless of whether C2 is enabled or not, the test should never fail.