Skip to content

Commit b3dab7e

Browse files
committed
Fix setting the query bits / index bits statically and printing them out
1 parent 98469e9 commit b3dab7e

File tree

3 files changed

+51
-25
lines changed

3 files changed

+51
-25
lines changed

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,13 @@ public static void main(String[] args) throws Exception {
187187
: new int[] { 0 };
188188
Results[] results = new Results[nProbes.length];
189189
for (int i = 0; i < nProbes.length; i++) {
190-
results[i] = new Results(cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT), cmdLineArgs.numDocs());
190+
results[i] = new Results(
191+
cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT),
192+
cmdLineArgs.numDocs(),
193+
cmdLineArgs.quantizeBits(),
194+
cmdLineArgs.quantizeQueryBits(),
195+
cmdLineArgs.overSamplingFactor()
196+
);
191197
}
192198
logger.info("Running KNN index tester with arguments: " + cmdLineArgs);
193199
Codec codec = createCodec(cmdLineArgs);
@@ -247,7 +253,11 @@ public String toString() {
247253
"avg_cpu_count",
248254
"QPS",
249255
"recall",
250-
"visited" };
256+
"visited",
257+
"indexBits",
258+
"queryBits",
259+
"oversampling"
260+
};
251261

252262
// Calculate appropriate column widths based on headers and data
253263

@@ -274,8 +284,12 @@ public String toString() {
274284
String.format(Locale.ROOT, "%.2f", result.netCpuTimeMS),
275285
String.format(Locale.ROOT, "%.2f", result.avgCpuCount),
276286
String.format(Locale.ROOT, "%.2f", result.qps),
277-
String.format(Locale.ROOT, "%.2f", result.avgRecall),
278-
String.format(Locale.ROOT, "%.2f", result.averageVisited) };
287+
String.format(Locale.ROOT, "%.4f", result.avgRecall),
288+
String.format(Locale.ROOT, "%.2f", result.averageVisited),
289+
String.format(Locale.ROOT, "%d", result.indexBits),
290+
String.format(Locale.ROOT, "%d", result.queryBits),
291+
String.format(Locale.ROOT, "%.2f", result.oversampling)
292+
};
279293

280294
}
281295

@@ -341,6 +355,8 @@ private int[] calculateColumnWidths(String[] headers, String[]... data) {
341355
static class Results {
342356
final String indexType;
343357
final int numDocs;
358+
int indexBits;
359+
int queryBits;
344360
long indexTimeMS;
345361
long forceMergeTimeMS;
346362
int numSegments;
@@ -351,10 +367,14 @@ static class Results {
351367
double averageVisited;
352368
double netCpuTimeMS;
353369
double avgCpuCount;
370+
float oversampling;
354371

355-
Results(String indexType, int numDocs) {
372+
Results(String indexType, int numDocs, int indexBits, int queryBits, float oversampling) {
356373
this.indexType = indexType;
357374
this.numDocs = numDocs;
375+
this.indexBits = indexBits;
376+
this.queryBits = queryBits;
377+
this.oversampling = oversampling;
358378
}
359379
}
360380

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import org.apache.lucene.store.MMapDirectory;
4646
import org.elasticsearch.common.io.Channels;
4747
import org.elasticsearch.core.PathUtils;
48+
import org.elasticsearch.index.codec.vectors.es910.ES910BinaryQuantizedVectorsFormat;
4849
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
4950
import org.elasticsearch.search.profile.query.QueryProfiler;
5051
import org.elasticsearch.search.vectors.ESKnnByteVectorQuery;
@@ -115,6 +116,13 @@ class KnnSearcher {
115116
this.nProbe = nProbe;
116117
this.indexType = cmdLineArgs.indexType();
117118
this.searchThreads = cmdLineArgs.searchThreads();
119+
if (cmdLineArgs.useNewFlatVectorsFormat()) {
120+
// use the query bits statically
121+
ES910BinaryQuantizedVectorsFormat.setQuantizationBits(
122+
(byte) cmdLineArgs.quantizeBits(),
123+
(byte) cmdLineArgs.quantizeQueryBits()
124+
);
125+
}
118126
}
119127

120128
void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) throws IOException {

server/src/main/java/org/elasticsearch/index/codec/vectors/es910/ES910BinaryQuantizedVectorsFormat.java

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import org.apache.lucene.index.SegmentReadState;
2828
import org.apache.lucene.index.SegmentWriteState;
2929
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
30-
import org.elasticsearch.index.codec.vectors.es818.DirectIOLucene99FlatVectorsFormat;
3130

3231
import java.io.IOException;
3332

@@ -102,38 +101,37 @@ public class ES910BinaryQuantizedVectorsFormat extends FlatVectorsFormat {
102101
private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(
103102
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
104103
);
105-
private static byte DEFAULT_INDEX_BITS = (byte) 1;
106-
private static byte DEFAULT_QUERY_BITS = (byte) 4;
107104

108-
private final ES910BinaryFlatVectorsScorer scorer;
109-
110-
private final byte indexBits;
111-
private final byte queryBits;
105+
// index and query bits for quantization. They are static as we have no mapping, and thus can't create a format with specific bits.
106+
// we use static setters to change the quantization bits and scorers.
107+
private static byte INDEX_BITS = (byte) 1;
108+
private static byte QUERY_BITS = (byte) 4;
109+
private static ES910BinaryFlatVectorsScorer SCORER;
112110

113111
public ES910BinaryQuantizedVectorsFormat() {
114-
this(DEFAULT_INDEX_BITS, DEFAULT_QUERY_BITS); // Default to 4 bits for index and 2 bits for query
112+
this(INDEX_BITS, QUERY_BITS); // Default to 4 bits for index and 2 bits for query
115113
}
116114

117115
/** Creates a new instance with the default number of vectors per cluster. */
118116
public ES910BinaryQuantizedVectorsFormat(byte indexBits, byte queryBits) {
119117
super(NAME);
120-
this.indexBits = indexBits;
121-
this.queryBits = queryBits;
122-
// Set the default bits for index and query vectors. I know, I know, this is a hack, but we
123-
// don't have the possibility of doing a PerFieldMapperCodec yet on KnnSearcher
124-
DEFAULT_QUERY_BITS = queryBits;
125-
DEFAULT_INDEX_BITS = indexBits;
126-
this.scorer = new ES910BinaryFlatVectorsScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer(), indexBits, queryBits);
118+
setQuantizationBits(indexBits, queryBits);
119+
}
120+
121+
public static void setQuantizationBits(byte indexBits, byte queryBits) {
122+
INDEX_BITS = indexBits;
123+
QUERY_BITS = queryBits;
124+
SCORER = new ES910BinaryFlatVectorsScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer(), indexBits, queryBits);
127125
}
128126

129127
@Override
130128
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
131-
return new ES910BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state, indexBits, queryBits);
129+
return new ES910BinaryQuantizedVectorsWriter(SCORER, rawVectorFormat.fieldsWriter(state), state, INDEX_BITS, QUERY_BITS);
132130
}
133131

134132
@Override
135133
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
136-
return new ES910BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), scorer);
134+
return new ES910BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), SCORER);
137135
}
138136

139137
@Override
@@ -146,11 +144,11 @@ public String toString() {
146144
return "ES910BinaryQuantizedVectorsFormat(name="
147145
+ NAME
148146
+ ", flatVectorScorer="
149-
+ scorer
147+
+ SCORER
150148
+ ", indexBits="
151-
+ indexBits
149+
+ INDEX_BITS
152150
+ ", queryBits="
153-
+ queryBits
151+
+ QUERY_BITS
154152
+ ")";
155153
}
156154
}

0 commit comments

Comments
 (0)