Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.io.stream.ByteArrayStreamInput;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.mapper.BlockLoader.BlockFactory;
import org.elasticsearch.index.mapper.BlockLoader.BooleanBuilder;
import org.elasticsearch.index.mapper.BlockLoader.Builder;
Expand All @@ -26,6 +28,7 @@
import org.elasticsearch.index.mapper.BlockLoader.DoubleBuilder;
import org.elasticsearch.index.mapper.BlockLoader.IntBuilder;
import org.elasticsearch.index.mapper.BlockLoader.LongBuilder;
import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder;
import org.elasticsearch.search.fetch.StoredFieldsSpec;

import java.io.IOException;
Expand Down Expand Up @@ -504,6 +507,85 @@ public String toString() {
}
}

public static class DenseVectorBlockLoader extends DocValuesBlockLoader {
private final String fieldName;
private final int dimensions;

public DenseVectorBlockLoader(String fieldName, int dimensions) {
this.fieldName = fieldName;
this.dimensions = dimensions;
}

@Override
public Builder builder(BlockFactory factory, int expectedCount) {
return factory.denseVectors(expectedCount, dimensions);
}

@Override
public AllReader reader(LeafReaderContext context) throws IOException {
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName);
if (floatVectorValues != null) {
return new DenseVectorValuesBlockReader(floatVectorValues, dimensions);
}
return new ConstantNullsReader();
}
}

private static class DenseVectorValuesBlockReader extends BlockDocValuesReader {
private final FloatVectorValues floatVectorValues;
private final int dimensions;

DenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) {
this.floatVectorValues = floatVectorValues;
this.dimensions = dimensions;
}

@Override
public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException {
// Doubles from doc values ensures that the values are in order
try (BlockLoader.FloatBuilder builder = factory.denseVectors(docs.count(), dimensions)) {
for (int i = 0; i < docs.count(); i++) {
int doc = docs.get(i);
if (doc < floatVectorValues.docID()) {
throw new IllegalStateException("docs within same block must be in order");
}
read(doc, builder);
}
return builder.build();
}
}

@Override
public void read(int docId, BlockLoader.StoredFields storedFields, Builder builder) throws IOException {
read(docId, (BlockLoader.FloatBuilder) builder);
}

private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException {
if (floatVectorValues.advance(doc) == doc) {
builder.beginPositionEntry();
float[] floats = floatVectorValues.vectorValue();
assert floats.length == dimensions
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length;
for (float aFloat : floats) {
builder.appendFloat(aFloat);
}
builder.endPositionEntry();
} else {
builder.appendNull();
}
}

@Override
public int docId() {
return floatVectorValues.docID();
}

@Override
public String toString() {
return "BlockDocValuesReader.FloatVectorValuesBlockReader";
}
}

public static class BytesRefsFromOrdsBlockLoader extends DocValuesBlockLoader {
private final String fieldName;

Expand Down Expand Up @@ -752,6 +834,94 @@ public String toString() {
}
}

public static class DenseVectorFromBinaryBlockLoader extends DocValuesBlockLoader {
private final String fieldName;
private final int dims;
private final IndexVersion indexVersion;

public DenseVectorFromBinaryBlockLoader(String fieldName, int dims, IndexVersion indexVersion) {
this.fieldName = fieldName;
this.dims = dims;
this.indexVersion = indexVersion;
}

@Override
public Builder builder(BlockFactory factory, int expectedCount) {
return factory.denseVectors(expectedCount, dims);
}

@Override
public AllReader reader(LeafReaderContext context) throws IOException {
BinaryDocValues docValues = context.reader().getBinaryDocValues(fieldName);
if (docValues == null) {
return new ConstantNullsReader();
}
return new DenseVectorFromBinary(docValues, dims, indexVersion);
}
}

private static class DenseVectorFromBinary extends BlockDocValuesReader {
private final BinaryDocValues docValues;
private final IndexVersion indexVersion;
private final int dimensions;
private final float[] scratch;

private int docID = -1;

DenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) {
this.docValues = docValues;
this.scratch = new float[dims];
this.indexVersion = indexVersion;
this.dimensions = dims;
}

@Override
public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException {
try (BlockLoader.FloatBuilder builder = factory.denseVectors(docs.count(), dimensions)) {
for (int i = 0; i < docs.count(); i++) {
int doc = docs.get(i);
if (doc < docID) {
throw new IllegalStateException("docs within same block must be in order");
}
read(doc, builder);
}
return builder.build();
}
}

@Override
public void read(int docId, BlockLoader.StoredFields storedFields, Builder builder) throws IOException {
read(docId, (BlockLoader.FloatBuilder) builder);
}

private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException {
this.docID = doc;
if (false == docValues.advanceExact(doc)) {
builder.appendNull();
return;
}
BytesRef bytesRef = docValues.binaryValue();
assert bytesRef.length > 0;
VectorEncoderDecoder.decodeDenseVector(indexVersion, bytesRef, scratch);

builder.beginPositionEntry();
for (float value : scratch) {
builder.appendFloat(value);
}
builder.endPositionEntry();
}

@Override
public int docId() {
return docID;
}

@Override
public String toString() {
return "DenseVectorFromBinary.Bytes";
}
}

public static class BooleansBlockLoader extends DocValuesBlockLoader {
private final String fieldName;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,11 @@ interface BlockFactory {
*/
DoubleBuilder doubles(int expectedCount);

/**
* Build a builder to load dense vectors without any loading constraints.
*/
FloatBuilder denseVectors(int expectedVectorsCount, int dimensions);

/**
* Build a builder to load ints as loaded from doc values.
* Doc values load ints in sorted order.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,49 @@ public String toString() {
}
}

/**
* Load {@code float}s from {@code _source}.
*/
public static class DenseVectorBlockLoader extends SourceBlockLoader {
private final int dimensions;

public DenseVectorBlockLoader(ValueFetcher fetcher, LeafIteratorLookup lookup, int dimensions) {
super(fetcher, lookup);
this.dimensions = dimensions;
}

@Override
public Builder builder(BlockFactory factory, int expectedCount) {
return factory.denseVectors(expectedCount, dimensions);
}

@Override
public RowStrideReader rowStrideReader(LeafReaderContext context, DocIdSetIterator iter) {
return new DenseVectors(fetcher, iter);
}

@Override
protected String name() {
return "DenseVectors";
}
}

private static class DenseVectors extends BlockSourceReader {
DenseVectors(ValueFetcher fetcher, DocIdSetIterator iter) {
super(fetcher, iter);
}

@Override
protected void append(BlockLoader.Builder builder, Object v) {
((BlockLoader.FloatBuilder) builder).appendFloat(((Number) v).floatValue());
}

@Override
public String toString() {
return "BlockSourceReader.DenseVectors";
}
}

/**
* Load {@code int}s from {@code _source}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,26 @@
import org.elasticsearch.index.fielddata.FieldDataContext;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.mapper.ArraySourceValueFetcher;
import org.elasticsearch.index.mapper.BlockDocValuesReader;
import org.elasticsearch.index.mapper.BlockLoader;
import org.elasticsearch.index.mapper.BlockSourceReader;
import org.elasticsearch.index.mapper.DocumentParserContext;
import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.index.mapper.MapperBuilderContext;
import org.elasticsearch.index.mapper.MapperParsingException;
import org.elasticsearch.index.mapper.MappingParser;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.index.mapper.SimpleMappedFieldType;
import org.elasticsearch.index.mapper.SourceLoader;
import org.elasticsearch.index.mapper.SourceValueFetcher;
import org.elasticsearch.index.mapper.TextSearchInfo;
import org.elasticsearch.index.mapper.ValueFetcher;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
import org.elasticsearch.search.lookup.Source;
import org.elasticsearch.search.vectors.DenseVectorQuery;
import org.elasticsearch.search.vectors.ESDiversifyingChildrenByteKnnVectorQuery;
import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery;
Expand All @@ -84,10 +90,12 @@
import java.time.ZoneId;
import java.util.Arrays;
import java.util.HexFormat;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;
Expand Down Expand Up @@ -324,7 +332,8 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) {
indexed.getValue(),
similarity.getValue(),
indexOptions.getValue(),
meta.getValue()
meta.getValue(),
context.isSourceSynthetic()
),
builderParams(this, context),
indexOptions.getValue(),
Expand Down Expand Up @@ -2053,6 +2062,7 @@ public static final class DenseVectorFieldType extends SimpleMappedFieldType {
private final VectorSimilarity similarity;
private final IndexVersion indexVersionCreated;
private final IndexOptions indexOptions;
private final boolean isSyntheticSource;

public DenseVectorFieldType(
String name,
Expand All @@ -2062,7 +2072,8 @@ public DenseVectorFieldType(
boolean indexed,
VectorSimilarity similarity,
IndexOptions indexOptions,
Map<String, String> meta
Map<String, String> meta,
boolean isSyntheticSource
) {
super(name, indexed, false, indexed == false, TextSearchInfo.NONE, meta);
this.elementType = elementType;
Expand All @@ -2071,6 +2082,7 @@ public DenseVectorFieldType(
this.similarity = similarity;
this.indexVersionCreated = indexVersionCreated;
this.indexOptions = indexOptions;
this.isSyntheticSource = isSyntheticSource;
}

@Override
Expand Down Expand Up @@ -2329,6 +2341,44 @@ ElementType getElementType() {
public IndexOptions getIndexOptions() {
return indexOptions;
}

@Override
public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) {
if (elementType != ElementType.FLOAT) {
// Just float dense vector support for now
return null;
}

if (indexed) {
return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims);
}

if (hasDocValues() && (blContext.fieldExtractPreference() == FieldExtractPreference.DOC_VALUES || isSyntheticSource)) {
return new BlockDocValuesReader.DenseVectorFromBinaryBlockLoader(name(), dims, indexVersionCreated);
}

BlockSourceReader.LeafIteratorLookup lookup = BlockSourceReader.lookupMatchingAll();
return new BlockSourceReader.DenseVectorBlockLoader(sourceValueFetcher(blContext.sourcePaths(name())), lookup, dims);
}

private SourceValueFetcher sourceValueFetcher(Set<String> sourcePaths) {
return new SourceValueFetcher(sourcePaths, null) {
@Override
protected Object parseSourceValue(Object value) {
if (value.equals("")) {
return null;
}
return NumberFieldMapper.NumberType.FLOAT.parse(value, false);
}

@Override
public List<Object> fetchValues(Source source, int doc, List<Object> ignoredValues) {
List<Object> result = super.fetchValues(source, doc, ignoredValues);
assert result.size() == dims : "Unexpected number of dimensions; got " + result.size() + " but expected " + dims;
return result;
}
};
}
}

private final IndexOptions indexOptions;
Expand Down Expand Up @@ -2383,7 +2433,8 @@ public void parse(DocumentParserContext context) throws IOException {
fieldType().indexed,
fieldType().similarity,
fieldType().indexOptions,
fieldType().meta()
fieldType().meta(),
fieldType().isSyntheticSource
);
Mapper update = new DenseVectorFieldMapper(
leafName(),
Expand Down
Loading