Skip to content

Commit 763fe63

Browse files
committed
Take into account normalization for dense vector support
1 parent 8d12540 commit 763fe63

File tree

4 files changed

+270
-109
lines changed

4 files changed

+270
-109
lines changed

server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,15 @@
2929
import org.elasticsearch.index.mapper.BlockLoader.DoubleBuilder;
3030
import org.elasticsearch.index.mapper.BlockLoader.IntBuilder;
3131
import org.elasticsearch.index.mapper.BlockLoader.LongBuilder;
32+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
33+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
3234
import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder;
3335
import org.elasticsearch.search.fetch.StoredFieldsSpec;
3436

3537
import java.io.IOException;
3638

39+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.COSINE_MAGNITUDE_FIELD_SUFFIX;
40+
3741
/**
3842
* A reader that supports reading doc-values from a Lucene segment in Block fashion.
3943
*/
@@ -511,10 +515,12 @@ public String toString() {
511515
public static class DenseVectorBlockLoader extends DocValuesBlockLoader {
512516
private final String fieldName;
513517
private final int dimensions;
518+
private final DenseVectorFieldMapper.DenseVectorFieldType fieldType;
514519

515-
public DenseVectorBlockLoader(String fieldName, int dimensions) {
520+
public DenseVectorBlockLoader(String fieldName, int dimensions, DenseVectorFieldMapper.DenseVectorFieldType fieldType) {
516521
this.fieldName = fieldName;
517522
this.dimensions = dimensions;
523+
this.fieldType = fieldType;
518524
}
519525

520526
@Override
@@ -524,9 +530,26 @@ public Builder builder(BlockFactory factory, int expectedCount) {
524530

525531
@Override
526532
public AllReader reader(LeafReaderContext context) throws IOException {
527-
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName);
528-
if (floatVectorValues != null) {
529-
return new DenseVectorValuesBlockReader(floatVectorValues, dimensions);
533+
switch (fieldType.getElementType()) {
534+
case FLOAT -> {
535+
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName);
536+
if (floatVectorValues != null) {
537+
if (fieldType.isNormalized()) {
538+
return new FloatDenseVectorNormalizedValuesBlockReader(
539+
floatVectorValues,
540+
dimensions,
541+
context.reader().getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX)
542+
);
543+
}
544+
return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions);
545+
}
546+
}
547+
case BYTE -> {
548+
ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(fieldName);
549+
if (byteVectorValues != null) {
550+
return new ByteDenseVectorValuesBlockReader(byteVectorValues, dimensions);
551+
}
552+
}
530553
}
531554
return new ConstantNullsReader();
532555
}
@@ -580,10 +603,77 @@ private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException
580603
public int docId() {
581604
return iterator.docID();
582605
}
606+
}
607+
608+
private static class FloatDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader<FloatVectorValues> {
609+
610+
FloatDenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) {
611+
super(floatVectorValues, dimensions);
612+
}
613+
614+
protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
615+
float[] floats = vectorValues.vectorValue(iterator.index());
616+
assert floats.length == dimensions
617+
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length;
618+
for (float aFloat : floats) {
619+
builder.appendFloat(aFloat);
620+
}
621+
}
622+
623+
@Override
624+
public String toString() {
625+
return "BlockDocValuesReader.FloatDenseVectorValuesBlockReader";
626+
}
627+
}
628+
629+
private static class FloatDenseVectorNormalizedValuesBlockReader extends DenseVectorValuesBlockReader<FloatVectorValues> {
630+
private final NumericDocValues magnitudeDocValues;
631+
632+
FloatDenseVectorNormalizedValuesBlockReader(
633+
FloatVectorValues floatVectorValues,
634+
int dimensions,
635+
NumericDocValues magnitudeDocValues
636+
) {
637+
super(floatVectorValues, dimensions);
638+
this.magnitudeDocValues = magnitudeDocValues;
639+
}
640+
641+
@Override
642+
protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
643+
float[] floats = vectorValues.vectorValue(iterator.index());
644+
assert floats.length == dimensions
645+
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length;
646+
647+
assert magnitudeDocValues.advanceExact(iterator.docID());
648+
float magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue());
649+
for (float aFloat : floats) {
650+
builder.appendFloat(aFloat * magnitude);
651+
}
652+
}
653+
654+
@Override
655+
public String toString() {
656+
return "BlockDocValuesReader.FloatDenseVectorNormalizedValuesBlockReader";
657+
}
658+
}
659+
660+
private static class ByteDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader<ByteVectorValues> {
661+
ByteDenseVectorValuesBlockReader(ByteVectorValues floatVectorValues, int dimensions) {
662+
super(floatVectorValues, dimensions);
663+
}
664+
665+
protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
666+
byte[] bytes = vectorValues.vectorValue(iterator.index());
667+
assert bytes.length == dimensions
668+
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + bytes.length;
669+
for (byte aFloat : bytes) {
670+
builder.appendFloat(aFloat);
671+
}
672+
}
583673

584674
@Override
585675
public String toString() {
586-
return "BlockDocValuesReader.FloatVectorValuesBlockReader";
676+
return "BlockDocValuesReader.ByteDenseVectorValuesBlockReader";
587677
}
588678
}
589679

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -734,31 +734,29 @@ IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldTyp
734734
this,
735735
denseVectorFieldType.dims,
736736
denseVectorFieldType.indexed,
737-
denseVectorFieldType.indexVersionCreated.onOrAfter(NORMALIZE_COSINE)
738-
&& denseVectorFieldType.indexed
739-
&& denseVectorFieldType.similarity.equals(VectorSimilarity.COSINE) ? r -> new FilterLeafReader(r) {
740-
@Override
741-
public CacheHelper getCoreCacheHelper() {
742-
return r.getCoreCacheHelper();
743-
}
737+
denseVectorFieldType.isNormalized() && denseVectorFieldType.indexed ? r -> new FilterLeafReader(r) {
738+
@Override
739+
public CacheHelper getCoreCacheHelper() {
740+
return r.getCoreCacheHelper();
741+
}
744742

745-
@Override
746-
public CacheHelper getReaderCacheHelper() {
747-
return r.getReaderCacheHelper();
748-
}
743+
@Override
744+
public CacheHelper getReaderCacheHelper() {
745+
return r.getReaderCacheHelper();
746+
}
749747

750-
@Override
751-
public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException {
752-
FloatVectorValues values = in.getFloatVectorValues(fieldName);
753-
if (values == null) {
754-
return null;
755-
}
756-
return new DenormalizedCosineFloatVectorValues(
757-
values,
758-
in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX)
759-
);
748+
@Override
749+
public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException {
750+
FloatVectorValues values = in.getFloatVectorValues(fieldName);
751+
if (values == null) {
752+
return null;
760753
}
761-
} : r -> r
754+
return new DenormalizedCosineFloatVectorValues(
755+
values,
756+
in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX)
757+
);
758+
}
759+
} : r -> r
762760
);
763761
}
764762

@@ -820,9 +818,7 @@ public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFie
820818
fieldMapper.checkDimensionMatches(index, context);
821819
checkVectorBounds(vector);
822820
checkVectorMagnitude(fieldMapper.fieldType().similarity, errorFloatElementsAppender(vector), squaredMagnitude);
823-
if (fieldMapper.indexCreatedVersion.onOrAfter(NORMALIZE_COSINE)
824-
&& fieldMapper.fieldType().similarity.equals(VectorSimilarity.COSINE)
825-
&& isNotUnitVector(squaredMagnitude)) {
821+
if (fieldMapper.fieldType().isNormalized() && isNotUnitVector(squaredMagnitude)) {
826822
float length = (float) Math.sqrt(squaredMagnitude);
827823
for (int i = 0; i < vector.length; i++) {
828824
vector[i] /= length;
@@ -2491,6 +2487,10 @@ public Query createExactKnnQuery(VectorData queryVector, Float vectorSimilarity)
24912487
return knnQuery;
24922488
}
24932489

2490+
public boolean isNormalized() {
2491+
return indexVersionCreated.onOrAfter(NORMALIZE_COSINE) && VectorSimilarity.COSINE.equals(similarity);
2492+
}
2493+
24942494
private Query createExactKnnBitQuery(byte[] queryVector) {
24952495
elementType.checkDimensions(dims, queryVector.length);
24962496
return new DenseVectorQuery.Bytes(queryVector, name());
@@ -2511,9 +2511,7 @@ private Query createExactKnnFloatQuery(float[] queryVector) {
25112511
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
25122512
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
25132513
elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude);
2514-
if (similarity == VectorSimilarity.COSINE
2515-
&& indexVersionCreated.onOrAfter(NORMALIZE_COSINE)
2516-
&& isNotUnitVector(squaredMagnitude)) {
2514+
if (isNormalized() && isNotUnitVector(squaredMagnitude)) {
25172515
float length = (float) Math.sqrt(squaredMagnitude);
25182516
queryVector = Arrays.copyOf(queryVector, queryVector.length);
25192517
for (int i = 0; i < queryVector.length; i++) {
@@ -2703,9 +2701,7 @@ private Query createKnnFloatQuery(
27032701
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
27042702
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
27052703
elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude);
2706-
if (similarity == VectorSimilarity.COSINE
2707-
&& indexVersionCreated.onOrAfter(NORMALIZE_COSINE)
2708-
&& isNotUnitVector(squaredMagnitude)) {
2704+
if (isNormalized() && isNotUnitVector(squaredMagnitude)) {
27092705
float length = (float) Math.sqrt(squaredMagnitude);
27102706
queryVector = Arrays.copyOf(queryVector, queryVector.length);
27112707
for (int i = 0; i < queryVector.length; i++) {
@@ -2795,7 +2791,7 @@ int getVectorDimensions() {
27952791
return dims;
27962792
}
27972793

2798-
ElementType getElementType() {
2794+
public ElementType getElementType() {
27992795
return elementType;
28002796
}
28012797

@@ -2816,7 +2812,7 @@ public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) {
28162812
}
28172813

28182814
if (indexed) {
2819-
return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims);
2815+
return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims, this);
28202816
}
28212817

28222818
if (hasDocValues() && (blContext.fieldExtractPreference() != FieldExtractPreference.STORED || isSyntheticSource)) {

0 commit comments

Comments
 (0)