2222import org .apache .lucene .store .DataOutput ;
2323import org .apache .lucene .store .IndexInput ;
2424import org .apache .lucene .util .IntsRef ;
25- import org .apache .lucene .util .LongsRef ;
2625import org .apache .lucene .util .hnsw .IntToIntFunction ;
2726
2827import java .io .IOException ;
@@ -43,7 +42,6 @@ final class DocIdsWriter {
4342 private static final byte BPV_32 = (byte ) 32 ;
4443
4544 private int [] scratch = new int [0 ];
46- private final LongsRef scratchLongs = new LongsRef ();
4745
4846 /**
4947 * IntsRef to be used to iterate over the scratch buffer. A single instance is reused to avoid
@@ -63,6 +61,175 @@ final class DocIdsWriter {
6361
6462 DocIdsWriter () {}
6563
64+ /**
65+ * Calculate the best encoding that will be used to write blocks of doc ids of blockSize.
66+ * The encoding choice is universal for all the blocks, which means that the encoding is only as
67+ * efficient as the worst block.
68+ * @param docIds function to access the doc ids
69+ * @param count number of doc ids
70+ * @param blockSize the block size
71+ * @return the byte encoding to use for the blocks
72+ */
73+ byte calculateBlockEncoding (IntToIntFunction docIds , int count , int blockSize ) {
74+ if (count == 0 ) {
75+ return CONTINUOUS_IDS ;
76+ }
77+ byte encoding = CONTINUOUS_IDS ;
78+ int iterationLimit = count - blockSize + 1 ;
79+ int i = 0 ;
80+ for (; i < iterationLimit ; i += blockSize ) {
81+ int offset = i ;
82+ encoding = (byte ) Math .max (encoding , blockEncoding (d -> docIds .apply (offset + d ), blockSize ));
83+ }
84+ // check the tail
85+ if (i == count ) {
86+ return encoding ;
87+ }
88+ int offset = i ;
89+ encoding = (byte ) Math .max (encoding , blockEncoding (d -> docIds .apply (offset + d ), count - i ));
90+ return encoding ;
91+ }
92+
93+ void writeDocIds (IntToIntFunction docIds , int count , byte encoding , DataOutput out ) throws IOException {
94+ if (count == 0 ) {
95+ return ;
96+ }
97+ if (count > scratch .length ) {
98+ scratch = new int [count ];
99+ }
100+ int min = docIds .apply (0 );
101+ for (int i = 1 ; i < count ; ++i ) {
102+ int current = docIds .apply (i );
103+ min = Math .min (min , current );
104+ }
105+ switch (encoding ) {
106+ case CONTINUOUS_IDS :
107+ writeContinuousIds (docIds , count , out );
108+ break ;
109+ case DELTA_BPV_16 :
110+ writeDelta16 (docIds , count , min , out );
111+ break ;
112+ case BPV_21 :
113+ write21 (docIds , count , min , out );
114+ break ;
115+ case BPV_24 :
116+ write24 (docIds , count , min , out );
117+ break ;
118+ case BPV_32 :
119+ write32 (docIds , count , min , out );
120+ break ;
121+ default :
122+ throw new IOException ("Unsupported number of bits per value: " + encoding );
123+ }
124+ }
125+
126+ private static void writeContinuousIds (IntToIntFunction docIds , int count , DataOutput out ) throws IOException {
127+ out .writeVInt (docIds .apply (0 ));
128+ }
129+
130+ private void writeDelta16 (IntToIntFunction docIds , int count , int min , DataOutput out ) throws IOException {
131+ for (int i = 0 ; i < count ; i ++) {
132+ scratch [i ] = docIds .apply (i ) - min ;
133+ }
134+ out .writeVInt (min );
135+ final int halfLen = count >> 1 ;
136+ for (int i = 0 ; i < halfLen ; ++i ) {
137+ scratch [i ] = scratch [halfLen + i ] | (scratch [i ] << 16 );
138+ }
139+ for (int i = 0 ; i < halfLen ; i ++) {
140+ out .writeInt (scratch [i ]);
141+ }
142+ if ((count & 1 ) == 1 ) {
143+ out .writeShort ((short ) scratch [count - 1 ]);
144+ }
145+ }
146+
147+ private void write21 (IntToIntFunction docIds , int count , int min , DataOutput out ) throws IOException {
148+ final int oneThird = floorToMultipleOf16 (count / 3 );
149+ final int numInts = oneThird * 2 ;
150+ for (int i = 0 ; i < numInts ; i ++) {
151+ scratch [i ] = docIds .apply (i ) << 11 ;
152+ }
153+ for (int i = 0 ; i < oneThird ; i ++) {
154+ final int longIdx = i + numInts ;
155+ scratch [i ] |= docIds .apply (longIdx ) & 0x7FF ;
156+ scratch [i + oneThird ] |= (docIds .apply (longIdx ) >>> 11 ) & 0x7FF ;
157+ }
158+ for (int i = 0 ; i < numInts ; i ++) {
159+ out .writeInt (scratch [i ]);
160+ }
161+ int i = oneThird * 3 ;
162+ for (; i < count - 2 ; i += 3 ) {
163+ out .writeLong (((long ) docIds .apply (i )) | (((long ) docIds .apply (i + 1 )) << 21 ) | (((long ) docIds .apply (i + 2 )) << 42 ));
164+ }
165+ for (; i < count ; ++i ) {
166+ out .writeShort ((short ) docIds .apply (i ));
167+ out .writeByte ((byte ) (docIds .apply (i ) >>> 16 ));
168+ }
169+ }
170+
171+ private void write24 (IntToIntFunction docIds , int count , int min , DataOutput out ) throws IOException {
172+
173+ // encode the docs in the format that can be vectorized decoded.
174+ final int quarter = count >> 2 ;
175+ final int numInts = quarter * 3 ;
176+ for (int i = 0 ; i < numInts ; i ++) {
177+ scratch [i ] = docIds .apply (i ) << 8 ;
178+ }
179+ for (int i = 0 ; i < quarter ; i ++) {
180+ final int longIdx = i + numInts ;
181+ scratch [i ] |= docIds .apply (longIdx ) & 0xFF ;
182+ scratch [i + quarter ] |= (docIds .apply (longIdx ) >>> 8 ) & 0xFF ;
183+ scratch [i + quarter * 2 ] |= docIds .apply (longIdx ) >>> 16 ;
184+ }
185+ for (int i = 0 ; i < numInts ; i ++) {
186+ out .writeInt (scratch [i ]);
187+ }
188+ for (int i = quarter << 2 ; i < count ; ++i ) {
189+ out .writeShort ((short ) docIds .apply (i ));
190+ out .writeByte ((byte ) (docIds .apply (i ) >>> 16 ));
191+ }
192+ }
193+
194+ private void write32 (IntToIntFunction docIds , int count , int min , DataOutput out ) throws IOException {
195+ for (int i = 0 ; i < count ; i ++) {
196+ out .writeInt (docIds .apply (i ));
197+ }
198+ }
199+
200+ private static byte blockEncoding (IntToIntFunction docIds , int count ) {
201+ // docs can be sorted either when all docs in a block have the same value
202+ // or when a segment is sorted
203+ boolean strictlySorted = true ;
204+ int min = docIds .apply (0 );
205+ int max = min ;
206+ for (int i = 1 ; i < count ; ++i ) {
207+ int last = docIds .apply (i - 1 );
208+ int current = docIds .apply (i );
209+ if (last >= current ) {
210+ strictlySorted = false ;
211+ }
212+ min = Math .min (min , current );
213+ max = Math .max (max , current );
214+ }
215+
216+ int min2max = max - min + 1 ;
217+ if (strictlySorted && min2max == count ) {
218+ return CONTINUOUS_IDS ;
219+ }
220+ if (min2max <= 0xFFFF ) {
221+ return DELTA_BPV_16 ;
222+ } else {
223+ if (max <= 0x1FFFFF ) {
224+ return BPV_21 ;
225+ } else if (max <= 0xFFFFFF ) {
226+ return BPV_24 ;
227+ } else {
228+ return BPV_32 ;
229+ }
230+ }
231+ }
232+
66233 void writeDocIds (IntToIntFunction docIds , int count , DataOutput out ) throws IOException {
67234 if (count == 0 ) {
68235 return ;
@@ -89,91 +256,35 @@ void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOEx
89256 if (strictlySorted && min2max == count ) {
90257 // continuous ids, typically happens when segment is sorted
91258 out .writeByte (CONTINUOUS_IDS );
92- out . writeVInt (docIds . apply ( 0 ) );
259+ writeContinuousIds (docIds , count , out );
93260 return ;
94261 }
95262
96263 if (min2max <= 0xFFFF ) {
97264 out .writeByte (DELTA_BPV_16 );
98- for (int i = 0 ; i < count ; i ++) {
99- scratch [i ] = docIds .apply (i ) - min ;
100- }
101- out .writeVInt (min );
102- final int halfLen = count >> 1 ;
103- for (int i = 0 ; i < halfLen ; ++i ) {
104- scratch [i ] = scratch [halfLen + i ] | (scratch [i ] << 16 );
105- }
106- for (int i = 0 ; i < halfLen ; i ++) {
107- out .writeInt (scratch [i ]);
108- }
109- if ((count & 1 ) == 1 ) {
110- out .writeShort ((short ) scratch [count - 1 ]);
111- }
265+ writeDelta16 (docIds , count , min , out );
112266 } else {
113267 if (max <= 0x1FFFFF ) {
114268 out .writeByte (BPV_21 );
115- final int oneThird = floorToMultipleOf16 (count / 3 );
116- final int numInts = oneThird * 2 ;
117- for (int i = 0 ; i < numInts ; i ++) {
118- scratch [i ] = docIds .apply (i ) << 11 ;
119- }
120- for (int i = 0 ; i < oneThird ; i ++) {
121- final int longIdx = i + numInts ;
122- scratch [i ] |= docIds .apply (longIdx ) & 0x7FF ;
123- scratch [i + oneThird ] |= (docIds .apply (longIdx ) >>> 11 ) & 0x7FF ;
124- }
125- for (int i = 0 ; i < numInts ; i ++) {
126- out .writeInt (scratch [i ]);
127- }
128- int i = oneThird * 3 ;
129- for (; i < count - 2 ; i += 3 ) {
130- out .writeLong (((long ) docIds .apply (i )) | (((long ) docIds .apply (i + 1 )) << 21 ) | (((long ) docIds .apply (i + 2 )) << 42 ));
131- }
132- for (; i < count ; ++i ) {
133- out .writeShort ((short ) docIds .apply (i ));
134- out .writeByte ((byte ) (docIds .apply (i ) >>> 16 ));
135- }
269+ write21 (docIds , count , min , out );
136270 } else if (max <= 0xFFFFFF ) {
137271 out .writeByte (BPV_24 );
138-
139- // encode the docs in the format that can be vectorized decoded.
140- final int quarter = count >> 2 ;
141- final int numInts = quarter * 3 ;
142- for (int i = 0 ; i < numInts ; i ++) {
143- scratch [i ] = docIds .apply (i ) << 8 ;
144- }
145- for (int i = 0 ; i < quarter ; i ++) {
146- final int longIdx = i + numInts ;
147- scratch [i ] |= docIds .apply (longIdx ) & 0xFF ;
148- scratch [i + quarter ] |= (docIds .apply (longIdx ) >>> 8 ) & 0xFF ;
149- scratch [i + quarter * 2 ] |= docIds .apply (longIdx ) >>> 16 ;
150- }
151- for (int i = 0 ; i < numInts ; i ++) {
152- out .writeInt (scratch [i ]);
153- }
154- for (int i = quarter << 2 ; i < count ; ++i ) {
155- out .writeShort ((short ) docIds .apply (i ));
156- out .writeByte ((byte ) (docIds .apply (i ) >>> 16 ));
157- }
272+ write24 (docIds , count , min , out );
158273 } else {
159274 out .writeByte (BPV_32 );
160- for (int i = 0 ; i < count ; i ++) {
161- out .writeInt (docIds .apply (i ));
162- }
275+ write32 (docIds , count , min , out );
163276 }
164277 }
165278 }
166279
167- /** Read {@code count} integers into {@code docIDs}. */
168- void readInts (IndexInput in , int count , int [] docIDs ) throws IOException {
280+ void readInts (IndexInput in , int count , byte encoding , int [] docIDs ) throws IOException {
169281 if (count == 0 ) {
170282 return ;
171283 }
172284 if (count > scratch .length ) {
173285 scratch = new int [count ];
174286 }
175- final int bpv = in .readByte ();
176- switch (bpv ) {
287+ switch (encoding ) {
177288 case CONTINUOUS_IDS :
178289 readContinuousIds (in , count , docIDs );
179290 break ;
@@ -190,8 +301,20 @@ void readInts(IndexInput in, int count, int[] docIDs) throws IOException {
190301 readInts32 (in , count , docIDs );
191302 break ;
192303 default :
193- throw new IOException ("Unsupported number of bits per value: " + bpv );
304+ throw new IOException ("Unsupported number of bits per value: " + encoding );
305+ }
306+ }
307+
308+ /** Read {@code count} integers into {@code docIDs}. */
309+ void readInts (IndexInput in , int count , int [] docIDs ) throws IOException {
310+ if (count == 0 ) {
311+ return ;
194312 }
313+ if (count > scratch .length ) {
314+ scratch = new int [count ];
315+ }
316+ final int bpv = in .readByte ();
317+ readInts (in , count , (byte ) bpv , docIDs );
195318 }
196319
197320 private static void readContinuousIds (IndexInput in , int count , int [] docIDs ) throws IOException {
0 commit comments