Skip to content
1 change: 1 addition & 0 deletions server/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat,
org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat,
org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat,
org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat,
org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat;

provides org.apache.lucene.codecs.Codec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,10 @@ protected DirectIOCapableFlatVectorsFormat(String name) {
super(name);
}

@Override
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return fieldsReader(state, false);
}

public abstract FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors;

import org.apache.lucene.codecs.hnsw.FlatVectorsReader;

import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Map;

/**
* Keeps track of field-specific raw vector readers for vector reads
*/
public class GenericFlatVectorReaders {

public interface Field {
String rawVectorFormatName();

boolean useDirectIOReads();
}

@FunctionalInterface
public interface LoadFlatVectorsReader {
FlatVectorsReader getReader(String formatName, boolean useDirectIO) throws IOException;
}

private record FlatVectorsReaderKey(String formatName, boolean useDirectIO) {
private FlatVectorsReaderKey(Field field) {
this(field.rawVectorFormatName(), field.useDirectIOReads());
}

@Override
public String toString() {
return formatName + (useDirectIO ? " with Direct IO" : "");
}
}

private final Map<FlatVectorsReaderKey, FlatVectorsReader> readers = new HashMap<>();
private final Map<Integer, FlatVectorsReader> readersForFields = new HashMap<>();

public void loadField(int fieldNumber, Field field, LoadFlatVectorsReader loadReader) throws IOException {
FlatVectorsReaderKey key = new FlatVectorsReaderKey(field);
FlatVectorsReader reader = readers.get(key);
if (reader == null) {
reader = loadReader.getReader(field.rawVectorFormatName(), field.useDirectIOReads());
if (reader == null) {
throw new IllegalStateException("Cannot find flat vector format: " + field.rawVectorFormatName());
}
readers.put(key, reader);
}
readersForFields.put(fieldNumber, reader);
}

public FlatVectorsReader getReaderForField(int fieldNumber) {
FlatVectorsReader reader = readersForFields.get(fieldNumber);
if (reader == null) {
throw new IllegalArgumentException("Invalid field number [" + fieldNumber + "]");
}
return reader;
}

public Collection<FlatVectorsReader> allReaders() {
return Collections.unmodifiableCollection(readers.values());
}

public GenericFlatVectorReaders getMergeInstance() throws IOException {
GenericFlatVectorReaders mergeReaders = new GenericFlatVectorReaders();

// link the original instance with the merge instance
Map<FlatVectorsReader, FlatVectorsReader> mergeInstances = new IdentityHashMap<>();
for (var reader : readers.entrySet()) {
FlatVectorsReader mergeInstance = reader.getValue().getMergeInstance();
mergeInstances.put(reader.getValue(), mergeInstance);
mergeReaders.readers.put(reader.getKey(), mergeInstance);
}
// link up the fields to the merge readers
for (var field : readersForFields.entrySet()) {
mergeReaders.readersForFields.put(field.getKey(), mergeInstances.get(field.getValue()));
}
return mergeReaders;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.index.codec.vectors.GenericFlatVectorReaders;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import org.elasticsearch.index.codec.vectors.cluster.NeighborQueue;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
Expand All @@ -38,7 +39,7 @@
*/
public class ES920DiskBBQVectorsReader extends IVFVectorsReader {

ES920DiskBBQVectorsReader(SegmentReadState state, GetFormatReader getFormatReader) throws IOException {
ES920DiskBBQVectorsReader(SegmentReadState state, GenericFlatVectorReaders.LoadFlatVectorsReader getFormatReader) throws IOException {
super(state, getFormatReader);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.index.codec.vectors.GenericFlatVectorReaders;
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand All @@ -49,33 +49,18 @@
*/
public abstract class IVFVectorsReader extends KnnVectorsReader {

private record FlatVectorsReaderKey(String formatName, boolean useDirectIO) {
private FlatVectorsReaderKey(FieldEntry entry) {
this(entry.rawVectorFormatName, entry.useDirectIOReads);
}

@Override
public String toString() {
return formatName + (useDirectIO ? " with Direct IO" : "");
}
}

private final IndexInput ivfCentroids, ivfClusters;
private final SegmentReadState state;
private final FieldInfos fieldInfos;
protected final IntObjectHashMap<FieldEntry> fields;
private final Map<FlatVectorsReaderKey, FlatVectorsReader> rawVectorReaders;

@FunctionalInterface
public interface GetFormatReader {
FlatVectorsReader getReader(String formatName, boolean useDirectIO) throws IOException;
}
private final GenericFlatVectorReaders genericReaders;

@SuppressWarnings("this-escape")
protected IVFVectorsReader(SegmentReadState state, GetFormatReader getFormatReader) throws IOException {
protected IVFVectorsReader(SegmentReadState state, GenericFlatVectorReaders.LoadFlatVectorsReader loadReader) throws IOException {
this.state = state;
this.fieldInfos = state.fieldInfos;
this.fields = new IntObjectHashMap<>();
this.genericReaders = new GenericFlatVectorReaders();
String meta = IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
Expand All @@ -86,7 +71,6 @@ protected IVFVectorsReader(SegmentReadState state, GetFormatReader getFormatRead
boolean success = false;
try (ChecksumIndexInput ivfMeta = state.directory.openChecksumInput(meta)) {
Throwable priorE = null;
Map<FlatVectorsReaderKey, FlatVectorsReader> readers = null;
try {
versionMeta = CodecUtil.checkIndexHeader(
ivfMeta,
Expand All @@ -96,13 +80,12 @@ protected IVFVectorsReader(SegmentReadState state, GetFormatReader getFormatRead
state.segmentInfo.getId(),
state.segmentSuffix
);
readers = readFields(ivfMeta, getFormatReader, versionMeta);
readFields(ivfMeta, versionMeta, genericReaders, loadReader);
} catch (Throwable exception) {
priorE = exception;
} finally {
CodecUtil.checkFooter(ivfMeta, priorE);
}
this.rawVectorReaders = readers;
ivfCentroids = openDataInput(
state,
versionMeta,
Expand Down Expand Up @@ -169,30 +152,23 @@ private static IndexInput openDataInput(
}
}

private Map<FlatVectorsReaderKey, FlatVectorsReader> readFields(ChecksumIndexInput meta, GetFormatReader loadReader, int versionMeta)
throws IOException {
Map<FlatVectorsReaderKey, FlatVectorsReader> readers = new HashMap<>();
private void readFields(
ChecksumIndexInput meta,
int versionMeta,
GenericFlatVectorReaders genericFields,
GenericFlatVectorReaders.LoadFlatVectorsReader loadReader
) throws IOException {
for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
final FieldInfo info = fieldInfos.fieldInfo(fieldNumber);
if (info == null) {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
}

FieldEntry fieldEntry = readField(meta, info, versionMeta);
FlatVectorsReaderKey key = new FlatVectorsReaderKey(fieldEntry);

FlatVectorsReader reader = readers.get(key);
if (reader == null) {
reader = loadReader.getReader(fieldEntry.rawVectorFormatName, fieldEntry.useDirectIOReads);
if (reader == null) {
throw new IllegalStateException("Cannot find flat vector format: " + fieldEntry.rawVectorFormatName);
}
readers.put(key, reader);
}
genericFields.loadField(fieldNumber, fieldEntry, loadReader);

fields.put(info.number, fieldEntry);
}
return readers;
}

private FieldEntry readField(IndexInput input, FieldInfo info, int versionMeta) throws IOException {
Expand Down Expand Up @@ -256,29 +232,17 @@ private static VectorEncoding readVectorEncoding(DataInput input) throws IOExcep

@Override
public final void checkIntegrity() throws IOException {
for (var reader : rawVectorReaders.values()) {
for (var reader : genericReaders.allReaders()) {
reader.checkIntegrity();
}
CodecUtil.checksumEntireFile(ivfCentroids);
CodecUtil.checksumEntireFile(ivfClusters);
}

private FieldEntry getFieldEntryOrThrow(String field) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry entry;
if (info == null || (entry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
return entry;
}

private FlatVectorsReader getReaderForField(String field) {
var readerKey = new FlatVectorsReaderKey(getFieldEntryOrThrow(field));
FlatVectorsReader reader = rawVectorReaders.get(readerKey);
if (reader == null) throw new IllegalArgumentException(
"Could not find raw vector format [" + readerKey + "] for field [" + field + "]"
);
return reader;
FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) throw new IllegalArgumentException("Could not find field [" + field + "]");
return genericReaders.getReaderForField(info.number);
}

@Override
Expand Down Expand Up @@ -399,7 +363,7 @@ public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) {

@Override
public void close() throws IOException {
List<Closeable> closeables = new ArrayList<>(rawVectorReaders.values());
List<Closeable> closeables = new ArrayList<>(genericReaders.allReaders());
Collections.addAll(closeables, ivfCentroids, ivfClusters);
IOUtils.close(closeables);
}
Expand All @@ -416,7 +380,7 @@ protected record FieldEntry(
long postingListLength,
float[] globalCentroid,
float globalCentroidDp
) {
) implements GenericFlatVectorReaders.Field {
IndexInput centroidSlice(IndexInput centroidFile) throws IOException {
return centroidFile.slice("centroids", centroidOffset, centroidLength);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.index.codec.vectors.GenericFlatVectorReaders;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import org.elasticsearch.index.codec.vectors.cluster.NeighborQueue;
import org.elasticsearch.index.codec.vectors.diskbbq.DocIdsWriter;
Expand All @@ -40,7 +41,8 @@
*/
public class ESNextDiskBBQVectorsReader extends IVFVectorsReader {

public ESNextDiskBBQVectorsReader(SegmentReadState state, GetFormatReader getFormatReader) throws IOException {
public ESNextDiskBBQVectorsReader(SegmentReadState state, GenericFlatVectorReaders.LoadFlatVectorsReader getFormatReader)
throws IOException {
super(state, getFormatReader);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsWriter;

import java.io.IOException;
import java.util.Map;

/**
* Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
Expand Down Expand Up @@ -86,19 +87,33 @@
* <li>The sparse vector information, if required, mapping vector ordinal to doc ID
* </ul>
*/
public class ES93BinaryQuantizedVectorsFormat extends DirectIOCapableFlatVectorsFormat {
public class ES93BinaryQuantizedVectorsFormat extends ES93GenericFlatVectorsFormat {

public static final String NAME = "ES93BinaryQuantizedVectorsFormat";

private final DirectIOCapableLucene99FlatVectorsFormat rawVectorFormat;
private static final DirectIOCapableFlatVectorsFormat rawVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
);

private static final Map<String, DirectIOCapableFlatVectorsFormat> supportedFormats = Map.of(
rawVectorFormat.getName(),
rawVectorFormat
);

private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer(
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
);

private final boolean useDirectIO;

public ES93BinaryQuantizedVectorsFormat() {
super(NAME);
rawVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
this.useDirectIO = false;
}

public ES93BinaryQuantizedVectorsFormat(boolean useDirectIO) {
super(NAME);
this.useDirectIO = useDirectIO;
}

@Override
Expand All @@ -107,17 +122,27 @@ protected FlatVectorsScorer flatVectorsScorer() {
}

@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new ES818BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state);
protected boolean useDirectIOReads() {
return useDirectIO;
}

@Override
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new ES818BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), scorer);
protected DirectIOCapableFlatVectorsFormat writeFlatVectorsFormat() {
return rawVectorFormat;
}

@Override
public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException {
return new ES818BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state, useDirectIO), scorer);
protected Map<String, DirectIOCapableFlatVectorsFormat> supportedReadFlatVectorsFormats() {
return supportedFormats;
}

@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new ES818BinaryQuantizedVectorsWriter(scorer, super.fieldsWriter(state), state);
}

@Override
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new ES818BinaryQuantizedVectorsReader(state, super.fieldsReader(state), scorer);
}
}
Loading