Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@
import org.elasticsearch.index.mapper.BlockLoader.DoubleBuilder;
import org.elasticsearch.index.mapper.BlockLoader.IntBuilder;
import org.elasticsearch.index.mapper.BlockLoader.LongBuilder;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder;
import org.elasticsearch.search.fetch.StoredFieldsSpec;

import java.io.IOException;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.COSINE_MAGNITUDE_FIELD_SUFFIX;

/**
* A reader that supports reading doc-values from a Lucene segment in Block fashion.
*/
Expand Down Expand Up @@ -516,10 +519,12 @@ public String toString() {
public static class DenseVectorBlockLoader extends DocValuesBlockLoader {
private final String fieldName;
private final int dimensions;
private final DenseVectorFieldMapper.DenseVectorFieldType fieldType;

public DenseVectorBlockLoader(String fieldName, int dimensions) {
public DenseVectorBlockLoader(String fieldName, int dimensions, DenseVectorFieldMapper.DenseVectorFieldType fieldType) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing the DenseVectorFieldType looked better than including other specific attributes needed for normalization (fieldType.name(), IndexVersion). I opted for creating a isNormalized() method in the DenseVectorFieldType instead.

this.fieldName = fieldName;
this.dimensions = dimensions;
this.fieldType = fieldType;
}

@Override
Expand All @@ -531,20 +536,27 @@ public Builder builder(BlockFactory factory, int expectedCount) {
public AllReader reader(LeafReaderContext context) throws IOException {
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName);
if (floatVectorValues != null) {
return new DenseVectorValuesBlockReader(floatVectorValues, dimensions);
if (fieldType.isNormalized()) {
NumericDocValues magnitudeDocValues = context.reader()
.getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX);
return new FloatDenseVectorNormalizedValuesBlockReader(floatVectorValues, dimensions, magnitudeDocValues);
}
return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions);
}

return new ConstantNullsReader();
}
}

private static class DenseVectorValuesBlockReader extends BlockDocValuesReader {
private final FloatVectorValues floatVectorValues;
private final KnnVectorValues.DocIndexIterator iterator;
private final int dimensions;
private abstract static class DenseVectorValuesBlockReader<T extends KnnVectorValues> extends BlockDocValuesReader {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using an abstract class is previous work for supporting byte element type.


DenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) {
this.floatVectorValues = floatVectorValues;
iterator = floatVectorValues.iterator();
protected final T vectorValues;
protected final KnnVectorValues.DocIndexIterator iterator;
protected final int dimensions;

DenseVectorValuesBlockReader(T vectorValues, int dimensions) {
this.vectorValues = vectorValues;
iterator = vectorValues.iterator();
this.dimensions = dimensions;
}

Expand All @@ -569,26 +581,74 @@ private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException
builder.appendNull();
} else if (iterator.docID() == doc || iterator.advance(doc) == doc) {
builder.beginPositionEntry();
float[] floats = floatVectorValues.vectorValue(iterator.index());
assert floats.length == dimensions
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length;
for (float aFloat : floats) {
builder.appendFloat(aFloat);
}
appendDoc(builder);
builder.endPositionEntry();
} else {
builder.appendNull();
}
}

protected abstract void appendDoc(BlockLoader.FloatBuilder builder) throws IOException;

@Override
public int docId() {
return iterator.docID();
}
}

private static class FloatDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader<FloatVectorValues> {

FloatDenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) {
super(floatVectorValues, dimensions);
}

protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
float[] floats = vectorValues.vectorValue(iterator.index());
assert floats.length == dimensions
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length;
for (float aFloat : floats) {
builder.appendFloat(aFloat);
}
}

@Override
public String toString() {
return "BlockDocValuesReader.FloatDenseVectorValuesBlockReader";
}
}

private static class FloatDenseVectorNormalizedValuesBlockReader extends DenseVectorValuesBlockReader<FloatVectorValues> {
private final NumericDocValues magnitudeDocValues;

FloatDenseVectorNormalizedValuesBlockReader(
FloatVectorValues floatVectorValues,
int dimensions,
NumericDocValues magnitudeDocValues
) {
super(floatVectorValues, dimensions);
this.magnitudeDocValues = magnitudeDocValues;
}

@Override
protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
float[] floats = vectorValues.vectorValue(iterator.index());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we refactor checking the dimension count out, since it's shared between both block readers?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there's a dimension() common method for KnnVectorValues that I wasn't aware of. Done in e98bbcb

assert floats.length == dimensions
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length;

float magnitude = 1.0f;
// If all vectors are normalized, no doc values will be present. The vector may be normalized already, so we may not have a
// stored magnitude for all docs
if ((magnitudeDocValues != null) && magnitudeDocValues.advanceExact(iterator.docID())) {
magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue());
}
for (float aFloat : floats) {
builder.appendFloat(aFloat * magnitude);
}
}

@Override
public String toString() {
return "BlockDocValuesReader.FloatVectorValuesBlockReader";
return "BlockDocValuesReader.FloatDenseVectorNormalizedValuesBlockReader";
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -734,31 +734,29 @@ IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldTyp
this,
denseVectorFieldType.dims,
denseVectorFieldType.indexed,
denseVectorFieldType.indexVersionCreated.onOrAfter(NORMALIZE_COSINE)
&& denseVectorFieldType.indexed
&& denseVectorFieldType.similarity.equals(VectorSimilarity.COSINE) ? r -> new FilterLeafReader(r) {
@Override
public CacheHelper getCoreCacheHelper() {
return r.getCoreCacheHelper();
}
denseVectorFieldType.isNormalized() && denseVectorFieldType.indexed ? r -> new FilterLeafReader(r) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used the isNormalized() method to improve readability - that messed up some indentation

@Override
public CacheHelper getCoreCacheHelper() {
return r.getCoreCacheHelper();
}

@Override
public CacheHelper getReaderCacheHelper() {
return r.getReaderCacheHelper();
}
@Override
public CacheHelper getReaderCacheHelper() {
return r.getReaderCacheHelper();
}

@Override
public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException {
FloatVectorValues values = in.getFloatVectorValues(fieldName);
if (values == null) {
return null;
}
return new DenormalizedCosineFloatVectorValues(
values,
in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX)
);
@Override
public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException {
FloatVectorValues values = in.getFloatVectorValues(fieldName);
if (values == null) {
return null;
}
} : r -> r
return new DenormalizedCosineFloatVectorValues(
values,
in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX)
);
}
} : r -> r
);
}

Expand Down Expand Up @@ -820,9 +818,7 @@ public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFie
fieldMapper.checkDimensionMatches(index, context);
checkVectorBounds(vector);
checkVectorMagnitude(fieldMapper.fieldType().similarity, errorFloatElementsAppender(vector), squaredMagnitude);
if (fieldMapper.indexCreatedVersion.onOrAfter(NORMALIZE_COSINE)
&& fieldMapper.fieldType().similarity.equals(VectorSimilarity.COSINE)
&& isNotUnitVector(squaredMagnitude)) {
if (fieldMapper.fieldType().isNormalized() && isNotUnitVector(squaredMagnitude)) {
float length = (float) Math.sqrt(squaredMagnitude);
for (int i = 0; i < vector.length; i++) {
vector[i] /= length;
Expand Down Expand Up @@ -2491,6 +2487,10 @@ public Query createExactKnnQuery(VectorData queryVector, Float vectorSimilarity)
return knnQuery;
}

public boolean isNormalized() {
return indexVersionCreated.onOrAfter(NORMALIZE_COSINE) && VectorSimilarity.COSINE.equals(similarity);
}

private Query createExactKnnBitQuery(byte[] queryVector) {
elementType.checkDimensions(dims, queryVector.length);
return new DenseVectorQuery.Bytes(queryVector, name());
Expand All @@ -2511,9 +2511,7 @@ private Query createExactKnnFloatQuery(float[] queryVector) {
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude);
if (similarity == VectorSimilarity.COSINE
&& indexVersionCreated.onOrAfter(NORMALIZE_COSINE)
&& isNotUnitVector(squaredMagnitude)) {
if (isNormalized() && isNotUnitVector(squaredMagnitude)) {
float length = (float) Math.sqrt(squaredMagnitude);
queryVector = Arrays.copyOf(queryVector, queryVector.length);
for (int i = 0; i < queryVector.length; i++) {
Expand Down Expand Up @@ -2703,9 +2701,7 @@ private Query createKnnFloatQuery(
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude);
if (similarity == VectorSimilarity.COSINE
&& indexVersionCreated.onOrAfter(NORMALIZE_COSINE)
&& isNotUnitVector(squaredMagnitude)) {
if (isNormalized() && isNotUnitVector(squaredMagnitude)) {
float length = (float) Math.sqrt(squaredMagnitude);
queryVector = Arrays.copyOf(queryVector, queryVector.length);
for (int i = 0; i < queryVector.length; i++) {
Expand Down Expand Up @@ -2795,7 +2791,7 @@ int getVectorDimensions() {
return dims;
}

ElementType getElementType() {
public ElementType getElementType() {
return elementType;
}

Expand All @@ -2816,7 +2812,7 @@ public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) {
}

if (indexed) {
return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims);
return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims, this);
}

if (hasDocValues() && (blContext.fieldExtractPreference() != FieldExtractPreference.STORED || isSyntheticSource)) {
Expand Down
Loading