Skip to content

Commit 4019ff1

Browse files
committed
Fix dibit striping for quantized vector formats (#15578)
* Fix dibit striping for quantized vector formats * tidy
1 parent 9f0ddba commit 4019ff1

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,28 @@ public enum ScalarEncoding {
145145
* between the compression of {@link #SINGLE_BIT_QUERY_NIBBLE} and the accuracy of {@link
146146
* #PACKED_NIBBLE}.
147147
*/
148-
DIBIT_QUERY_NIBBLE(4, (byte) 2, 2, (byte) 4, 4);
148+
DIBIT_QUERY_NIBBLE(4, (byte) 2, 2, (byte) 4, 4) {
149+
@Override
150+
public int getDiscreteDimensions(int dimensions) {
151+
int queryDiscretized = (dimensions * 4 + 7) / 8 * 8 / 4;
152+
// we want to force dibit packing to byte boundaries assuming single bit striping
153+
// so we discretize to the same as single bit encoding
154+
int docDiscretized = (dimensions + 7) / 8 * 8;
155+
int maxDiscretized = Math.max(queryDiscretized, docDiscretized);
156+
assert maxDiscretized % (8.0 / 4) == 0
157+
: "bad discretized=" + maxDiscretized + " for dim=" + dimensions;
158+
assert maxDiscretized % (8.0 / 2) == 0
159+
: "bad discretized=" + maxDiscretized + " for dim=" + dimensions;
160+
return maxDiscretized;
161+
}
162+
163+
@Override
164+
public int getDocPackedLength(int dimensions) {
165+
int discretized = getDiscreteDimensions(dimensions);
166+
// DIBIT should be stored as two single bits striped
167+
return 2 * ((discretized + 7) / 8);
168+
}
169+
};
149170

150171
public static ScalarEncoding fromNumBits(int bits) {
151172
for (ScalarEncoding encoding : values()) {

lucene/core/src/test/org/apache/lucene/util/quantization/TestOptimizedScalarQuantizer.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,20 @@ public void testUnpackBinary() {
194194
assertArrayEquals(scratch, unpacked);
195195
}
196196

197+
public void testPackTransposeDibit() {
198+
int dim = randomIntBetween(1, 4096);
199+
ScalarEncoding encoding = ScalarEncoding.DIBIT_QUERY_NIBBLE;
200+
byte[] scratch = new byte[encoding.getDiscreteDimensions(dim)];
201+
for (int i = 0; i < scratch.length; i++) {
202+
scratch[i] = (byte) randomIntBetween(0, 3);
203+
}
204+
byte[] packed = new byte[encoding.getDocPackedLength(scratch.length)];
205+
byte[] unpacked = new byte[scratch.length];
206+
OptimizedScalarQuantizer.transposeDibit(scratch, packed);
207+
OptimizedScalarQuantizer.untransposeDibit(packed, unpacked);
208+
assertArrayEquals(scratch, unpacked);
209+
}
210+
197211
static void assertValidQuantizedRange(byte[] quantized, byte bits) {
198212
for (byte b : quantized) {
199213
if (bits < 8) {

0 commit comments

Comments
 (0)