Skip to content

Commit a4aca14

Browse files
committed
Take into account normalization
1 parent 0402ad9 commit a4aca14

File tree

4 files changed

+158
-102
lines changed

4 files changed

+158
-102
lines changed

server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@
3030
import org.elasticsearch.index.mapper.BlockLoader.DoubleBuilder;
3131
import org.elasticsearch.index.mapper.BlockLoader.IntBuilder;
3232
import org.elasticsearch.index.mapper.BlockLoader.LongBuilder;
33+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
3334
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
3435
import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder;
3536
import org.elasticsearch.search.fetch.StoredFieldsSpec;
3637

3738
import java.io.IOException;
3839

40+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.COSINE_MAGNITUDE_FIELD_SUFFIX;
41+
3942
/**
4043
* A reader that supports reading doc-values from a Lucene segment in Block fashion.
4144
*/
@@ -513,12 +516,12 @@ public String toString() {
513516
public static class DenseVectorBlockLoader extends DocValuesBlockLoader {
514517
private final String fieldName;
515518
private final int dimensions;
516-
private final ElementType elementType;
519+
private final DenseVectorFieldMapper.DenseVectorFieldType fieldType;
517520

518-
public DenseVectorBlockLoader(String fieldName, int dimensions, ElementType elementType) {
521+
public DenseVectorBlockLoader(String fieldName, int dimensions, DenseVectorFieldMapper.DenseVectorFieldType fieldType) {
519522
this.fieldName = fieldName;
520523
this.dimensions = dimensions;
521-
this.elementType = elementType;
524+
this.fieldType = fieldType;
522525
}
523526

524527
@Override
@@ -528,10 +531,17 @@ public Builder builder(BlockFactory factory, int expectedCount) {
528531

529532
@Override
530533
public AllReader reader(LeafReaderContext context) throws IOException {
531-
switch (elementType) {
534+
switch (fieldType.getElementType()) {
532535
case FLOAT -> {
533536
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName);
534537
if (floatVectorValues != null) {
538+
if (fieldType.isNormalized()) {
539+
return new FloatDenseVectorNormalizedValuesBlockReader(
540+
floatVectorValues,
541+
dimensions,
542+
context.reader().getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX)
543+
);
544+
}
535545
return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions);
536546
}
537547
}
@@ -596,6 +606,7 @@ public int docId() {
596606
}
597607

598608
private static class FloatDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader<FloatVectorValues> {
609+
599610
FloatDenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) {
600611
super(floatVectorValues, dimensions);
601612
}
@@ -615,6 +626,37 @@ public String toString() {
615626
}
616627
}
617628

629+
private static class FloatDenseVectorNormalizedValuesBlockReader extends DenseVectorValuesBlockReader<FloatVectorValues> {
630+
private final NumericDocValues magnitudeDocValues;
631+
632+
FloatDenseVectorNormalizedValuesBlockReader(
633+
FloatVectorValues floatVectorValues,
634+
int dimensions,
635+
NumericDocValues magnitudeDocValues
636+
) {
637+
super(floatVectorValues, dimensions);
638+
this.magnitudeDocValues = magnitudeDocValues;
639+
}
640+
641+
@Override
642+
protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
643+
float[] floats = vectorValues.vectorValue(iterator.index());
644+
assert floats.length == dimensions
645+
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length;
646+
647+
assert magnitudeDocValues.advanceExact(iterator.docID());
648+
float magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue());
649+
for (float aFloat : floats) {
650+
builder.appendFloat(aFloat * magnitude);
651+
}
652+
}
653+
654+
@Override
655+
public String toString() {
656+
return "BlockDocValuesReader.FloatDenseVectorNormalizedValuesBlockReader";
657+
}
658+
}
659+
618660
private static class ByteDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader<ByteVectorValues> {
619661
ByteDenseVectorValuesBlockReader(ByteVectorValues floatVectorValues, int dimensions) {
620662
super(floatVectorValues, dimensions);

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -734,31 +734,29 @@ IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldTyp
734734
this,
735735
denseVectorFieldType.dims,
736736
denseVectorFieldType.indexed,
737-
denseVectorFieldType.indexVersionCreated.onOrAfter(NORMALIZE_COSINE)
738-
&& denseVectorFieldType.indexed
739-
&& denseVectorFieldType.similarity.equals(VectorSimilarity.COSINE) ? r -> new FilterLeafReader(r) {
740-
@Override
741-
public CacheHelper getCoreCacheHelper() {
742-
return r.getCoreCacheHelper();
743-
}
737+
denseVectorFieldType.isNormalized() && denseVectorFieldType.indexed ? r -> new FilterLeafReader(r) {
738+
@Override
739+
public CacheHelper getCoreCacheHelper() {
740+
return r.getCoreCacheHelper();
741+
}
744742

745-
@Override
746-
public CacheHelper getReaderCacheHelper() {
747-
return r.getReaderCacheHelper();
748-
}
743+
@Override
744+
public CacheHelper getReaderCacheHelper() {
745+
return r.getReaderCacheHelper();
746+
}
749747

750-
@Override
751-
public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException {
752-
FloatVectorValues values = in.getFloatVectorValues(fieldName);
753-
if (values == null) {
754-
return null;
755-
}
756-
return new DenormalizedCosineFloatVectorValues(
757-
values,
758-
in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX)
759-
);
748+
@Override
749+
public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException {
750+
FloatVectorValues values = in.getFloatVectorValues(fieldName);
751+
if (values == null) {
752+
return null;
760753
}
761-
} : r -> r
754+
return new DenormalizedCosineFloatVectorValues(
755+
values,
756+
in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX)
757+
);
758+
}
759+
} : r -> r
762760
);
763761
}
764762

@@ -820,9 +818,7 @@ public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFie
820818
fieldMapper.checkDimensionMatches(index, context);
821819
checkVectorBounds(vector);
822820
checkVectorMagnitude(fieldMapper.fieldType().similarity, errorFloatElementsAppender(vector), squaredMagnitude);
823-
if (fieldMapper.indexCreatedVersion.onOrAfter(NORMALIZE_COSINE)
824-
&& fieldMapper.fieldType().similarity.equals(VectorSimilarity.COSINE)
825-
&& isNotUnitVector(squaredMagnitude)) {
821+
if (fieldMapper.fieldType().isNormalized() && isNotUnitVector(squaredMagnitude)) {
826822
float length = (float) Math.sqrt(squaredMagnitude);
827823
for (int i = 0; i < vector.length; i++) {
828824
vector[i] /= length;
@@ -2491,6 +2487,10 @@ public Query createExactKnnQuery(VectorData queryVector, Float vectorSimilarity)
24912487
return knnQuery;
24922488
}
24932489

2490+
public boolean isNormalized() {
2491+
return indexVersionCreated.onOrAfter(NORMALIZE_COSINE) && VectorSimilarity.COSINE.equals(similarity);
2492+
}
2493+
24942494
private Query createExactKnnBitQuery(byte[] queryVector) {
24952495
elementType.checkDimensions(dims, queryVector.length);
24962496
return new DenseVectorQuery.Bytes(queryVector, name());
@@ -2511,9 +2511,7 @@ private Query createExactKnnFloatQuery(float[] queryVector) {
25112511
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
25122512
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
25132513
elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude);
2514-
if (similarity == VectorSimilarity.COSINE
2515-
&& indexVersionCreated.onOrAfter(NORMALIZE_COSINE)
2516-
&& isNotUnitVector(squaredMagnitude)) {
2514+
if (isNormalized() && isNotUnitVector(squaredMagnitude)) {
25172515
float length = (float) Math.sqrt(squaredMagnitude);
25182516
queryVector = Arrays.copyOf(queryVector, queryVector.length);
25192517
for (int i = 0; i < queryVector.length; i++) {
@@ -2703,9 +2701,7 @@ private Query createKnnFloatQuery(
27032701
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
27042702
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
27052703
elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude);
2706-
if (similarity == VectorSimilarity.COSINE
2707-
&& indexVersionCreated.onOrAfter(NORMALIZE_COSINE)
2708-
&& isNotUnitVector(squaredMagnitude)) {
2704+
if (isNormalized() && isNotUnitVector(squaredMagnitude)) {
27092705
float length = (float) Math.sqrt(squaredMagnitude);
27102706
queryVector = Arrays.copyOf(queryVector, queryVector.length);
27112707
for (int i = 0; i < queryVector.length; i++) {
@@ -2795,7 +2791,7 @@ int getVectorDimensions() {
27952791
return dims;
27962792
}
27972793

2798-
ElementType getElementType() {
2794+
public ElementType getElementType() {
27992795
return elementType;
28002796
}
28012797

@@ -2816,7 +2812,7 @@ public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) {
28162812
}
28172813

28182814
if (indexed) {
2819-
return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims, elementType);
2815+
return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims, this);
28202816
}
28212817

28222818
if (hasDocValues() && (blContext.fieldExtractPreference() != FieldExtractPreference.STORED || isSyntheticSource)) {

0 commit comments

Comments
 (0)