Skip to content

Commit 5108d96

Browse files
committed
Use bit scale for both index and queries
1 parent 81b5409 commit 5108d96

File tree

2 files changed

+28
-14
lines changed

2 files changed

+28
-14
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/es910/ES910BinaryFlatVectorsScorer.java

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,11 @@
3939
public class ES910BinaryFlatVectorsScorer implements FlatVectorsScorer {
4040
private final FlatVectorsScorer nonQuantizedDelegate;
4141
private final byte queryBits;
42+
private final byte indexBits;
4243

43-
public ES910BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate, byte queryBits) {
44+
public ES910BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate, byte indexBits, byte queryBits) {
4445
this.nonQuantizedDelegate = nonQuantizedDelegate;
46+
this.indexBits = indexBits;
4547
this.queryBits = queryBits;
4648
}
4749

@@ -85,15 +87,20 @@ public float score(int i) throws IOException {
8587
queryCorrections,
8688
binarizedVectors.vectorValue(i),
8789
binarizedVectors.getCorrectiveTerms(i),
88-
getBitsScale()
90+
getIndexBitsScale(),
91+
getQueryBitsScale()
8992
);
9093
}
9194
};
9295
}
9396
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
9497
}
9598

96-
private float getBitsScale() {
99+
private float getIndexBitsScale() {
100+
return 1f / ((1 << indexBits) - 1);
101+
}
102+
103+
private float getQueryBitsScale() {
97104
return 1f / ((1 << queryBits) - 1);
98105
}
99106

@@ -111,7 +118,7 @@ RandomVectorScorerSupplier getRandomVectorScorerSupplier(
111118
ES910BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors,
112119
BinarizedByteVectorValues targetVectors
113120
) {
114-
return new BinarizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction, queryBits);
121+
return new BinarizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction, indexBits, queryBits);
115122
}
116123

117124
@Override
@@ -124,28 +131,31 @@ static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSu
124131
private final ES910BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors;
125132
private final BinarizedByteVectorValues targetVectors;
126133
private final VectorSimilarityFunction similarityFunction;
134+
private final byte indexBits;
127135
private final byte queryBits;
128136

129137
BinarizedRandomVectorScorerSupplier(
130138
ES910BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors,
131139
BinarizedByteVectorValues targetVectors,
132140
VectorSimilarityFunction similarityFunction,
141+
byte indexBits,
133142
byte queryBits
134143
) {
135144
this.queryVectors = queryVectors;
136145
this.targetVectors = targetVectors;
137146
this.similarityFunction = similarityFunction;
147+
this.indexBits = indexBits;
138148
this.queryBits = queryBits;
139149
}
140150

141151
@Override
142152
public BinarizedRandomVectorScorer scorer() throws IOException {
143-
return new BinarizedRandomVectorScorer(queryVectors.copy(), targetVectors.copy(), similarityFunction, queryBits);
153+
return new BinarizedRandomVectorScorer(queryVectors.copy(), targetVectors.copy(), similarityFunction, indexBits, queryBits);
144154
}
145155

146156
@Override
147157
public RandomVectorScorerSupplier copy() throws IOException {
148-
return new BinarizedRandomVectorScorerSupplier(queryVectors, targetVectors, similarityFunction, queryBits);
158+
return new BinarizedRandomVectorScorerSupplier(queryVectors, targetVectors, similarityFunction, indexBits, queryBits);
149159
}
150160
}
151161

@@ -157,20 +167,23 @@ public static class BinarizedRandomVectorScorer extends UpdateableRandomVectorSc
157167
private final byte[] quantizedQuery;
158168
private OptimizedScalarQuantizer.QuantizationResult queryCorrections = null;
159169
private int currentOrdinal = -1;
160-
private final float bitScale;
170+
private final float queryBitsScale;
171+
private final float indexBitsScale;
161172

162173
BinarizedRandomVectorScorer(
163174
ES910BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors,
164175
BinarizedByteVectorValues targetVectors,
165176
VectorSimilarityFunction similarityFunction,
177+
byte indexBits,
166178
byte queryBits
167179
) {
168180
super(targetVectors);
169181
this.queryVectors = queryVectors;
170182
this.quantizedQuery = new byte[queryVectors.dimension()];
171183
this.targetVectors = targetVectors;
172184
this.similarityFunction = similarityFunction;
173-
bitScale = 1.0F / (float) ((1 << queryBits) - 1);
185+
this.indexBitsScale = 1.0F / (float) ((1 << indexBits) - 1);
186+
this.queryBitsScale = 1.0F / (float) ((1 << queryBits) - 1);
174187
}
175188

176189
@Override
@@ -186,7 +199,8 @@ public float score(int targetOrd) throws IOException {
186199
queryCorrections,
187200
targetVectors.vectorValue(targetOrd),
188201
targetVectors.getCorrectiveTerms(targetOrd),
189-
bitScale
202+
indexBitsScale,
203+
queryBitsScale
190204
);
191205
}
192206

@@ -209,15 +223,15 @@ private static float quantizedScore(
209223
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
210224
byte[] d,
211225
OptimizedScalarQuantizer.QuantizationResult indexCorrections,
212-
float bitsScale
226+
float indexBitsScale,
227+
float queryBitsScale
213228
) {
214229
float qcDist = VectorUtil.dotProduct(q, d);
215230
float x1 = indexCorrections.quantizedComponentSum();
216231
float ax = indexCorrections.lowerInterval();
217-
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
218-
float lx = indexCorrections.upperInterval() - ax;
232+
float lx = (indexCorrections.upperInterval() - ax) * indexBitsScale;
219233
float ay = queryCorrections.lowerInterval();
220-
float ly = (queryCorrections.upperInterval() - ay) * bitsScale;
234+
float ly = (queryCorrections.upperInterval() - ay) * queryBitsScale;
221235
float y1 = queryCorrections.quantizedComponentSum();
222236
float score = ax * ay * dims + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist;
223237
// For euclidean, we need to invert the score and apply the additional correction, which is

server/src/main/java/org/elasticsearch/index/codec/vectors/es910/ES910BinaryQuantizedVectorsFormat.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ public ES910BinaryQuantizedVectorsFormat(byte indexBits, byte queryBits) {
122122
// don't have the possibility of doing a PerFieldMapperCodec yet on KnnSearcher
123123
DEFAULT_QUERY_BITS = queryBits;
124124
DEFAULT_INDEX_BITS = indexBits;
125-
this.scorer = new ES910BinaryFlatVectorsScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer(), queryBits);
125+
this.scorer = new ES910BinaryFlatVectorsScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer(), indexBits, queryBits);
126126
}
127127

128128
@Override

0 commit comments

Comments
 (0)