diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index 6a0b2eac47e85..017e713fe09fe 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -38,6 +38,7 @@ import java.io.IOException; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.COSINE_MAGNITUDE_FIELD_SUFFIX; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE; /** @@ -540,6 +541,11 @@ public AllReader reader(LeafReaderContext context) throws IOException { case FLOAT -> { FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); if (floatVectorValues != null) { + if (fieldType.isNormalized()) { + NumericDocValues magnitudeDocValues = context.reader() + .getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX); + return new FloatDenseVectorNormalizedValuesBlockReader(floatVectorValues, dimensions, magnitudeDocValues); + } return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); } } @@ -584,6 +590,9 @@ public void read(int docId, BlockLoader.StoredFields storedFields, Builder build } private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException { + assert vectorValues.dimension() == dimensions + : "unexpected dimensions for vector value; expected " + dimensions + " but got " + vectorValues.dimension(); + if (iterator.docID() > doc) { builder.appendNull(); } else if (iterator.docID() == doc || iterator.advance(doc) == doc) { @@ -611,8 +620,6 @@ private static class FloatDenseVectorValuesBlockReader extends DenseVectorValues protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { float[] floats = vectorValues.vectorValue(iterator.index()); - assert floats.length == dimensions - : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; for (float aFloat : floats) { builder.appendFloat(aFloat); } @@ -624,6 +631,38 @@ public String toString() { } } + private static class FloatDenseVectorNormalizedValuesBlockReader extends DenseVectorValuesBlockReader { + private final NumericDocValues magnitudeDocValues; + + FloatDenseVectorNormalizedValuesBlockReader( + FloatVectorValues floatVectorValues, + int dimensions, + NumericDocValues magnitudeDocValues + ) { + super(floatVectorValues, dimensions); + this.magnitudeDocValues = magnitudeDocValues; + } + + @Override + protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { + float magnitude = 1.0f; + // If all vectors are normalized, no doc values will be present. The vector may be normalized already, so we may not have a + // stored magnitude for all docs + if ((magnitudeDocValues != null) && magnitudeDocValues.advanceExact(iterator.docID())) { + magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue()); + } + float[] floats = vectorValues.vectorValue(iterator.index()); + for (float aFloat : floats) { + builder.appendFloat(aFloat * magnitude); + } + } + + @Override + public String toString() { + return "BlockDocValuesReader.FloatDenseVectorNormalizedValuesBlockReader"; + } + } + private static class ByteDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader { ByteDenseVectorValuesBlockReader(ByteVectorValues floatVectorValues, int dimensions) { super(floatVectorValues, dimensions); @@ -631,8 +670,6 @@ private static class ByteDenseVectorValuesBlockReader extends DenseVectorValuesB protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { byte[] bytes = vectorValues.vectorValue(iterator.index()); - assert bytes.length == dimensions - : "unexpected dimensions for vector value; expected " + dimensions + " but got " + bytes.length; for (byte aFloat : bytes) { builder.appendFloat(aFloat); } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 8f9e613d2acec..ad482bfa1b60c 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -24,11 +24,13 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import static org.elasticsearch.index.IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING; import static org.elasticsearch.index.mapper.SourceFieldMapper.Mode.SYNTHETIC; @@ -36,17 +38,16 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { - public static final Set ALL_DENSE_VECTOR_INDEX_TYPES = Set.of( - "int8_hnsw", - "hnsw", - "int4_hnsw", - "bbq_hnsw", - "int8_flat", - "int4_flat", - "bbq_flat", - "flat" - ); - public static final Set NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Set.of("hnsw", "flat"); + public static final Set ALL_DENSE_VECTOR_INDEX_TYPES = Arrays.stream(DenseVectorFieldMapper.VectorIndexType.values()) + .filter(DenseVectorFieldMapper.VectorIndexType::isEnabled) + .map(v -> v.getName().toLowerCase(Locale.ROOT)) + .collect(Collectors.toSet()); + + public static final Set NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Arrays.stream(DenseVectorFieldMapper.VectorIndexType.values()) + .filter(t -> t.isEnabled() && t.isQuantized() == false) + .map(v -> v.getName().toLowerCase(Locale.ROOT)) + .collect(Collectors.toSet()); + public static final float DELTA = 1e-7F; private final ElementType elementType; @@ -57,15 +58,10 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { @ParametersFactory public static Iterable parameters() throws Exception { List params = new ArrayList<>(); - List similarities = List.of( - DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT, - DenseVectorFieldMapper.VectorSimilarity.L2_NORM, - DenseVectorFieldMapper.VectorSimilarity.MAX_INNER_PRODUCT - ); for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT)) { - // Test all similarities for element types - for (DenseVectorFieldMapper.VectorSimilarity similarity : similarities) { + // Test all similarities + for (DenseVectorFieldMapper.VectorSimilarity similarity : DenseVectorFieldMapper.VectorSimilarity.values()) { params.add(new Object[] { elementType, similarity, true, false }); } @@ -74,6 +70,7 @@ public static Iterable parameters() throws Exception { // No indexing, synthetic source params.add(new Object[] { elementType, null, false, true }); } + return params; }