Skip to content

Commit 00e4d12

Browse files
authored
Store the raw format name for DiskBBQ in field metadata (#134812)
Store the flat vector format in fieldentry info for loading at read time. At the moment, there's just one supported flat format, but more can be added relatively easily
1 parent 0c73f24 commit 00e4d12

File tree

5 files changed

+77
-15
lines changed

5 files changed

+77
-15
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414
import org.apache.lucene.codecs.KnnVectorsWriter;
1515
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
1616
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
17+
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
1718
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
1819
import org.apache.lucene.index.SegmentReadState;
1920
import org.apache.lucene.index.SegmentWriteState;
21+
import org.elasticsearch.common.util.Maps;
2022
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
2123

2224
import java.io.IOException;
25+
import java.util.Collections;
26+
import java.util.Map;
2327

2428
/**
2529
* Codec format for Inverted File Vector indexes. This index expects to break the dimensional space
@@ -59,6 +63,7 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat {
5963
private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(
6064
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
6165
);
66+
private static final Map<String, FlatVectorsFormat> supportedFormats = Map.of(rawVectorFormat.getName(), rawVectorFormat);
6267

6368
// This dynamically sets the cluster probe based on the `k` requested and the number of clusters.
6469
// useful when searching with 'efSearch' type parameters instead of requiring a specific ratio.
@@ -106,12 +111,23 @@ public ES920DiskBBQVectorsFormat() {
106111

107112
@Override
108113
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
109-
return new ES920DiskBBQVectorsWriter(state, rawVectorFormat.fieldsWriter(state), vectorPerCluster, centroidsPerParentCluster);
114+
return new ES920DiskBBQVectorsWriter(
115+
rawVectorFormat.getName(),
116+
state,
117+
rawVectorFormat.fieldsWriter(state),
118+
vectorPerCluster,
119+
centroidsPerParentCluster
120+
);
110121
}
111122

112123
@Override
113124
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
114-
return new ES920DiskBBQVectorsReader(state, rawVectorFormat.fieldsReader(state));
125+
Map<String, FlatVectorsReader> readers = Maps.newHashMapWithExpectedSize(supportedFormats.size());
126+
for (var fe : supportedFormats.entrySet()) {
127+
readers.put(fe.getKey(), fe.getValue().fieldsReader(state));
128+
}
129+
130+
return new ES920DiskBBQVectorsReader(state, Collections.unmodifiableMap(readers));
115131
}
116132

117133
@Override

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
*/
4141
public class ES920DiskBBQVectorsReader extends IVFVectorsReader implements OffHeapStats {
4242

43-
public ES920DiskBBQVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
43+
public ES920DiskBBQVectorsReader(SegmentReadState state, Map<String, FlatVectorsReader> rawVectorsReader) throws IOException {
4444
super(state, rawVectorsReader);
4545
}
4646

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,13 @@ public class ES920DiskBBQVectorsWriter extends IVFVectorsWriter {
5151
private final int centroidsPerParentCluster;
5252

5353
public ES920DiskBBQVectorsWriter(
54+
String rawVectorFormatName,
5455
SegmentWriteState state,
5556
FlatVectorsWriter rawVectorDelegate,
5657
int vectorPerCluster,
5758
int centroidsPerParentCluster
5859
) throws IOException {
59-
super(state, rawVectorDelegate);
60+
super(state, rawVectorFormatName, rawVectorDelegate);
6061
this.vectorPerCluster = vectorPerCluster;
6162
this.centroidsPerParentCluster = centroidsPerParentCluster;
6263
}

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@
3232
import org.elasticsearch.core.IOUtils;
3333
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
3434

35+
import java.io.Closeable;
3536
import java.io.IOException;
37+
import java.util.ArrayList;
38+
import java.util.Collections;
39+
import java.util.List;
40+
import java.util.Map;
3641

3742
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
3843
import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.DYNAMIC_VISIT_RATIO;
@@ -46,14 +51,14 @@ public abstract class IVFVectorsReader extends KnnVectorsReader {
4651
private final SegmentReadState state;
4752
private final FieldInfos fieldInfos;
4853
protected final IntObjectHashMap<FieldEntry> fields;
49-
private final FlatVectorsReader rawVectorsReader;
54+
private final Map<String, FlatVectorsReader> rawVectorReaders;
5055

5156
@SuppressWarnings("this-escape")
52-
protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
57+
protected IVFVectorsReader(SegmentReadState state, Map<String, FlatVectorsReader> rawVectorReaders) throws IOException {
5358
this.state = state;
5459
this.fieldInfos = state.fieldInfos;
55-
this.rawVectorsReader = rawVectorsReader;
5660
this.fields = new IntObjectHashMap<>();
61+
this.rawVectorReaders = rawVectorReaders;
5762
String meta = IndexFileNames.segmentFileName(
5863
state.segmentInfo.name,
5964
state.segmentSuffix,
@@ -156,6 +161,7 @@ private void readFields(ChecksumIndexInput meta) throws IOException {
156161
}
157162

158163
private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
164+
final String rawVectorFormat = input.readString();
159165
final VectorEncoding vectorEncoding = readVectorEncoding(input);
160166
final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
161167
if (similarityFunction != info.getVectorSimilarityFunction()) {
@@ -182,6 +188,7 @@ private FieldEntry readField(IndexInput input, FieldInfo info) throws IOExceptio
182188
globalCentroidDp = Float.intBitsToFloat(input.readInt());
183189
}
184190
return new FieldEntry(
191+
rawVectorFormat,
185192
similarityFunction,
186193
vectorEncoding,
187194
numCentroids,
@@ -212,26 +219,46 @@ private static VectorEncoding readVectorEncoding(DataInput input) throws IOExcep
212219

213220
@Override
214221
public final void checkIntegrity() throws IOException {
215-
rawVectorsReader.checkIntegrity();
222+
for (var reader : rawVectorReaders.values()) {
223+
reader.checkIntegrity();
224+
}
216225
CodecUtil.checksumEntireFile(ivfCentroids);
217226
CodecUtil.checksumEntireFile(ivfClusters);
218227
}
219228

229+
private FieldEntry getFieldEntryOrThrow(String field) {
230+
final FieldInfo info = fieldInfos.fieldInfo(field);
231+
final FieldEntry entry;
232+
if (info == null || (entry = fields.get(info.number)) == null) {
233+
throw new IllegalArgumentException("field=\"" + field + "\" not found");
234+
}
235+
return entry;
236+
}
237+
238+
private FlatVectorsReader getReaderForField(String field) {
239+
var formatName = getFieldEntryOrThrow(field).rawVectorFormatName;
240+
FlatVectorsReader reader = rawVectorReaders.get(formatName);
241+
if (reader == null) throw new IllegalArgumentException(
242+
"Could not find raw vector format [" + formatName + "] for field [" + field + "]"
243+
);
244+
return reader;
245+
}
246+
220247
@Override
221248
public final FloatVectorValues getFloatVectorValues(String field) throws IOException {
222-
return rawVectorsReader.getFloatVectorValues(field);
249+
return getReaderForField(field).getFloatVectorValues(field);
223250
}
224251

225252
@Override
226253
public final ByteVectorValues getByteVectorValues(String field) throws IOException {
227-
return rawVectorsReader.getByteVectorValues(field);
254+
return getReaderForField(field).getByteVectorValues(field);
228255
}
229256

230257
@Override
231258
public final void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
232259
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
233260
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) {
234-
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
261+
getReaderForField(field).search(field, target, knnCollector, acceptDocs);
235262
return;
236263
}
237264
if (fieldInfo.getVectorDimension() != target.length) {
@@ -243,7 +270,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
243270
if (acceptDocs instanceof BitSet bitSet) {
244271
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
245272
}
246-
int numVectors = rawVectorsReader.getFloatVectorValues(field).size();
273+
int numVectors = getReaderForField(field).getFloatVectorValues(field).size();
247274
float visitRatio = DYNAMIC_VISIT_RATIO;
248275
// Search strategy may be null if this is being called from checkIndex (e.g. from a test)
249276
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
@@ -309,7 +336,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
309336
@Override
310337
public final void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
311338
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
312-
final ByteVectorValues values = rawVectorsReader.getByteVectorValues(field);
339+
final ByteVectorValues values = getReaderForField(field).getByteVectorValues(field);
313340
for (int i = 0; i < values.size(); i++) {
314341
final float score = fieldInfo.getVectorSimilarityFunction().compare(target, values.vectorValue(i));
315342
knnCollector.collect(values.ordToDoc(i), score);
@@ -321,10 +348,13 @@ public final void search(String field, byte[] target, KnnCollector knnCollector,
321348

322349
@Override
323350
public void close() throws IOException {
324-
IOUtils.close(rawVectorsReader, ivfCentroids, ivfClusters);
351+
List<Closeable> closeables = new ArrayList<>(rawVectorReaders.values());
352+
Collections.addAll(closeables, ivfCentroids, ivfClusters);
353+
IOUtils.close(closeables);
325354
}
326355

327356
protected record FieldEntry(
357+
String rawVectorFormatName,
328358
VectorSimilarityFunction similarityFunction,
329359
VectorEncoding vectorEncoding,
330360
int numCentroids,

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsWriter.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,13 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
5151
private final List<FieldWriter> fieldWriters = new ArrayList<>();
5252
private final IndexOutput ivfCentroids, ivfClusters;
5353
private final IndexOutput ivfMeta;
54+
private final String rawVectorFormatName;
5455
private final FlatVectorsWriter rawVectorDelegate;
5556

5657
@SuppressWarnings("this-escape")
57-
protected IVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) throws IOException {
58+
protected IVFVectorsWriter(SegmentWriteState state, String rawVectorFormatName, FlatVectorsWriter rawVectorDelegate)
59+
throws IOException {
60+
this.rawVectorFormatName = rawVectorFormatName;
5861
this.rawVectorDelegate = rawVectorDelegate;
5962
final String metaFileName = IndexFileNames.segmentFileName(
6063
state.segmentInfo.name,
@@ -116,6 +119,9 @@ public final KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOExc
116119
@SuppressWarnings("unchecked")
117120
final FlatFieldVectorsWriter<float[]> floatWriter = (FlatFieldVectorsWriter<float[]>) rawVectorDelegate;
118121
fieldWriters.add(new FieldWriter(fieldInfo, floatWriter));
122+
} else {
123+
// we simply write information that the field is present but we don't do anything with it.
124+
fieldWriters.add(new FieldWriter(fieldInfo, null));
119125
}
120126
return rawVectorDelegate;
121127
}
@@ -165,6 +171,11 @@ abstract CentroidSupplier createCentroidSupplier(
165171
public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
166172
rawVectorDelegate.flush(maxDoc, sortMap);
167173
for (FieldWriter fieldWriter : fieldWriters) {
174+
if (fieldWriter.delegate == null) {
175+
// field is not float, we just write meta information
176+
writeMeta(fieldWriter.fieldInfo, 0, 0, 0, 0, 0, null);
177+
continue;
178+
}
168179
final float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()];
169180
// build a float vector values with random access
170181
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc);
@@ -248,6 +259,9 @@ public int ordToDoc(int ord) {
248259
public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
249260
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
250261
mergeOneFieldIVF(fieldInfo, mergeState);
262+
} else {
263+
// we simply write information that the field is present but we don't do anything with it.
264+
writeMeta(fieldInfo, 0, 0, 0, 0, 0, null);
251265
}
252266
// we merge the vectors at the end so we only have two copies of the vectors on disk at the same time.
253267
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
@@ -476,6 +490,7 @@ private void writeMeta(
476490
float[] globalCentroid
477491
) throws IOException {
478492
ivfMeta.writeInt(field.number);
493+
ivfMeta.writeString(rawVectorFormatName);
479494
ivfMeta.writeInt(field.getVectorEncoding().ordinal());
480495
ivfMeta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction()));
481496
ivfMeta.writeInt(numCentroids);

0 commit comments

Comments
 (0)