Skip to content

Commit 3b444c0

Browse files
committed
Store the raw format name in field metadata
1 parent bd10e90 commit 3b444c0

File tree

5 files changed

+72
-15
lines changed

5 files changed

+72
-15
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat {
5353
public static final String CLUSTER_EXTENSION = "clivf";
5454
static final String IVF_META_EXTENSION = "mivf";
5555

56+
static final String RAW_VECTOR_FORMAT = "raw_vector_format";
57+
5658
public static final int VERSION_START = 0;
5759
public static final int VERSION_CURRENT = VERSION_START;
5860

@@ -106,12 +108,18 @@ public ES920DiskBBQVectorsFormat() {
106108

107109
@Override
108110
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
109-
return new ES920DiskBBQVectorsWriter(state, rawVectorFormat.fieldsWriter(state), vectorPerCluster, centroidsPerParentCluster);
111+
return new ES920DiskBBQVectorsWriter(
112+
rawVectorFormat.getName(),
113+
state,
114+
rawVectorFormat.fieldsWriter(state),
115+
vectorPerCluster,
116+
centroidsPerParentCluster
117+
);
110118
}
111119

112120
@Override
113121
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
114-
return new ES920DiskBBQVectorsReader(state, rawVectorFormat.fieldsReader(state));
122+
return new ES920DiskBBQVectorsReader(state);
115123
}
116124

117125
@Override

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
package org.elasticsearch.index.codec.vectors.diskbbq;
1111

12+
import org.apache.lucene.codecs.KnnVectorsFormat;
1213
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
1314
import org.apache.lucene.index.FieldInfo;
1415
import org.apache.lucene.index.SegmentReadState;
@@ -25,13 +26,15 @@
2526
import org.elasticsearch.simdvec.ESVectorUtil;
2627

2728
import java.io.IOException;
29+
import java.util.HashMap;
2830
import java.util.Map;
2931

3032
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS;
3133
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
3234
import static org.elasticsearch.index.codec.vectors.BQSpaceUtils.transposeHalfByte;
3335
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
3436
import static org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer.DEFAULT_LAMBDA;
37+
import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.RAW_VECTOR_FORMAT;
3538
import static org.elasticsearch.simdvec.ES91OSQVectorsScorer.BULK_SIZE;
3639

3740
/**
@@ -40,8 +43,27 @@
4043
*/
4144
public class ES920DiskBBQVectorsReader extends IVFVectorsReader implements OffHeapStats {
4245

43-
public ES920DiskBBQVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
44-
super(state, rawVectorsReader);
46+
public ES920DiskBBQVectorsReader(SegmentReadState state) throws IOException {
47+
super(state, loadReaders(state));
48+
}
49+
50+
private static Map<String, FlatVectorsReader> loadReaders(SegmentReadState state) throws IOException {
51+
Map<String, FlatVectorsReader> readers = new HashMap<>();
52+
for (FieldInfo fi : state.fieldInfos) {
53+
if (fi.hasVectorValues()) {
54+
String formatName = fi.getAttribute(RAW_VECTOR_FORMAT);
55+
if (formatName == null) {
56+
throw new IllegalArgumentException("Field does not have " + RAW_VECTOR_FORMAT);
57+
}
58+
readers.put(
59+
fi.name,
60+
(FlatVectorsReader) KnnVectorsFormat.forName(formatName)
61+
.fieldsReader(state)
62+
);
63+
}
64+
}
65+
66+
return Map.copyOf(readers);
4567
}
4668

4769
CentroidIterator getPostingListPrefetchIterator(CentroidIterator centroidIterator, IndexInput postingListSlice) throws IOException {

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
package org.elasticsearch.index.codec.vectors.diskbbq;
1111

12+
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
1213
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
1314
import org.apache.lucene.index.FieldInfo;
1415
import org.apache.lucene.index.FloatVectorValues;
@@ -39,6 +40,8 @@
3940
import java.util.AbstractList;
4041
import java.util.Arrays;
4142

43+
import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.RAW_VECTOR_FORMAT;
44+
4245
/**
4346
* Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
4447
* partition the vector space, and then stores the centroids and posting list in a sequential
@@ -47,20 +50,29 @@
4750
public class ES920DiskBBQVectorsWriter extends IVFVectorsWriter {
4851
private static final Logger logger = LogManager.getLogger(ES920DiskBBQVectorsWriter.class);
4952

53+
private final String rawVectorFormatName;
5054
private final int vectorPerCluster;
5155
private final int centroidsPerParentCluster;
5256

5357
public ES920DiskBBQVectorsWriter(
58+
String rawVectorFormatName,
5459
SegmentWriteState state,
5560
FlatVectorsWriter rawVectorDelegate,
5661
int vectorPerCluster,
5762
int centroidsPerParentCluster
5863
) throws IOException {
5964
super(state, rawVectorDelegate);
65+
this.rawVectorFormatName = rawVectorFormatName;
6066
this.vectorPerCluster = vectorPerCluster;
6167
this.centroidsPerParentCluster = centroidsPerParentCluster;
6268
}
6369

70+
@Override
71+
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
72+
fieldInfo.putAttribute(RAW_VECTOR_FORMAT, rawVectorFormatName);
73+
return super.addField(fieldInfo);
74+
}
75+
6476
@Override
6577
CentroidOffsetAndLength buildAndWritePostingsLists(
6678
FieldInfo fieldInfo,

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@
3232
import org.elasticsearch.core.IOUtils;
3333
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
3434

35+
import java.io.Closeable;
3536
import java.io.IOException;
37+
import java.util.ArrayList;
38+
import java.util.Collections;
39+
import java.util.List;
40+
import java.util.Map;
3641

3742
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
3843
import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.DYNAMIC_VISIT_RATIO;
@@ -46,14 +51,14 @@ public abstract class IVFVectorsReader extends KnnVectorsReader {
4651
private final SegmentReadState state;
4752
private final FieldInfos fieldInfos;
4853
protected final IntObjectHashMap<FieldEntry> fields;
49-
private final FlatVectorsReader rawVectorsReader;
54+
private final Map<String, FlatVectorsReader> rawVectorReaders;
5055

5156
@SuppressWarnings("this-escape")
52-
protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
57+
protected IVFVectorsReader(SegmentReadState state, Map<String, FlatVectorsReader> rawVectorReaders) throws IOException {
5358
this.state = state;
5459
this.fieldInfos = state.fieldInfos;
55-
this.rawVectorsReader = rawVectorsReader;
5660
this.fields = new IntObjectHashMap<>();
61+
this.rawVectorReaders = rawVectorReaders;
5762
String meta = IndexFileNames.segmentFileName(
5863
state.segmentInfo.name,
5964
state.segmentSuffix,
@@ -212,26 +217,34 @@ private static VectorEncoding readVectorEncoding(DataInput input) throws IOExcep
212217

213218
@Override
214219
public final void checkIntegrity() throws IOException {
215-
rawVectorsReader.checkIntegrity();
220+
for (var reader : rawVectorReaders.values()) {
221+
reader.checkIntegrity();
222+
}
216223
CodecUtil.checksumEntireFile(ivfCentroids);
217224
CodecUtil.checksumEntireFile(ivfClusters);
218225
}
219226

227+
private FlatVectorsReader getReaderForField(String field) {
228+
FlatVectorsReader reader = rawVectorReaders.get(field);
229+
if (reader == null) throw new IllegalArgumentException("No recorded raw vector reader for field " + field);
230+
return reader;
231+
}
232+
220233
@Override
221234
public final FloatVectorValues getFloatVectorValues(String field) throws IOException {
222-
return rawVectorsReader.getFloatVectorValues(field);
235+
return getReaderForField(field).getFloatVectorValues(field);
223236
}
224237

225238
@Override
226239
public final ByteVectorValues getByteVectorValues(String field) throws IOException {
227-
return rawVectorsReader.getByteVectorValues(field);
240+
return getReaderForField(field).getByteVectorValues(field);
228241
}
229242

230243
@Override
231244
public final void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
232245
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
233246
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) {
234-
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
247+
getReaderForField(field).search(field, target, knnCollector, acceptDocs);
235248
return;
236249
}
237250
if (fieldInfo.getVectorDimension() != target.length) {
@@ -243,7 +256,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
243256
if (acceptDocs instanceof BitSet bitSet) {
244257
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
245258
}
246-
int numVectors = rawVectorsReader.getFloatVectorValues(field).size();
259+
int numVectors = getReaderForField(field).getFloatVectorValues(field).size();
247260
float visitRatio = DYNAMIC_VISIT_RATIO;
248261
// Search strategy may be null if this is being called from checkIndex (e.g. from a test)
249262
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
@@ -309,7 +322,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
309322
@Override
310323
public final void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
311324
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
312-
final ByteVectorValues values = rawVectorsReader.getByteVectorValues(field);
325+
final ByteVectorValues values = getReaderForField(field).getByteVectorValues(field);
313326
for (int i = 0; i < values.size(); i++) {
314327
final float score = fieldInfo.getVectorSimilarityFunction().compare(target, values.vectorValue(i));
315328
knnCollector.collect(values.ordToDoc(i), score);
@@ -321,7 +334,9 @@ public final void search(String field, byte[] target, KnnCollector knnCollector,
321334

322335
@Override
323336
public void close() throws IOException {
324-
IOUtils.close(rawVectorsReader, ivfCentroids, ivfClusters);
337+
List<Closeable> closeables = new ArrayList<>(rawVectorReaders.values());
338+
Collections.addAll(closeables, ivfCentroids, ivfClusters);
339+
IOUtils.close(closeables);
325340
}
326341

327342
protected record FieldEntry(

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsWriter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ protected IVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorD
107107
}
108108

109109
@Override
110-
public final KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
110+
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
111111
if (fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE) {
112112
throw new IllegalArgumentException("IVF does not support cosine similarity");
113113
}

0 commit comments

Comments
 (0)