Skip to content

Commit 602bfbd

Browse files
authored
Extend Lucene104ScalarQuantization format to support quant asymmetry (#15271)
* Extend Lucene104ScalarQuantization format to suppor quant asymmetry * fixing discrete dimension calculations * iter removing debug annotation * iter * iter * fixing bugs * addressing PR comments * changes * removing unnecessary comment
1 parent bace292 commit 602bfbd

10 files changed

+389
-69
lines changed

lucene/CHANGES.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ New Features
156156
`Lucene104HnswScalarQuantizedVectorsFormat` replaces the now legacy `Lucene99HnswScalarQuantizedVectorsFormat`
157157
(Trevor McCulloch)
158158

159+
* GITHUB#15271: Extend `Lucene104ScalarQuantizedVectorsFormat` and `Lucene104HnswScalarQuantizedVectorsFormat` to
160+
allow asymmetric quantization. The initially supported bits are single bit with 4 bit queries. This is a replacement
161+
for the now legacy `Lucene102HnswBinaryQuantizedVectorsFormat` and `Lucene102BinaryQuantizedVectorsFormat`.
162+
(Ben Trent)
163+
159164
Improvements
160165
---------------------
161166
* GITHUB#15148: Add support uint8 distance and allow 8 bit scalar quantization (Trevor McCulloch)

lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ public Lucene104HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) {
8686
/**
8787
* Constructs a format using the given graph construction parameters and scalar quantization.
8888
*
89+
* @param encoding the quantization encoding used to encode the vectors
8990
* @param maxConn the maximum number of connections to a node in the HNSW graph
9091
* @param beamWidth the size of the queue maintained during graph construction.
9192
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If

lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,15 @@ public RandomVectorScorer getRandomVectorScorer(
6464
if (vectorValues instanceof QuantizedByteVectorValues qv) {
6565
checkDimensions(target.length, qv.dimension());
6666
OptimizedScalarQuantizer quantizer = qv.getQuantizer();
67-
byte[] targetQuantized =
68-
new byte
69-
[OptimizedScalarQuantizer.discretize(
70-
target.length, qv.getScalarEncoding().getDimensionsPerByte())];
67+
Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding scalarEncoding = qv.getScalarEncoding();
68+
byte[] scratch = new byte[scalarEncoding.getDiscreteDimensions(qv.dimension())];
69+
final byte[] targetQuantized;
70+
if (scalarEncoding.isAsymmetric() == false) {
71+
targetQuantized = scratch;
72+
} else {
73+
// This is asymmetric quantization, we will pack the vector
74+
targetQuantized = new byte[scalarEncoding.getQueryPackedLength(scratch.length)];
75+
}
7176
// We make a copy as the quantization process mutates the input
7277
float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length);
7378
if (similarityFunction == COSINE) {
@@ -76,7 +81,12 @@ public RandomVectorScorer getRandomVectorScorer(
7681
target = copy;
7782
var targetCorrectiveTerms =
7883
quantizer.scalarQuantize(
79-
target, targetQuantized, qv.getScalarEncoding().getBits(), qv.getCentroid());
84+
target, scratch, scalarEncoding.getQueryBits(), qv.getCentroid());
85+
// for single bit query nibble, we need to transpose the nibbles for fast scoring comparisons
86+
if (scalarEncoding
87+
== Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SINGLE_BIT_QUERY_NIBBLE) {
88+
OptimizedScalarQuantizer.transposeHalfByte(scratch, targetQuantized);
89+
}
8090
return new RandomVectorScorer.AbstractRandomVectorScorer(qv) {
8191
@Override
8292
public float score(int node) throws IOException {
@@ -96,13 +106,68 @@ public RandomVectorScorer getRandomVectorScorer(
96106
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
97107
}
98108

109+
RandomVectorScorerSupplier getRandomVectorScorerSupplier(
110+
VectorSimilarityFunction similarityFunction,
111+
QuantizedByteVectorValues scoringVectors,
112+
QuantizedByteVectorValues targetVectors) {
113+
return new AsymmetricQuantizedRandomVectorScorerSupplier(
114+
scoringVectors, targetVectors, similarityFunction);
115+
}
116+
99117
@Override
100118
public String toString() {
101119
return "Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate="
102120
+ nonQuantizedDelegate
103121
+ ")";
104122
}
105123

124+
static class AsymmetricQuantizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
125+
private final QuantizedByteVectorValues queryVectors;
126+
private final QuantizedByteVectorValues targetVectors;
127+
private final VectorSimilarityFunction similarityFunction;
128+
129+
AsymmetricQuantizedRandomVectorScorerSupplier(
130+
QuantizedByteVectorValues queryVectors,
131+
QuantizedByteVectorValues targetVectors,
132+
VectorSimilarityFunction similarityFunction) {
133+
assert targetVectors.getScalarEncoding().isAsymmetric();
134+
this.queryVectors = queryVectors;
135+
this.targetVectors = targetVectors;
136+
this.similarityFunction = similarityFunction;
137+
}
138+
139+
@Override
140+
public UpdateableRandomVectorScorer scorer() throws IOException {
141+
final QuantizedByteVectorValues targetVectors = this.targetVectors.copy();
142+
final QuantizedByteVectorValues queryVectors = this.queryVectors.copy();
143+
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(targetVectors) {
144+
private OptimizedScalarQuantizer.QuantizationResult queryCorrections = null;
145+
private byte[] vector = null;
146+
147+
@Override
148+
public void setScoringOrdinal(int node) throws IOException {
149+
vector = queryVectors.vectorValue(node);
150+
queryCorrections = queryVectors.getCorrectiveTerms(node);
151+
}
152+
153+
@Override
154+
public float score(int node) throws IOException {
155+
if (vector == null || queryCorrections == null) {
156+
throw new IllegalStateException("setScoringOrdinal was not called");
157+
}
158+
159+
return quantizedScore(vector, queryCorrections, targetVectors, node, similarityFunction);
160+
}
161+
};
162+
}
163+
164+
@Override
165+
public RandomVectorScorerSupplier copy() throws IOException {
166+
return new AsymmetricQuantizedRandomVectorScorerSupplier(
167+
queryVectors.copy(), targetVectors.copy(), similarityFunction);
168+
}
169+
}
170+
106171
private static final class ScalarQuantizedVectorScorerSupplier
107172
implements RandomVectorScorerSupplier {
108173
private final QuantizedByteVectorValues targetValues;
@@ -111,6 +176,7 @@ private static final class ScalarQuantizedVectorScorerSupplier
111176

112177
public ScalarQuantizedVectorScorerSupplier(
113178
QuantizedByteVectorValues values, VectorSimilarityFunction similarity) throws IOException {
179+
assert values.getScalarEncoding().isAsymmetric() == false;
114180
this.targetValues = values.copy();
115181
this.values = values;
116182
this.similarity = similarity;
@@ -131,14 +197,17 @@ public float score(int node) throws IOException {
131197
public void setScoringOrdinal(int node) throws IOException {
132198
var rawTargetVector = targetValues.vectorValue(node);
133199
switch (values.getScalarEncoding()) {
134-
case UNSIGNED_BYTE -> targetVector = rawTargetVector;
135-
case SEVEN_BIT -> targetVector = rawTargetVector;
200+
case UNSIGNED_BYTE, SEVEN_BIT -> targetVector = rawTargetVector;
136201
case PACKED_NIBBLE -> {
137202
if (targetVector == null) {
138203
targetVector = new byte[OptimizedScalarQuantizer.discretize(values.dimension(), 2)];
139204
}
140205
OffHeapScalarQuantizedVectorValues.unpackNibbles(rawTargetVector, targetVector);
141206
}
207+
case SINGLE_BIT_QUERY_NIBBLE -> {
208+
throw new IllegalStateException(
209+
"SINGLE_BIT_QUERY_NIBBLE encoding is not supported for symmetric quantization");
210+
}
142211
}
143212
targetCorrectiveTerms = targetValues.getCorrectiveTerms(node);
144213
}
@@ -177,16 +246,19 @@ private static float quantizedScore(
177246
case UNSIGNED_BYTE -> VectorUtil.uint8DotProduct(quantizedQuery, quantizedDoc);
178247
case SEVEN_BIT -> VectorUtil.dotProduct(quantizedQuery, quantizedDoc);
179248
case PACKED_NIBBLE -> VectorUtil.int4DotProductSinglePacked(quantizedQuery, quantizedDoc);
249+
case SINGLE_BIT_QUERY_NIBBLE ->
250+
VectorUtil.int4BitDotProduct(quantizedQuery, quantizedDoc);
180251
};
181252
OptimizedScalarQuantizer.QuantizationResult indexCorrections =
182253
targetVectors.getCorrectiveTerms(targetOrd);
254+
float queryScale = SCALE_LUT[scalarEncoding.getQueryBits() - 1];
183255
float scale = SCALE_LUT[scalarEncoding.getBits() - 1];
184256
float x1 = indexCorrections.quantizedComponentSum();
185257
float ax = indexCorrections.lowerInterval();
186258
// Here we must scale according to the bits
187259
float lx = (indexCorrections.upperInterval() - ax) * scale;
188260
float ay = queryCorrections.lowerInterval();
189-
float ly = (queryCorrections.upperInterval() - ay) * scale;
261+
float ly = (queryCorrections.upperInterval() - ay) * queryScale;
190262
float y1 = queryCorrections.quantizedComponentSum();
191263
float score =
192264
ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist;

lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,17 +118,25 @@ public class Lucene104ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
118118
*/
119119
public enum ScalarEncoding {
120120
/** Each dimension is quantized to 8 bits and treated as an unsigned value. */
121-
UNSIGNED_BYTE(0, (byte) 8, 1),
121+
UNSIGNED_BYTE(0, (byte) 8, 8),
122122
/** Each dimension is quantized to 4 bits two values are packed into each output byte. */
123-
PACKED_NIBBLE(1, (byte) 4, 2),
123+
PACKED_NIBBLE(1, (byte) 4, 4),
124124
/**
125125
* Each dimension is quantized to 7 bits and treated as a signed value.
126126
*
127127
* <p>This is intended for backwards compatibility with older iterations of scalar quantization.
128128
* This setting will produce an index the same size as {@link #UNSIGNED_BYTE} but will produce
129129
* less accurate vector comparisons.
130130
*/
131-
SEVEN_BIT(2, (byte) 7, 1);
131+
SEVEN_BIT(2, (byte) 7, 8),
132+
/**
133+
* Each dimension is quantized to a single bit and packed into bytes. During query time, the
134+
* query vector is quantized to 4 bits per dimension.
135+
*
136+
* <p>This is the most space efficient encoding, and will produce an index 8x smaller than
137+
* {@link #UNSIGNED_BYTE}. However, this comes at the cost of accuracy.
138+
*/
139+
SINGLE_BIT_QUERY_NIBBLE(3, (byte) 1, 1, (byte) 4, 4);
132140

133141
public static ScalarEncoding fromNumBits(int bits) {
134142
for (ScalarEncoding encoding : values()) {
@@ -142,13 +150,27 @@ public static ScalarEncoding fromNumBits(int bits) {
142150
/** The number used to identify this encoding on the wire, rather than relying on ordinal. */
143151
private final int wireNumber;
144152

145-
private final byte bits;
146-
private final int dimsPerByte;
153+
private final byte bits, queryBits;
154+
private final int bitsPerDim, queryBitsPerDim;
147155

148-
ScalarEncoding(int wireNumber, byte bits, int dimsPerByte) {
156+
ScalarEncoding(int wireNumber, byte bits, int bitsPerDim) {
149157
this.wireNumber = wireNumber;
150158
this.bits = bits;
151-
this.dimsPerByte = dimsPerByte;
159+
this.queryBits = bits;
160+
this.bitsPerDim = bitsPerDim;
161+
this.queryBitsPerDim = bitsPerDim;
162+
}
163+
164+
ScalarEncoding(int wireNumber, byte bits, int bitsPerDim, byte queryBits, int queryBitsPerDim) {
165+
this.wireNumber = wireNumber;
166+
this.bits = bits;
167+
this.queryBits = queryBits;
168+
this.bitsPerDim = bitsPerDim;
169+
this.queryBitsPerDim = queryBitsPerDim;
170+
}
171+
172+
boolean isAsymmetric() {
173+
return bits != queryBits;
152174
}
153175

154176
int getWireNumber() {
@@ -160,14 +182,48 @@ public byte getBits() {
160182
return bits;
161183
}
162184

185+
public byte getQueryBits() {
186+
return queryBits;
187+
}
188+
189+
/** Return the number of dimensions rounded up to fit into whole bytes. */
190+
public int getDiscreteDimensions(int dimensions) {
191+
if (queryBits == bits) {
192+
int totalBits = dimensions * bitsPerDim;
193+
return (totalBits + 7) / 8 * 8 / bitsPerDim;
194+
}
195+
int queryDiscretized = (dimensions * queryBitsPerDim + 7) / 8 * 8 / queryBitsPerDim;
196+
int docDiscretized = (dimensions * bitsPerDim + 7) / 8 * 8 / bitsPerDim;
197+
int maxDiscretized = Math.max(queryDiscretized, docDiscretized);
198+
assert maxDiscretized % (8.0 / queryBitsPerDim) == 0
199+
: "bad discretized=" + maxDiscretized + " for dim=" + dimensions;
200+
assert maxDiscretized % (8.0 / bitsPerDim) == 0
201+
: "bad discretized=" + maxDiscretized + " for dim=" + dimensions;
202+
return maxDiscretized;
203+
}
204+
163205
/** Return the number of dimensions that can be packed into a single byte. */
164-
public int getDimensionsPerByte() {
165-
return this.dimsPerByte;
206+
public int getDocBitsPerDim() {
207+
return this.bitsPerDim;
208+
}
209+
210+
public int getQueryBitsPerDim() {
211+
return this.queryBitsPerDim;
166212
}
167213

168214
/** Return the number of bytes required to store a packed vector of the given dimensions. */
169-
public int getPackedLength(int dimensions) {
170-
return (dimensions + this.dimsPerByte - 1) / this.dimsPerByte;
215+
public int getDocPackedLength(int dimensions) {
216+
int discretized = getDiscreteDimensions(dimensions);
217+
// how many bytes do we need to store the quantized vector?
218+
int totalBits = discretized * bitsPerDim;
219+
return (totalBits + 7) / 8;
220+
}
221+
222+
public int getQueryPackedLength(int dimensions) {
223+
int discretized = getDiscreteDimensions(dimensions);
224+
// how many bytes do we need to store the quantized vector?
225+
int totalBits = discretized * queryBitsPerDim;
226+
return (totalBits + 7) / 8;
171227
}
172228

173229
/** Returns the encoding for the given wire number, or empty if unknown. */
@@ -186,7 +242,7 @@ public Lucene104ScalarQuantizedVectorsFormat() {
186242
this(ScalarEncoding.UNSIGNED_BYTE);
187243
}
188244

189-
/** Creates a new instance with the chosen encoding. */
245+
/** Creates a new instance with the chosen quantization encoding. */
190246
public Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding encoding) {
191247
super(NAME);
192248
this.encoding = encoding;

lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
141141

142142
long numQuantizedVectorBytes =
143143
Math.multiplyExact(
144-
(fieldEntry.scalarEncoding.getPackedLength(dimension)
144+
(fieldEntry.scalarEncoding.getDocPackedLength(dimension)
145145
+ (Float.BYTES * 3)
146146
+ Integer.BYTES),
147147
(long) fieldEntry.size);

0 commit comments

Comments
 (0)