Skip to content

Commit 9aee716

Browse files
committed
Refactor doc-id encoding for DiskBBQ to allow doc ids to be encoded in blocks
1 parent e785661 commit 9aee716

File tree

2 files changed

+243
-66
lines changed

2 files changed

+243
-66
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DocIdsWriter.java

Lines changed: 189 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import org.apache.lucene.store.DataOutput;
2323
import org.apache.lucene.store.IndexInput;
2424
import org.apache.lucene.util.IntsRef;
25-
import org.apache.lucene.util.LongsRef;
2625
import org.apache.lucene.util.hnsw.IntToIntFunction;
2726

2827
import java.io.IOException;
@@ -43,7 +42,6 @@ final class DocIdsWriter {
4342
private static final byte BPV_32 = (byte) 32;
4443

4544
private int[] scratch = new int[0];
46-
private final LongsRef scratchLongs = new LongsRef();
4745

4846
/**
4947
* IntsRef to be used to iterate over the scratch buffer. A single instance is reused to avoid
@@ -63,6 +61,175 @@ final class DocIdsWriter {
6361

6462
DocIdsWriter() {}
6563

64+
/**
65+
* Calculate the best encoding that will be used to write blocks of doc ids of blockSize.
66+
* The encoding choice is universal for all the blocks, which means that the encoding is only as
67+
* efficient as the worst block.
68+
* @param docIds function to access the doc ids
69+
* @param count number of doc ids
70+
* @param blockSize the block size
71+
* @return the byte encoding to use for the blocks
72+
*/
73+
byte calculateBlockEncoding(IntToIntFunction docIds, int count, int blockSize) {
74+
if (count == 0) {
75+
return CONTINUOUS_IDS;
76+
}
77+
byte encoding = CONTINUOUS_IDS;
78+
int iterationLimit = count - blockSize + 1;
79+
int i = 0;
80+
for (; i < iterationLimit; i += blockSize) {
81+
int offset = i;
82+
encoding = (byte) Math.max(encoding, blockEncoding(d -> docIds.apply(offset + d), blockSize));
83+
}
84+
// check the tail
85+
if (i == count) {
86+
return encoding;
87+
}
88+
int offset = i;
89+
encoding = (byte) Math.max(encoding, blockEncoding(d -> docIds.apply(offset + d), count - i));
90+
return encoding;
91+
}
92+
93+
void writeDocIds(IntToIntFunction docIds, int count, byte encoding, DataOutput out) throws IOException {
94+
if (count == 0) {
95+
return;
96+
}
97+
if (count > scratch.length) {
98+
scratch = new int[count];
99+
}
100+
int min = docIds.apply(0);
101+
for (int i = 1; i < count; ++i) {
102+
int current = docIds.apply(i);
103+
min = Math.min(min, current);
104+
}
105+
switch (encoding) {
106+
case CONTINUOUS_IDS:
107+
writeContinuousIds(docIds, count, out);
108+
break;
109+
case DELTA_BPV_16:
110+
writeDelta16(docIds, count, min, out);
111+
break;
112+
case BPV_21:
113+
write21(docIds, count, min, out);
114+
break;
115+
case BPV_24:
116+
write24(docIds, count, min, out);
117+
break;
118+
case BPV_32:
119+
write32(docIds, count, min, out);
120+
break;
121+
default:
122+
throw new IOException("Unsupported number of bits per value: " + encoding);
123+
}
124+
}
125+
126+
private static void writeContinuousIds(IntToIntFunction docIds, int count, DataOutput out) throws IOException {
127+
out.writeVInt(docIds.apply(0));
128+
}
129+
130+
private void writeDelta16(IntToIntFunction docIds, int count, int min, DataOutput out) throws IOException {
131+
for (int i = 0; i < count; i++) {
132+
scratch[i] = docIds.apply(i) - min;
133+
}
134+
out.writeVInt(min);
135+
final int halfLen = count >> 1;
136+
for (int i = 0; i < halfLen; ++i) {
137+
scratch[i] = scratch[halfLen + i] | (scratch[i] << 16);
138+
}
139+
for (int i = 0; i < halfLen; i++) {
140+
out.writeInt(scratch[i]);
141+
}
142+
if ((count & 1) == 1) {
143+
out.writeShort((short) scratch[count - 1]);
144+
}
145+
}
146+
147+
private void write21(IntToIntFunction docIds, int count, int min, DataOutput out) throws IOException {
148+
final int oneThird = floorToMultipleOf16(count / 3);
149+
final int numInts = oneThird * 2;
150+
for (int i = 0; i < numInts; i++) {
151+
scratch[i] = docIds.apply(i) << 11;
152+
}
153+
for (int i = 0; i < oneThird; i++) {
154+
final int longIdx = i + numInts;
155+
scratch[i] |= docIds.apply(longIdx) & 0x7FF;
156+
scratch[i + oneThird] |= (docIds.apply(longIdx) >>> 11) & 0x7FF;
157+
}
158+
for (int i = 0; i < numInts; i++) {
159+
out.writeInt(scratch[i]);
160+
}
161+
int i = oneThird * 3;
162+
for (; i < count - 2; i += 3) {
163+
out.writeLong(((long) docIds.apply(i)) | (((long) docIds.apply(i + 1)) << 21) | (((long) docIds.apply(i + 2)) << 42));
164+
}
165+
for (; i < count; ++i) {
166+
out.writeShort((short) docIds.apply(i));
167+
out.writeByte((byte) (docIds.apply(i) >>> 16));
168+
}
169+
}
170+
171+
private void write24(IntToIntFunction docIds, int count, int min, DataOutput out) throws IOException {
172+
173+
// encode the docs in the format that can be vectorized decoded.
174+
final int quarter = count >> 2;
175+
final int numInts = quarter * 3;
176+
for (int i = 0; i < numInts; i++) {
177+
scratch[i] = docIds.apply(i) << 8;
178+
}
179+
for (int i = 0; i < quarter; i++) {
180+
final int longIdx = i + numInts;
181+
scratch[i] |= docIds.apply(longIdx) & 0xFF;
182+
scratch[i + quarter] |= (docIds.apply(longIdx) >>> 8) & 0xFF;
183+
scratch[i + quarter * 2] |= docIds.apply(longIdx) >>> 16;
184+
}
185+
for (int i = 0; i < numInts; i++) {
186+
out.writeInt(scratch[i]);
187+
}
188+
for (int i = quarter << 2; i < count; ++i) {
189+
out.writeShort((short) docIds.apply(i));
190+
out.writeByte((byte) (docIds.apply(i) >>> 16));
191+
}
192+
}
193+
194+
private void write32(IntToIntFunction docIds, int count, int min, DataOutput out) throws IOException {
195+
for (int i = 0; i < count; i++) {
196+
out.writeInt(docIds.apply(i));
197+
}
198+
}
199+
200+
private static byte blockEncoding(IntToIntFunction docIds, int count) {
201+
// docs can be sorted either when all docs in a block have the same value
202+
// or when a segment is sorted
203+
boolean strictlySorted = true;
204+
int min = docIds.apply(0);
205+
int max = min;
206+
for (int i = 1; i < count; ++i) {
207+
int last = docIds.apply(i - 1);
208+
int current = docIds.apply(i);
209+
if (last >= current) {
210+
strictlySorted = false;
211+
}
212+
min = Math.min(min, current);
213+
max = Math.max(max, current);
214+
}
215+
216+
int min2max = max - min + 1;
217+
if (strictlySorted && min2max == count) {
218+
return CONTINUOUS_IDS;
219+
}
220+
if (min2max <= 0xFFFF) {
221+
return DELTA_BPV_16;
222+
} else {
223+
if (max <= 0x1FFFFF) {
224+
return BPV_21;
225+
} else if (max <= 0xFFFFFF) {
226+
return BPV_24;
227+
} else {
228+
return BPV_32;
229+
}
230+
}
231+
}
232+
66233
void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOException {
67234
if (count == 0) {
68235
return;
@@ -89,91 +256,35 @@ void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOEx
89256
if (strictlySorted && min2max == count) {
90257
// continuous ids, typically happens when segment is sorted
91258
out.writeByte(CONTINUOUS_IDS);
92-
out.writeVInt(docIds.apply(0));
259+
writeContinuousIds(docIds, count, out);
93260
return;
94261
}
95262

96263
if (min2max <= 0xFFFF) {
97264
out.writeByte(DELTA_BPV_16);
98-
for (int i = 0; i < count; i++) {
99-
scratch[i] = docIds.apply(i) - min;
100-
}
101-
out.writeVInt(min);
102-
final int halfLen = count >> 1;
103-
for (int i = 0; i < halfLen; ++i) {
104-
scratch[i] = scratch[halfLen + i] | (scratch[i] << 16);
105-
}
106-
for (int i = 0; i < halfLen; i++) {
107-
out.writeInt(scratch[i]);
108-
}
109-
if ((count & 1) == 1) {
110-
out.writeShort((short) scratch[count - 1]);
111-
}
265+
writeDelta16(docIds, count, min, out);
112266
} else {
113267
if (max <= 0x1FFFFF) {
114268
out.writeByte(BPV_21);
115-
final int oneThird = floorToMultipleOf16(count / 3);
116-
final int numInts = oneThird * 2;
117-
for (int i = 0; i < numInts; i++) {
118-
scratch[i] = docIds.apply(i) << 11;
119-
}
120-
for (int i = 0; i < oneThird; i++) {
121-
final int longIdx = i + numInts;
122-
scratch[i] |= docIds.apply(longIdx) & 0x7FF;
123-
scratch[i + oneThird] |= (docIds.apply(longIdx) >>> 11) & 0x7FF;
124-
}
125-
for (int i = 0; i < numInts; i++) {
126-
out.writeInt(scratch[i]);
127-
}
128-
int i = oneThird * 3;
129-
for (; i < count - 2; i += 3) {
130-
out.writeLong(((long) docIds.apply(i)) | (((long) docIds.apply(i + 1)) << 21) | (((long) docIds.apply(i + 2)) << 42));
131-
}
132-
for (; i < count; ++i) {
133-
out.writeShort((short) docIds.apply(i));
134-
out.writeByte((byte) (docIds.apply(i) >>> 16));
135-
}
269+
write21(docIds, count, min, out);
136270
} else if (max <= 0xFFFFFF) {
137271
out.writeByte(BPV_24);
138-
139-
// encode the docs in the format that can be vectorized decoded.
140-
final int quarter = count >> 2;
141-
final int numInts = quarter * 3;
142-
for (int i = 0; i < numInts; i++) {
143-
scratch[i] = docIds.apply(i) << 8;
144-
}
145-
for (int i = 0; i < quarter; i++) {
146-
final int longIdx = i + numInts;
147-
scratch[i] |= docIds.apply(longIdx) & 0xFF;
148-
scratch[i + quarter] |= (docIds.apply(longIdx) >>> 8) & 0xFF;
149-
scratch[i + quarter * 2] |= docIds.apply(longIdx) >>> 16;
150-
}
151-
for (int i = 0; i < numInts; i++) {
152-
out.writeInt(scratch[i]);
153-
}
154-
for (int i = quarter << 2; i < count; ++i) {
155-
out.writeShort((short) docIds.apply(i));
156-
out.writeByte((byte) (docIds.apply(i) >>> 16));
157-
}
272+
write24(docIds, count, min, out);
158273
} else {
159274
out.writeByte(BPV_32);
160-
for (int i = 0; i < count; i++) {
161-
out.writeInt(docIds.apply(i));
162-
}
275+
write32(docIds, count, min, out);
163276
}
164277
}
165278
}
166279

167-
/** Read {@code count} integers into {@code docIDs}. */
168-
void readInts(IndexInput in, int count, int[] docIDs) throws IOException {
280+
void readInts(IndexInput in, int count, byte encoding, int[] docIDs) throws IOException {
169281
if (count == 0) {
170282
return;
171283
}
172284
if (count > scratch.length) {
173285
scratch = new int[count];
174286
}
175-
final int bpv = in.readByte();
176-
switch (bpv) {
287+
switch (encoding) {
177288
case CONTINUOUS_IDS:
178289
readContinuousIds(in, count, docIDs);
179290
break;
@@ -190,8 +301,20 @@ void readInts(IndexInput in, int count, int[] docIDs) throws IOException {
190301
readInts32(in, count, docIDs);
191302
break;
192303
default:
193-
throw new IOException("Unsupported number of bits per value: " + bpv);
304+
throw new IOException("Unsupported number of bits per value: " + encoding);
305+
}
306+
}
307+
308+
/** Read {@code count} integers into {@code docIDs}. */
309+
void readInts(IndexInput in, int count, int[] docIDs) throws IOException {
310+
if (count == 0) {
311+
return;
194312
}
313+
if (count > scratch.length) {
314+
scratch = new int[count];
315+
}
316+
final int bpv = in.readByte();
317+
readInts(in, count, (byte) bpv, docIDs);
195318
}
196319

197320
private static void readContinuousIds(IndexInput in, int count, int[] docIDs) throws IOException {

server/src/test/java/org/elasticsearch/index/codec/vectors/DocIdsWriterTests.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,14 @@ public void testContinuousIds() throws Exception {
124124
}
125125

126126
private void test(Directory dir, int[] ints) throws Exception {
127+
if (random().nextBoolean()) {
128+
testSingleBlock(dir, ints);
129+
} else {
130+
testMultiBlock(dir, ints);
131+
}
132+
}
133+
134+
private void testSingleBlock(Directory dir, int[] ints) throws Exception {
127135
final long len;
128136
// It is hard to get BPV24-encoded docs in TextLuceneXXPointsFormat, test bwc here as well.
129137
DocIdsWriter docIdsWriter = new DocIdsWriter();
@@ -143,6 +151,52 @@ private void test(Directory dir, int[] ints) throws Exception {
143151
dir.deleteFile("tmp");
144152
}
145153

154+
private void testMultiBlock(Directory dir, int[] ints) throws Exception {
155+
final long len;
156+
final int blockSize = 16 + random().nextInt(100);
157+
DocIdsWriter docIdsWriter = new DocIdsWriter();
158+
try (IndexOutput out = dir.createOutput("tmp", IOContext.DEFAULT)) {
159+
byte encoding = docIdsWriter.calculateBlockEncoding(i -> ints[i], ints.length, blockSize);
160+
out.writeByte(encoding);
161+
int limit = ints.length - blockSize + 1;
162+
int i = 0;
163+
for (; i < limit; i += blockSize) {
164+
int offset = i;
165+
docIdsWriter.writeDocIds(d -> ints[d + offset], blockSize, encoding, out);
166+
}
167+
// handle tail
168+
if (i < ints.length) {
169+
int offset = i;
170+
docIdsWriter.writeDocIds(d -> ints[d + offset], ints.length - i, encoding, out);
171+
}
172+
len = out.getFilePointer();
173+
if (random().nextBoolean()) {
174+
out.writeLong(0); // garbage
175+
}
176+
}
177+
try (IndexInput in = dir.openInput("tmp", IOContext.READONCE)) {
178+
int[] read = new int[ints.length];
179+
int[] block = new int[blockSize];
180+
int limit = ints.length - blockSize + 1;
181+
byte encoding = in.readByte();
182+
int i = 0;
183+
for (; i < limit; i += blockSize) {
184+
int offset = i;
185+
docIdsWriter.readInts(in, blockSize, encoding, block);
186+
System.arraycopy(block, 0, read, offset, blockSize);
187+
}
188+
// handle tail
189+
if (i < ints.length) {
190+
int offset = i;
191+
docIdsWriter.readInts(in, ints.length - i, encoding, block);
192+
System.arraycopy(block, 0, read, offset, ints.length - i);
193+
}
194+
assertArrayEquals(ints, read);
195+
assertEquals(len, in.getFilePointer());
196+
}
197+
dir.deleteFile("tmp");
198+
}
199+
146200
// This simple test tickles a JVM C2 JIT crash on JDK's less than 21.0.1
147201
// Crashes only when run with HotSpot C2.
148202
// Regardless of whether C2 is enabled or not, the test should never fail.

0 commit comments

Comments
 (0)