Skip to content

Commit 89c58cf

Browse files
authored
Create a version of bbq_hnsw that supports on_disk_rescore (#135931)
1 parent a3c2dd6 commit 89c58cf

19 files changed

+702
-285
lines changed

server/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@
464464
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat,
465465
org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat,
466466
org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat,
467+
org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat,
467468
org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat;
468469

469470
provides org.apache.lucene.codecs.Codec

server/src/main/java/org/elasticsearch/index/codec/vectors/DirectIOCapableFlatVectorsFormat.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,10 @@ protected DirectIOCapableFlatVectorsFormat(String name) {
1919
super(name);
2020
}
2121

22+
@Override
23+
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
24+
return fieldsReader(state, false);
25+
}
26+
2227
public abstract FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException;
2328
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors;
11+
12+
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
13+
14+
import java.io.IOException;
15+
import java.util.Collection;
16+
import java.util.Collections;
17+
import java.util.HashMap;
18+
import java.util.IdentityHashMap;
19+
import java.util.Map;
20+
21+
/**
22+
* Keeps track of field-specific raw vector readers for vector reads
23+
*/
24+
public class GenericFlatVectorReaders {
25+
26+
public interface Field {
27+
String rawVectorFormatName();
28+
29+
boolean useDirectIOReads();
30+
}
31+
32+
@FunctionalInterface
33+
public interface LoadFlatVectorsReader {
34+
FlatVectorsReader getReader(String formatName, boolean useDirectIO) throws IOException;
35+
}
36+
37+
private record FlatVectorsReaderKey(String formatName, boolean useDirectIO) {
38+
private FlatVectorsReaderKey(Field field) {
39+
this(field.rawVectorFormatName(), field.useDirectIOReads());
40+
}
41+
42+
@Override
43+
public String toString() {
44+
return formatName + (useDirectIO ? " with Direct IO" : "");
45+
}
46+
}
47+
48+
private final Map<FlatVectorsReaderKey, FlatVectorsReader> readers = new HashMap<>();
49+
private final Map<Integer, FlatVectorsReader> readersForFields = new HashMap<>();
50+
51+
public void loadField(int fieldNumber, Field field, LoadFlatVectorsReader loadReader) throws IOException {
52+
FlatVectorsReaderKey key = new FlatVectorsReaderKey(field);
53+
FlatVectorsReader reader = readers.get(key);
54+
if (reader == null) {
55+
reader = loadReader.getReader(field.rawVectorFormatName(), field.useDirectIOReads());
56+
if (reader == null) {
57+
throw new IllegalStateException("Cannot find flat vector format: " + field.rawVectorFormatName());
58+
}
59+
readers.put(key, reader);
60+
}
61+
readersForFields.put(fieldNumber, reader);
62+
}
63+
64+
public FlatVectorsReader getReaderForField(int fieldNumber) {
65+
FlatVectorsReader reader = readersForFields.get(fieldNumber);
66+
if (reader == null) {
67+
throw new IllegalArgumentException("Invalid field number [" + fieldNumber + "]");
68+
}
69+
return reader;
70+
}
71+
72+
public Collection<FlatVectorsReader> allReaders() {
73+
return Collections.unmodifiableCollection(readers.values());
74+
}
75+
76+
public GenericFlatVectorReaders getMergeInstance() throws IOException {
77+
GenericFlatVectorReaders mergeReaders = new GenericFlatVectorReaders();
78+
79+
// link the original instance with the merge instance
80+
Map<FlatVectorsReader, FlatVectorsReader> mergeInstances = new IdentityHashMap<>();
81+
for (var reader : readers.entrySet()) {
82+
FlatVectorsReader mergeInstance = reader.getValue().getMergeInstance();
83+
mergeInstances.put(reader.getValue(), mergeInstance);
84+
mergeReaders.readers.put(reader.getKey(), mergeInstance);
85+
}
86+
// link up the fields to the merge readers
87+
for (var field : readersForFields.entrySet()) {
88+
mergeReaders.readersForFields.put(field.getKey(), mergeInstances.get(field.getValue()));
89+
}
90+
return mergeReaders;
91+
}
92+
}

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.apache.lucene.store.IndexInput;
1717
import org.apache.lucene.util.Bits;
1818
import org.apache.lucene.util.VectorUtil;
19+
import org.elasticsearch.index.codec.vectors.GenericFlatVectorReaders;
1920
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
2021
import org.elasticsearch.index.codec.vectors.cluster.NeighborQueue;
2122
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
@@ -38,7 +39,7 @@
3839
*/
3940
public class ES920DiskBBQVectorsReader extends IVFVectorsReader {
4041

41-
ES920DiskBBQVectorsReader(SegmentReadState state, GetFormatReader getFormatReader) throws IOException {
42+
ES920DiskBBQVectorsReader(SegmentReadState state, GenericFlatVectorReaders.LoadFlatVectorsReader getFormatReader) throws IOException {
4243
super(state, getFormatReader);
4344
}
4445

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java

Lines changed: 18 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@
3030
import org.apache.lucene.store.IndexInput;
3131
import org.apache.lucene.util.Bits;
3232
import org.elasticsearch.core.IOUtils;
33+
import org.elasticsearch.index.codec.vectors.GenericFlatVectorReaders;
3334
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
3435

3536
import java.io.Closeable;
3637
import java.io.IOException;
3738
import java.util.ArrayList;
3839
import java.util.Collections;
39-
import java.util.HashMap;
4040
import java.util.List;
4141
import java.util.Map;
4242

@@ -49,33 +49,18 @@
4949
*/
5050
public abstract class IVFVectorsReader extends KnnVectorsReader {
5151

52-
private record FlatVectorsReaderKey(String formatName, boolean useDirectIO) {
53-
private FlatVectorsReaderKey(FieldEntry entry) {
54-
this(entry.rawVectorFormatName, entry.useDirectIOReads);
55-
}
56-
57-
@Override
58-
public String toString() {
59-
return formatName + (useDirectIO ? " with Direct IO" : "");
60-
}
61-
}
62-
6352
private final IndexInput ivfCentroids, ivfClusters;
6453
private final SegmentReadState state;
6554
private final FieldInfos fieldInfos;
6655
protected final IntObjectHashMap<FieldEntry> fields;
67-
private final Map<FlatVectorsReaderKey, FlatVectorsReader> rawVectorReaders;
68-
69-
@FunctionalInterface
70-
public interface GetFormatReader {
71-
FlatVectorsReader getReader(String formatName, boolean useDirectIO) throws IOException;
72-
}
56+
private final GenericFlatVectorReaders genericReaders;
7357

7458
@SuppressWarnings("this-escape")
75-
protected IVFVectorsReader(SegmentReadState state, GetFormatReader getFormatReader) throws IOException {
59+
protected IVFVectorsReader(SegmentReadState state, GenericFlatVectorReaders.LoadFlatVectorsReader loadReader) throws IOException {
7660
this.state = state;
7761
this.fieldInfos = state.fieldInfos;
7862
this.fields = new IntObjectHashMap<>();
63+
this.genericReaders = new GenericFlatVectorReaders();
7964
String meta = IndexFileNames.segmentFileName(
8065
state.segmentInfo.name,
8166
state.segmentSuffix,
@@ -86,7 +71,6 @@ protected IVFVectorsReader(SegmentReadState state, GetFormatReader getFormatRead
8671
boolean success = false;
8772
try (ChecksumIndexInput ivfMeta = state.directory.openChecksumInput(meta)) {
8873
Throwable priorE = null;
89-
Map<FlatVectorsReaderKey, FlatVectorsReader> readers = null;
9074
try {
9175
versionMeta = CodecUtil.checkIndexHeader(
9276
ivfMeta,
@@ -96,13 +80,12 @@ protected IVFVectorsReader(SegmentReadState state, GetFormatReader getFormatRead
9680
state.segmentInfo.getId(),
9781
state.segmentSuffix
9882
);
99-
readers = readFields(ivfMeta, getFormatReader, versionMeta);
83+
readFields(ivfMeta, versionMeta, genericReaders, loadReader);
10084
} catch (Throwable exception) {
10185
priorE = exception;
10286
} finally {
10387
CodecUtil.checkFooter(ivfMeta, priorE);
10488
}
105-
this.rawVectorReaders = readers;
10689
ivfCentroids = openDataInput(
10790
state,
10891
versionMeta,
@@ -169,30 +152,23 @@ private static IndexInput openDataInput(
169152
}
170153
}
171154

172-
private Map<FlatVectorsReaderKey, FlatVectorsReader> readFields(ChecksumIndexInput meta, GetFormatReader loadReader, int versionMeta)
173-
throws IOException {
174-
Map<FlatVectorsReaderKey, FlatVectorsReader> readers = new HashMap<>();
155+
private void readFields(
156+
ChecksumIndexInput meta,
157+
int versionMeta,
158+
GenericFlatVectorReaders genericFields,
159+
GenericFlatVectorReaders.LoadFlatVectorsReader loadReader
160+
) throws IOException {
175161
for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
176162
final FieldInfo info = fieldInfos.fieldInfo(fieldNumber);
177163
if (info == null) {
178164
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
179165
}
180166

181167
FieldEntry fieldEntry = readField(meta, info, versionMeta);
182-
FlatVectorsReaderKey key = new FlatVectorsReaderKey(fieldEntry);
183-
184-
FlatVectorsReader reader = readers.get(key);
185-
if (reader == null) {
186-
reader = loadReader.getReader(fieldEntry.rawVectorFormatName, fieldEntry.useDirectIOReads);
187-
if (reader == null) {
188-
throw new IllegalStateException("Cannot find flat vector format: " + fieldEntry.rawVectorFormatName);
189-
}
190-
readers.put(key, reader);
191-
}
168+
genericFields.loadField(fieldNumber, fieldEntry, loadReader);
192169

193170
fields.put(info.number, fieldEntry);
194171
}
195-
return readers;
196172
}
197173

198174
private FieldEntry readField(IndexInput input, FieldInfo info, int versionMeta) throws IOException {
@@ -256,29 +232,17 @@ private static VectorEncoding readVectorEncoding(DataInput input) throws IOExcep
256232

257233
@Override
258234
public final void checkIntegrity() throws IOException {
259-
for (var reader : rawVectorReaders.values()) {
235+
for (var reader : genericReaders.allReaders()) {
260236
reader.checkIntegrity();
261237
}
262238
CodecUtil.checksumEntireFile(ivfCentroids);
263239
CodecUtil.checksumEntireFile(ivfClusters);
264240
}
265241

266-
private FieldEntry getFieldEntryOrThrow(String field) {
267-
final FieldInfo info = fieldInfos.fieldInfo(field);
268-
final FieldEntry entry;
269-
if (info == null || (entry = fields.get(info.number)) == null) {
270-
throw new IllegalArgumentException("field=\"" + field + "\" not found");
271-
}
272-
return entry;
273-
}
274-
275242
private FlatVectorsReader getReaderForField(String field) {
276-
var readerKey = new FlatVectorsReaderKey(getFieldEntryOrThrow(field));
277-
FlatVectorsReader reader = rawVectorReaders.get(readerKey);
278-
if (reader == null) throw new IllegalArgumentException(
279-
"Could not find raw vector format [" + readerKey + "] for field [" + field + "]"
280-
);
281-
return reader;
243+
FieldInfo info = fieldInfos.fieldInfo(field);
244+
if (info == null) throw new IllegalArgumentException("Could not find field [" + field + "]");
245+
return genericReaders.getReaderForField(info.number);
282246
}
283247

284248
@Override
@@ -399,7 +363,7 @@ public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) {
399363

400364
@Override
401365
public void close() throws IOException {
402-
List<Closeable> closeables = new ArrayList<>(rawVectorReaders.values());
366+
List<Closeable> closeables = new ArrayList<>(genericReaders.allReaders());
403367
Collections.addAll(closeables, ivfCentroids, ivfClusters);
404368
IOUtils.close(closeables);
405369
}
@@ -416,7 +380,7 @@ protected record FieldEntry(
416380
long postingListLength,
417381
float[] globalCentroid,
418382
float globalCentroidDp
419-
) {
383+
) implements GenericFlatVectorReaders.Field {
420384
IndexInput centroidSlice(IndexInput centroidFile) throws IOException {
421385
return centroidFile.slice("centroids", centroidOffset, centroidLength);
422386
}

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.apache.lucene.store.IndexInput;
1717
import org.apache.lucene.util.Bits;
1818
import org.apache.lucene.util.VectorUtil;
19+
import org.elasticsearch.index.codec.vectors.GenericFlatVectorReaders;
1920
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
2021
import org.elasticsearch.index.codec.vectors.cluster.NeighborQueue;
2122
import org.elasticsearch.index.codec.vectors.diskbbq.DocIdsWriter;
@@ -40,7 +41,8 @@
4041
*/
4142
public class ESNextDiskBBQVectorsReader extends IVFVectorsReader {
4243

43-
public ESNextDiskBBQVectorsReader(SegmentReadState state, GetFormatReader getFormatReader) throws IOException {
44+
public ESNextDiskBBQVectorsReader(SegmentReadState state, GenericFlatVectorReaders.LoadFlatVectorsReader getFormatReader)
45+
throws IOException {
4446
super(state, getFormatReader);
4547
}
4648

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsWriter;
3333

3434
import java.io.IOException;
35+
import java.util.Map;
3536

3637
/**
3738
* Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
@@ -86,19 +87,33 @@
8687
* <li>The sparse vector information, if required, mapping vector ordinal to doc ID
8788
* </ul>
8889
*/
89-
public class ES93BinaryQuantizedVectorsFormat extends DirectIOCapableFlatVectorsFormat {
90+
public class ES93BinaryQuantizedVectorsFormat extends ES93GenericFlatVectorsFormat {
9091

9192
public static final String NAME = "ES93BinaryQuantizedVectorsFormat";
9293

93-
private final DirectIOCapableLucene99FlatVectorsFormat rawVectorFormat;
94+
private static final DirectIOCapableFlatVectorsFormat rawVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(
95+
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
96+
);
97+
98+
private static final Map<String, DirectIOCapableFlatVectorsFormat> supportedFormats = Map.of(
99+
rawVectorFormat.getName(),
100+
rawVectorFormat
101+
);
94102

95103
private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer(
96104
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
97105
);
98106

107+
private final boolean useDirectIO;
108+
99109
public ES93BinaryQuantizedVectorsFormat() {
100110
super(NAME);
101-
rawVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
111+
this.useDirectIO = false;
112+
}
113+
114+
public ES93BinaryQuantizedVectorsFormat(boolean useDirectIO) {
115+
super(NAME);
116+
this.useDirectIO = useDirectIO;
102117
}
103118

104119
@Override
@@ -107,17 +122,27 @@ protected FlatVectorsScorer flatVectorsScorer() {
107122
}
108123

109124
@Override
110-
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
111-
return new ES818BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state);
125+
protected boolean useDirectIOReads() {
126+
return useDirectIO;
112127
}
113128

114129
@Override
115-
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
116-
return new ES818BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), scorer);
130+
protected DirectIOCapableFlatVectorsFormat writeFlatVectorsFormat() {
131+
return rawVectorFormat;
117132
}
118133

119134
@Override
120-
public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException {
121-
return new ES818BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state, useDirectIO), scorer);
135+
protected Map<String, DirectIOCapableFlatVectorsFormat> supportedReadFlatVectorsFormats() {
136+
return supportedFormats;
137+
}
138+
139+
@Override
140+
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
141+
return new ES818BinaryQuantizedVectorsWriter(scorer, super.fieldsWriter(state), state);
142+
}
143+
144+
@Override
145+
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
146+
return new ES818BinaryQuantizedVectorsReader(state, super.fieldsReader(state), scorer);
122147
}
123148
}

0 commit comments

Comments
 (0)