From a62c94327a135d4f6844bdf663941ed6cb6528d5 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 24 Jul 2025 17:01:06 +0200 Subject: [PATCH 01/21] Add byte vector support --- .../index/mapper/BlockDocValuesReader.java | 178 ++++++++++++++---- .../vectors/DenseVectorFieldMapper.java | 8 +- .../mapper/vectors/VectorEncoderDecoder.java | 20 ++ 3 files changed, 162 insertions(+), 44 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index f95e35a5d0845..7065adb0a816d 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -10,6 +10,7 @@ package org.elasticsearch.index.mapper; import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.KnnVectorValues; @@ -29,6 +30,7 @@ import org.elasticsearch.index.mapper.BlockLoader.DoubleBuilder; import org.elasticsearch.index.mapper.BlockLoader.IntBuilder; import org.elasticsearch.index.mapper.BlockLoader.LongBuilder; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder; import org.elasticsearch.search.fetch.StoredFieldsSpec; @@ -511,10 +513,12 @@ public String toString() { public static class DenseVectorBlockLoader extends DocValuesBlockLoader { private final String fieldName; private final int dimensions; + private final ElementType elementType; - public DenseVectorBlockLoader(String fieldName, int dimensions) { + public DenseVectorBlockLoader(String fieldName, int dimensions, ElementType elementType) { this.fieldName = fieldName; this.dimensions = dimensions; + this.elementType = elementType; } @Override @@ -524,22 +528,34 @@ public Builder builder(BlockFactory factory, int expectedCount) { @Override public AllReader reader(LeafReaderContext context) throws IOException { - FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); - if (floatVectorValues != null) { - return new DenseVectorValuesBlockReader(floatVectorValues, dimensions); + switch (elementType) { + case FLOAT -> { + FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); + if (floatVectorValues != null) { + return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); + } + } + case BYTE -> { + ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(fieldName); + if (byteVectorValues != null) { + return new ByteDenseVectorValuesBlockReader(byteVectorValues, dimensions); + } + } } + return new ConstantNullsReader(); } } - private static class DenseVectorValuesBlockReader extends BlockDocValuesReader { - private final FloatVectorValues floatVectorValues; - private final KnnVectorValues.DocIndexIterator iterator; - private final int dimensions; + private abstract static class DenseVectorValuesBlockReader extends BlockDocValuesReader { + + protected final T vectorValues; + protected final KnnVectorValues.DocIndexIterator iterator; + protected final int dimensions; - DenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) { - this.floatVectorValues = floatVectorValues; - iterator = floatVectorValues.iterator(); + DenseVectorValuesBlockReader(T vectorValues, int dimensions) { + this.vectorValues = vectorValues; + iterator = vectorValues.iterator(); this.dimensions = dimensions; } @@ -566,26 +582,58 @@ public void read(int docId, BlockLoader.StoredFields storedFields, Builder build private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException { if (iterator.advance(doc) == doc) { builder.beginPositionEntry(); - float[] floats = floatVectorValues.vectorValue(iterator.index()); - assert floats.length == dimensions - : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; - for (float aFloat : floats) { - builder.appendFloat(aFloat); - } + appendDoc(builder); builder.endPositionEntry(); } else { builder.appendNull(); } } + protected abstract void appendDoc(BlockLoader.FloatBuilder builder) throws IOException; + @Override public int docId() { return iterator.docID(); } + } + + private static class FloatDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader { + FloatDenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) { + super(floatVectorValues, dimensions); + } + + protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { + float[] floats = vectorValues.vectorValue(iterator.index()); + assert floats.length == dimensions + : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; + for (float aFloat : floats) { + builder.appendFloat(aFloat); + } + } + + @Override + public String toString() { + return "BlockDocValuesReader.FloatDenseVectorValuesBlockReader"; + } + } + + private static class ByteDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader { + ByteDenseVectorValuesBlockReader(ByteVectorValues floatVectorValues, int dimensions) { + super(floatVectorValues, dimensions); + } + + protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { + byte[] bytes = vectorValues.vectorValue(iterator.index()); + assert bytes.length == dimensions + : "unexpected dimensions for vector value; expected " + dimensions + " but got " + bytes.length; + for (byte aFloat : bytes) { + builder.appendFloat(aFloat); + } + } @Override public String toString() { - return "BlockDocValuesReader.FloatVectorValuesBlockReader"; + return "BlockDocValuesReader.ByteDenseVectorValuesBlockReader"; } } @@ -877,11 +925,13 @@ public static class DenseVectorFromBinaryBlockLoader extends DocValuesBlockLoade private final String fieldName; private final int dims; private final IndexVersion indexVersion; + private final ElementType elementType; - public DenseVectorFromBinaryBlockLoader(String fieldName, int dims, IndexVersion indexVersion) { + public DenseVectorFromBinaryBlockLoader(String fieldName, int dims, IndexVersion indexVersion, ElementType elementType) { this.fieldName = fieldName; this.dims = dims; this.indexVersion = indexVersion; + this.elementType = elementType; } @Override @@ -895,23 +945,40 @@ public AllReader reader(LeafReaderContext context) throws IOException { if (docValues == null) { return new ConstantNullsReader(); } - return new DenseVectorFromBinary(docValues, dims, indexVersion); + switch (elementType) { + case FLOAT: + return new FloatDenseVectorFromBinary(docValues, dims, indexVersion); + case BYTE: + return new ByteDenseVectorFromBinary(docValues, dims, indexVersion); + default: + throw new IllegalArgumentException("Unknown element type [" + elementType + "]"); + } } } - private static class DenseVectorFromBinary extends BlockDocValuesReader { - private final BinaryDocValues docValues; - private final IndexVersion indexVersion; - private final int dimensions; - private final float[] scratch; + // Abstract base for dense vector readers + private abstract static class AbstractDenseVectorFromBinary extends BlockDocValuesReader { + protected final BinaryDocValues docValues; + protected final IndexVersion indexVersion; + protected final int dimensions; + protected final T scratch; + protected int docID = -1; - private int docID = -1; - - DenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) { + AbstractDenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion, T scratch) { this.docValues = docValues; - this.scratch = new float[dims]; this.indexVersion = indexVersion; this.dimensions = dims; + this.scratch = scratch; + } + + @Override + public int docId() { + return docID; + } + + @Override + public void read(int docId, BlockLoader.StoredFields storedFields, Builder builder) throws IOException { + read(docId, (BlockLoader.FloatBuilder) builder); } @Override @@ -928,36 +995,67 @@ public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throw } } - @Override - public void read(int docId, BlockLoader.StoredFields storedFields, Builder builder) throws IOException { - read(docId, (BlockLoader.FloatBuilder) builder); - } - private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException { this.docID = doc; - if (false == docValues.advanceExact(doc)) { + if (docValues.advanceExact(doc) == false) { builder.appendNull(); return; } BytesRef bytesRef = docValues.binaryValue(); assert bytesRef.length > 0; - VectorEncoderDecoder.decodeDenseVector(indexVersion, bytesRef, scratch); + decodeDenseVector(bytesRef, scratch); builder.beginPositionEntry(); + writeScratchToBuilder(scratch, builder); + builder.endPositionEntry(); + } + + protected abstract void decodeDenseVector(BytesRef bytesRef, T scratch); + + protected abstract void writeScratchToBuilder(T scratch, BlockLoader.FloatBuilder builder); + } + + private static class FloatDenseVectorFromBinary extends AbstractDenseVectorFromBinary { + FloatDenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) { + super(docValues, dims, indexVersion, new float[dims]); + } + + @Override + protected void writeScratchToBuilder(float[] scratch, BlockLoader.FloatBuilder builder) { for (float value : scratch) { builder.appendFloat(value); } - builder.endPositionEntry(); } @Override - public int docId() { - return docID; + protected void decodeDenseVector(BytesRef bytesRef, float[] scratch) { + VectorEncoderDecoder.decodeDenseVector(indexVersion, bytesRef, scratch); + } + + @Override + public String toString() { + return "FloatDenseVectorFromBinary.Bytes"; + } + } + + private static class ByteDenseVectorFromBinary extends AbstractDenseVectorFromBinary { + ByteDenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) { + super(docValues, dims, indexVersion, new byte[dims]); } @Override public String toString() { - return "DenseVectorFromBinary.Bytes"; + return "ByteDenseVectorFromBinary.Bytes"; + } + + protected void writeScratchToBuilder(byte[] scratch, BlockLoader.FloatBuilder builder) { + for (byte value : scratch) { + builder.appendFloat(value); + } + } + + protected void decodeDenseVector(BytesRef bytesRef, byte[] scratch) { + VectorEncoderDecoder.decodeDenseVector(indexVersion, bytesRef, scratch); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 9019edc435eaf..11afe50da9e70 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2805,8 +2805,8 @@ public DenseVectorIndexOptions getIndexOptions() { @Override public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) { - if (elementType != ElementType.FLOAT) { - // Just float dense vector support for now + if (elementType == ElementType.BIT) { + // Just float and byte dense vector support for now return null; } @@ -2816,11 +2816,11 @@ public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) { } if (indexed) { - return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims); + return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims, elementType); } if (hasDocValues() && (blContext.fieldExtractPreference() != FieldExtractPreference.STORED || isSyntheticSource)) { - return new BlockDocValuesReader.DenseVectorFromBinaryBlockLoader(name(), dims, indexVersionCreated); + return new BlockDocValuesReader.DenseVectorFromBinaryBlockLoader(name(), dims, indexVersionCreated, elementType); } BlockSourceReader.LeafIteratorLookup lookup = BlockSourceReader.lookupMatchingAll(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java index 54b369ab1f377..9dec4a4f2dd61 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java @@ -84,6 +84,26 @@ public static void decodeDenseVector(IndexVersion indexVersion, BytesRef vectorB } } + /** + * Decodes a BytesRef into the provided array of bytes + * @param vectorBR - dense vector encoded in BytesRef + * @param vector - array of bytes where the decoded vector should be stored + */ + public static void decodeDenseVector(IndexVersion indexVersion, BytesRef vectorBR, byte[] vector) { + if (vectorBR == null) { + throw new IllegalArgumentException(DenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE); + } + if (indexVersion.onOrAfter(LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION)) { + ByteBuffer fb = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length).order(ByteOrder.LITTLE_ENDIAN); + fb.get(vector); + } else { + ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length); + for (int dim = 0; dim < vector.length; dim++) { + vector[dim] = byteBuffer.get(dim * vectorBR.offset); + } + } + } + public static float[] getMultiMagnitudes(BytesRef magnitudes) { assert magnitudes.length % Float.BYTES == 0; float[] multiMagnitudes = new float[magnitudes.length / Float.BYTES]; From 9317fcd3cf4b4adbac3ba68d6b9cdb30a79c514d Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 24 Jul 2025 17:01:22 +0200 Subject: [PATCH 02/21] Add byte vector support - tests --- .../main/resources/mapping-dense_vector.json | 12 +++- .../xpack/esql/DenseVectorFieldTypeIT.java | 72 ++++++++++++++----- .../xpack/esql/action/EsqlCapabilities.java | 7 +- .../xpack/esql/analysis/AnalyzerTests.java | 26 ++++--- 4 files changed, 85 insertions(+), 32 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-dense_vector.json b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-dense_vector.json index 9c7d34f0f15e4..3e21d490583bb 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-dense_vector.json +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-dense_vector.json @@ -3,7 +3,7 @@ "id": { "type": "long" }, - "vector": { + "float_vector": { "type": "dense_vector", "similarity": "l2_norm", "index_options": { @@ -11,6 +11,16 @@ "m": 16, "ef_construction": 100 } + }, + "byte_vector": { + "type": "dense_vector", + "similarity": "l2_norm", + "element_type": "byte", + "index_options": { + "type": "hnsw", + "m": 16, + "ef_construction": 100 + } } } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index a130b026cd88a..5aacba6ac09c2 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -13,6 +13,9 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; @@ -21,8 +24,10 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; @@ -32,7 +37,7 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { - private static final Set DENSE_VECTOR_INDEX_TYPES = Set.of( + private static final Set ALL_DENSE_VECTOR_INDEX_TYPES = Set.of( "int8_hnsw", "hnsw", "int4_hnsw", @@ -42,32 +47,48 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { "bbq_flat", "flat" ); + private static final Set NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Set.of( + "hnsw", + "flat" + ); + private final ElementType elementType; + private final boolean synthetic; private final String indexType; private final boolean index; - private final boolean synthetic; @ParametersFactory public static Iterable parameters() throws Exception { List params = new ArrayList<>(); // Indexed field types - for (String indexType : DENSE_VECTOR_INDEX_TYPES) { - params.add(new Object[] { indexType, true, false }); + for (String indexType : ALL_DENSE_VECTOR_INDEX_TYPES) { + params.add(new Object[] {ElementType.FLOAT, indexType, true, false }); + } + for (String indexType : NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES) { + params.add(new Object[] {ElementType.BYTE, indexType, true, false }); + } + for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT)) { + // No indexing + params.add(new Object[]{elementType, null, false, false}); + // No indexing, synthetic source + params.add(new Object[]{elementType, null, false, true}); } - // No indexing - params.add(new Object[] { null, false, false }); - // No indexing, synthetic source - params.add(new Object[] { null, false, true }); return params; } - public DenseVectorFieldTypeIT(@Name("indexType") String indexType, @Name("index") boolean index, @Name("synthetic") boolean synthetic) { + public DenseVectorFieldTypeIT( + @Name("elementType") ElementType elementType, + @Name("indexType") String indexType, + @Name("index") boolean index, + @Name("synthetic") boolean synthetic + ) { + this.elementType = elementType; this.indexType = indexType; this.index = index; this.synthetic = synthetic; } - private final Map> indexedVectors = new HashMap<>(); + private final Map> indexedVectors = new HashMap<>(); public void testRetrieveFieldType() { var query = """ @@ -93,11 +114,11 @@ public void testRetrieveTopNDenseVectorFieldData() { indexedVectors.forEach((id, vector) -> { var values = valuesList.get(id); assertEquals(id, values.get(0)); - List vectors = (List) values.get(1); + List vectors = (List) values.get(1); assertNotNull(vectors); assertEquals(vector.size(), vectors.size()); for (int i = 0; i < vector.size(); i++) { - assertEquals(vector.get(i), vectors.get(i), 0F); + assertEquals(vector.get(i).floatValue(), vectors.get(i).floatValue(), 0F); } }); } @@ -114,15 +135,14 @@ public void testRetrieveDenseVectorFieldData() { List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(valuesList.size(), indexedVectors.size()); valuesList.forEach(value -> { - ; assertEquals(2, value.size()); Integer id = (Integer) value.get(0); - List vector = (List) value.get(1); + List vector = (List) value.get(1); assertNotNull(vector); - List expectedVector = indexedVectors.get(id); + List expectedVector = indexedVectors.get(id); assertNotNull(expectedVector); for (int i = 0; i < vector.size(); i++) { - assertEquals(expectedVector.get(i), vector.get(i), 0F); + assertEquals(expectedVector.get(i).floatValue(), vector.get(i).floatValue(), 0F); } }); } @@ -167,9 +187,18 @@ public void setup() throws IOException { int numDocs = randomIntBetween(10, 100); IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; for (int i = 0; i < numDocs; i++) { - List vector = new ArrayList<>(numDims); + List vector = new ArrayList<>(numDims); for (int j = 0; j < numDims; j++) { - vector.add(randomFloat()); + switch (elementType) { + case FLOAT: + // Normalized values to avoid normalizing the comparison + vector.add(randomFloatBetween(-1F, 1F, true)); + break; + case BYTE: + vector.add(randomByte()); + break; + default: throw new IllegalArgumentException("Unexpected element type: " + elementType); + } } docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector); indexedVectors.put(i, vector); @@ -188,9 +217,14 @@ private void createIndexWithDenseVector(String indexName) throws IOException { .endObject() .startObject("vector") .field("type", "dense_vector") + .field("element_type", elementType.toString().toLowerCase(Locale.ROOT)) .field("index", index); if (index) { - mapping.field("similarity", "l2_norm"); + mapping.field( + "similarity", + // Let's not use others to avoid vector normalization + randomFrom("l2_norm", "max_inner_product") + ); } if (indexType != null) { mapping.startObject("index_options").field("type", indexType).endObject(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 5491ba58887f7..d4bbd749b9597 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1309,7 +1309,12 @@ public enum Cap { /** * Support correct counting of skipped shards. */ - CORRECT_SKIPPED_SHARDS_COUNT; + CORRECT_SKIPPED_SHARDS_COUNT, + + /** + * Byte elements dense vector field type support. + */ + DENSE_VECTOR_FIELD_TYPE_BYTE_ELEMENTS(EsqlCorePlugin.DENSE_VECTOR_FEATURE_FLAG); private final boolean enabled; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 595b7699bdbd7..bd483ff5b8c6a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -2338,11 +2338,15 @@ public void testImplicitCasting() { public void testDenseVectorImplicitCastingKnn() { assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); - Analyzer analyzer = analyzer(loadMapping("mapping-dense_vector.json", "vectors")); + assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); - var plan = analyze(""" - from test | where knn(vector, [0.342, 0.164, 0.234], 10) - """, "mapping-dense_vector.json"); + checkDenseVectorCastingKnn("float_vector"); + } + + private static void checkDenseVectorCastingKnn(String fieldName) { + var plan = analyze(String.format(Locale.ROOT, """ + from test | where knn(%s, [0.342, 0.164, 0.234], 10) + """, fieldName), "mapping-dense_vector.json"); var limit = as(plan, Limit.class); var filter = as(limit.child(), Filter.class); @@ -2355,19 +2359,19 @@ public void testDenseVectorImplicitCastingKnn() { public void testDenseVectorImplicitCastingSimilarityFunctions() { if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f)); - checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(vector, [1, 2, 3])", List.of(1f, 2f, 3f)); + checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(float_vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f)); + checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(byte_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); } if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { checkDenseVectorImplicitCastingSimilarityFunction( - "v_dot_product(vector, [0.342, 0.164, 0.234])", + "v_dot_product(float_vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f) ); - checkDenseVectorImplicitCastingSimilarityFunction("v_dot_product(vector, [1, 2, 3])", List.of(1f, 2f, 3f)); + checkDenseVectorImplicitCastingSimilarityFunction("v_dot_product(byte_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); } if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f)); - checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(vector, [1, 2, 3])", List.of(1f, 2f, 3f)); + checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(float_vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f)); + checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(byte_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); } } @@ -2382,7 +2386,7 @@ private void checkDenseVectorImplicitCastingSimilarityFunction(String similarity assertEquals("similarity", alias.name()); var similarity = as(alias.child(), VectorSimilarityFunction.class); var left = as(similarity.left(), FieldAttribute.class); - assertEquals("vector", left.name()); + assertThat(List.of("float_vector", "byte_vector"), hasItem(left.name())); var right = as(similarity.right(), Literal.class); assertThat(right.dataType(), is(DENSE_VECTOR)); assertThat(right.value(), equalTo(expectedElems)); From 4c59ba175dd36822cc3a7ab98accc59305d89cfd Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 24 Jul 2025 18:19:02 +0200 Subject: [PATCH 03/21] First test version, with all index types - fails --- .../xpack/esql/DenseVectorFieldTypeIT.java | 8 +- .../xpack/esql/plugin/KnnFunctionIT.java | 107 ++++++++++++++---- 2 files changed, 87 insertions(+), 28 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 5aacba6ac09c2..6ebaaa5200815 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -13,9 +13,7 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; @@ -24,7 +22,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Locale; @@ -37,7 +34,7 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { - private static final Set ALL_DENSE_VECTOR_INDEX_TYPES = Set.of( + public static final Set ALL_DENSE_VECTOR_INDEX_TYPES = Set.of( "int8_hnsw", "hnsw", "int4_hnsw", @@ -47,7 +44,7 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { "bbq_flat", "flat" ); - private static final Set NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Set.of( + public static final Set NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Set.of( "hnsw", "flat" ); @@ -191,7 +188,6 @@ public void setup() throws IOException { for (int j = 0; j < numDims; j++) { switch (elementType) { case FLOAT: - // Normalized values to avoid normalizing the comparison vector.add(randomFloatBetween(-1F, 1F, true)); break; case BYTE: diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index 9ae1c980337f1..91429cbe92f2f 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -7,11 +7,15 @@ package org.elasticsearch.xpack.esql.plugin; +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.client.internal.IndicesAdminClient; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.esql.EsqlTestUtils; @@ -29,18 +33,66 @@ import java.util.Map; import static org.elasticsearch.index.IndexMode.LOOKUP; +import static org.elasticsearch.index.IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING; +import static org.elasticsearch.index.mapper.SourceFieldMapper.Mode.SYNTHETIC; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.xpack.esql.DenseVectorFieldTypeIT.ALL_DENSE_VECTOR_INDEX_TYPES; +import static org.elasticsearch.xpack.esql.DenseVectorFieldTypeIT.NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES; import static org.hamcrest.CoreMatchers.containsString; public class KnnFunctionIT extends AbstractEsqlIntegTestCase { - private final Map> indexedVectors = new HashMap<>(); + private final Map> indexedVectors = new HashMap<>(); private int numDocs; private int numDims; + private final DenseVectorFieldMapper.ElementType elementType; + private final String indexType; + private final boolean synthetic; + private final boolean index; + + + @ParametersFactory + public static Iterable parameters() throws Exception { + List params = new ArrayList<>(); + // Indexed field types + for (String indexType : ALL_DENSE_VECTOR_INDEX_TYPES) { + params.add(new Object[] {DenseVectorFieldMapper.ElementType.FLOAT, indexType, true, false }); + } +// params.add(new Object[] {DenseVectorFieldMapper.ElementType.BYTE, "flat", true, false }); + for (String indexType : NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES) { + params.add(new Object[] {DenseVectorFieldMapper.ElementType.BYTE, indexType, true, false }); + } + for (DenseVectorFieldMapper.ElementType elementType : List.of( + DenseVectorFieldMapper.ElementType.BYTE, + DenseVectorFieldMapper.ElementType.FLOAT + )) { + // No indexing + params.add(new Object[]{elementType, null, false, false}); + // No indexing, synthetic source + params.add(new Object[]{elementType, null, false, true}); + } + + // Remove flat index types, as knn does not do a top k for flat + params.removeIf(param -> param[1] != null && ((String) param[1]).contains("flat")); + return params; + } + + public KnnFunctionIT( + @Name("elementType") DenseVectorFieldMapper.ElementType elementType, + @Name("indexType") String indexType, + @Name("index") boolean index, + @Name("synthetic") boolean synthetic + ) { + this.elementType = elementType; + this.indexType = indexType; + this.index = index; + this.synthetic = synthetic; + } + public void testKnnDefaults() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); var query = String.format(Locale.ROOT, """ FROM test METADATA _score @@ -61,9 +113,9 @@ public void testKnnDefaults() { assertEquals(i, row.getFirst()); @SuppressWarnings("unchecked") // Vectors should be the same - List floats = (List) row.get(1); + List floats = (List) row.get(1); for (int j = 0; j < floats.size(); j++) { - assertEquals(floats.get(j).floatValue(), indexedVectors.get(i).get(j), 0f); + assertEquals(floats.get(j).floatValue(), indexedVectors.get(i).get(j).floatValue(), 0f); } var score = (Double) row.get(2); assertNotNull(score); @@ -74,7 +126,7 @@ public void testKnnDefaults() { public void testKnnOptions() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); var query = String.format(Locale.ROOT, """ FROM test METADATA _score @@ -94,12 +146,12 @@ public void testKnnOptions() { public void testKnnNonPushedDown() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); // TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, 5) OR id > 10 + | WHERE knn(vector, %s, 5) OR (id > 10 AND id <= 15) | KEEP id, floats, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); @@ -109,8 +161,8 @@ public void testKnnNonPushedDown() { assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); - // K = 5, 1 more for every id > 10 - assertEquals(5 + Math.max(0, numDocs - 10 - 1), valuesList.size()); + // K = 5, and 5 from the disjunction + assertEquals(10, valuesList.size()); } } @@ -121,7 +173,7 @@ public void testKnnWithPrefilters() { // We retrieve 5 from knn, but must be prefiltered with id > 5 or no result will be returned as it would be post-filtered var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, 5) AND id > 5 + | WHERE knn(vector, %s, 5) AND id > 5 AND id <= 10 | KEEP id, floats, _score, vector | SORT _score DESC | LIMIT 5 @@ -144,7 +196,7 @@ public void testKnnWithLookupJoin() { var query = String.format(Locale.ROOT, """ FROM test | LOOKUP JOIN test_lookup ON id - | WHERE KNN(lookup_vector, %s, 5) OR id > 10 + | WHERE KNN(lookup_vector, %s, 5) OR (id > 5 AND id <= 10) """, Arrays.toString(queryVector)); var error = expectThrows(VerificationException.class, () -> run(query)); @@ -169,32 +221,43 @@ public void setup() throws IOException { .startObject("id") .field("type", "integer") .endObject() - .startObject("vector") - .field("type", "dense_vector") - .field("similarity", "l2_norm") - .endObject() .startObject("floats") .field("type", "float") .endObject() - .endObject() - .endObject(); + .startObject("vector") + .field("type", "dense_vector"); + if (index) { + mapping.field( + "similarity", + // Let's not use others to avoid vector normalization + randomFrom("l2_norm", "max_inner_product") + ); + } + if (indexType != null) { + mapping.startObject("index_options").field("type", indexType).endObject(); + } + mapping.endObject().endObject().endObject(); Settings.Builder settingsBuilder = Settings.builder() .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1); + if (synthetic) { + settingsBuilder.put(INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SYNTHETIC); + } var createRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build()); assertAcked(createRequest); - numDocs = randomIntBetween(15, 25); - numDims = randomIntBetween(3, 10); + numDocs = randomIntBetween(20, 35); + numDims = 64 + randomIntBetween(1, 10) * 2; IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; - float value = 0.0f; + byte value = 0; for (int i = 0; i < numDocs; i++) { - List vector = new ArrayList<>(numDims); + List vector = new ArrayList<>(numDims); for (int j = 0; j < numDims; j++) { - vector.add(value++); + vector.add(value); } + value++; docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "floats", vector, "vector", vector); indexedVectors.put(i, vector); } From 475a28533b1e152f03b7583f435c6a8aa440dfc9 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 24 Jul 2025 18:20:12 +0200 Subject: [PATCH 04/21] Knn tests for non-flat, indexed types --- .../xpack/esql/plugin/KnnFunctionIT.java | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index 91429cbe92f2f..c5902e71d7be5 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -55,23 +55,12 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase { @ParametersFactory public static Iterable parameters() throws Exception { List params = new ArrayList<>(); - // Indexed field types for (String indexType : ALL_DENSE_VECTOR_INDEX_TYPES) { params.add(new Object[] {DenseVectorFieldMapper.ElementType.FLOAT, indexType, true, false }); } -// params.add(new Object[] {DenseVectorFieldMapper.ElementType.BYTE, "flat", true, false }); for (String indexType : NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES) { params.add(new Object[] {DenseVectorFieldMapper.ElementType.BYTE, indexType, true, false }); } - for (DenseVectorFieldMapper.ElementType elementType : List.of( - DenseVectorFieldMapper.ElementType.BYTE, - DenseVectorFieldMapper.ElementType.FLOAT - )) { - // No indexing - params.add(new Object[]{elementType, null, false, false}); - // No indexing, synthetic source - params.add(new Object[]{elementType, null, false, true}); - } // Remove flat index types, as knn does not do a top k for flat params.removeIf(param -> param[1] != null && ((String) param[1]).contains("flat")); From ecc563f880664ced3b4a7960a3d36fae597522f4 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 24 Jul 2025 18:48:41 +0200 Subject: [PATCH 05/21] Add CSV tests --- .../src/main/resources/data/dense_vector.csv | 10 ++-- .../main/resources/dense_vector-byte.csv-spec | 47 +++++++++++++++++++ .../src/main/resources/dense_vector.csv-spec | 11 ++--- 3 files changed, 57 insertions(+), 11 deletions(-) create mode 100644 x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector-byte.csv-spec diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/dense_vector.csv b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/dense_vector.csv index d24c9f8543b53..dc5b58e354cf5 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/dense_vector.csv +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/dense_vector.csv @@ -1,5 +1,5 @@ -id:l, vector:dense_vector -0, [1.0, 2.0, 3.0] -1, [4.0, 5.0, 6.0] -2, [9.0, 8.0, 7.0] -3, [0.054, 0.032, 0.012] +id:l, float_vector:dense_vector, byte_vector:dense_vector +0, [1.0, 2.0, 3.0], [10, 20, 30] +1, [4.0, 5.0, 6.0], [40, 50, 60] +2, [9.0, 8.0, 7.0], [90, 80, 70] +3, [0.054, 0.032, 0.012], [100, 110, 120] diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector-byte.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector-byte.csv-spec new file mode 100644 index 0000000000000..f3fc819f9867e --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector-byte.csv-spec @@ -0,0 +1,47 @@ +retrieveByteVectorData +required_capability: dense_vector_field_type_byte_elements + +FROM dense_vector +| KEEP id, byte_vector +| SORT id +; + +id:l | byte_vector:dense_vector +0 | [10, 20, 30] +1 | [40, 50, 60] +2 | [90, 80, 70] +3 | [100, 110, 120] +; + +denseByteVectorWithEval +required_capability: dense_vector_field_type_byte_elements + +FROM dense_vector +| EVAL v = byte_vector +| KEEP id, v +| SORT id +; + +id:l | v:dense_vector +0 | [10, 20, 30] +1 | [40, 50, 60] +2 | [90, 80, 70] +3 | [100, 110, 120] +; + +denseByteVectorWithRenameAndDrop +required_capability: dense_vector_field_type_byte_elements + +FROM dense_vector +| EVAL v = byte_vector +| RENAME v AS new_vector +| DROP float_vector, byte_vector +| SORT id +; + +id:l | new_vector:dense_vector +0 | [10, 20, 30] +1 | [40, 50, 60] +2 | [90, 80, 70] +3 | [100, 110, 120] +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector.csv-spec index 74ef532313055..077565b8b8997 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector.csv-spec @@ -1,13 +1,12 @@ - retrieveDenseVectorData required_capability: dense_vector_field_type FROM dense_vector -| KEEP id, vector +| KEEP id, float_vector | SORT id ; -id:l | vector:dense_vector +id:l | float_vector:dense_vector 0 | [1.0, 2.0, 3.0] 1 | [4.0, 5.0, 6.0] 2 | [9.0, 8.0, 7.0] @@ -18,7 +17,7 @@ denseVectorWithEval required_capability: dense_vector_field_type FROM dense_vector -| EVAL v = vector +| EVAL v = float_vector | KEEP id, v | SORT id ; @@ -34,9 +33,9 @@ denseVectorWithRenameAndDrop required_capability: dense_vector_field_type FROM dense_vector -| EVAL v = vector +| EVAL v = float_vector | RENAME v AS new_vector -| DROP vector +| DROP float_vector, byte_vector | SORT id ; From 84d3c50c297cd9201da9cc4b169fd05fb455ff52 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 24 Jul 2025 19:25:13 +0200 Subject: [PATCH 06/21] Fix tests after merging --- .../xpack/esql/DenseVectorFieldTypeIT.java | 7 +- .../xpack/esql/plugin/KnnFunctionIT.java | 72 +++++++++---------- 2 files changed, 39 insertions(+), 40 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 1a918eb4af431..e212478ed8603 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -207,12 +207,13 @@ public void setup() throws IOException { case BYTE: vector.add(randomByte()); break; - default: throw new IllegalArgumentException("Unexpected element type: " + elementType); + default: + throw new IllegalArgumentException("Unexpected element type: " + elementType); } } + docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector); + indexedVectors.put(i, vector); } - docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector); - indexedVectors.put(i, vector); } indexRandom(true, docs); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index c5902e71d7be5..083435a869f97 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -33,12 +33,11 @@ import java.util.Map; import static org.elasticsearch.index.IndexMode.LOOKUP; -import static org.elasticsearch.index.IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING; -import static org.elasticsearch.index.mapper.SourceFieldMapper.Mode.SYNTHETIC; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.xpack.esql.DenseVectorFieldTypeIT.ALL_DENSE_VECTOR_INDEX_TYPES; import static org.elasticsearch.xpack.esql.DenseVectorFieldTypeIT.NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES; import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.Matchers.lessThanOrEqualTo; public class KnnFunctionIT extends AbstractEsqlIntegTestCase { @@ -48,18 +47,16 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase { private final DenseVectorFieldMapper.ElementType elementType; private final String indexType; - private final boolean synthetic; - private final boolean index; @ParametersFactory public static Iterable parameters() throws Exception { List params = new ArrayList<>(); for (String indexType : ALL_DENSE_VECTOR_INDEX_TYPES) { - params.add(new Object[] {DenseVectorFieldMapper.ElementType.FLOAT, indexType, true, false }); + params.add(new Object[] {DenseVectorFieldMapper.ElementType.FLOAT, indexType }); } for (String indexType : NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES) { - params.add(new Object[] {DenseVectorFieldMapper.ElementType.BYTE, indexType, true, false }); + params.add(new Object[] {DenseVectorFieldMapper.ElementType.BYTE, indexType }); } // Remove flat index types, as knn does not do a top k for flat @@ -69,14 +66,10 @@ public static Iterable parameters() throws Exception { public KnnFunctionIT( @Name("elementType") DenseVectorFieldMapper.ElementType elementType, - @Name("indexType") String indexType, - @Name("index") boolean index, - @Name("synthetic") boolean synthetic + @Name("indexType") String indexType ) { this.elementType = elementType; this.indexType = indexType; - this.index = index; - this.synthetic = synthetic; } public void testKnnDefaults() { @@ -96,15 +89,17 @@ public void testKnnDefaults() { List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(Math.min(indexedVectors.size(), 10), valuesList.size()); - for (int i = 0; i < valuesList.size(); i++) { - List row = valuesList.get(i); - // Vectors should be in order of ID, as they're less similar than the query vector as the ID increases - assertEquals(i, row.getFirst()); + double previousScore = Float.MAX_VALUE; + for (List row : valuesList) { + // Vectors should be in score order + double currentScore = (Double) row.get(2); + assertThat(currentScore, lessThanOrEqualTo(previousScore)); + previousScore = currentScore; @SuppressWarnings("unchecked") // Vectors should be the same List floats = (List) row.get(1); for (int j = 0; j < floats.size(); j++) { - assertEquals(floats.get(j).floatValue(), indexedVectors.get(i).get(j).floatValue(), 0f); + assertEquals(floats.get(j).floatValue(), indexedVectors.get(row.get(0)).get(j).floatValue(), 0f); } var score = (Double) row.get(2); assertNotNull(score); @@ -140,7 +135,7 @@ public void testKnnNonPushedDown() { // TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, 5) OR (id > 10 AND id <= 15) + | WHERE knn(vector, %s, 5) OR id > 100 | KEEP id, floats, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); @@ -150,14 +145,13 @@ public void testKnnNonPushedDown() { assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); - // K = 5, and 5 from the disjunction - assertEquals(10, valuesList.size()); + assertEquals(5, valuesList.size()); } } public void testKnnWithPrefilters() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); // We retrieve 5 from knn, but must be prefiltered with id > 5 or no result will be returned as it would be post-filtered var query = String.format(Locale.ROOT, """ @@ -180,12 +174,12 @@ public void testKnnWithPrefilters() { public void testKnnWithLookupJoin() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); var query = String.format(Locale.ROOT, """ FROM test | LOOKUP JOIN test_lookup ON id - | WHERE KNN(lookup_vector, %s, 5) OR (id > 5 AND id <= 10) + | WHERE KNN(lookup_vector, %s, 5) OR id > 100 """, Arrays.toString(queryVector)); var error = expectThrows(VerificationException.class, () -> run(query)); @@ -214,25 +208,22 @@ public void setup() throws IOException { .field("type", "float") .endObject() .startObject("vector") - .field("type", "dense_vector"); - if (index) { - mapping.field( + .field("type", "dense_vector") + .field( "similarity", // Let's not use others to avoid vector normalization randomFrom("l2_norm", "max_inner_product") - ); - } - if (indexType != null) { - mapping.startObject("index_options").field("type", indexType).endObject(); - } - mapping.endObject().endObject().endObject(); + ) + .startObject("index_options") + .field("type", indexType) + .endObject() + .endObject() + .endObject() + .endObject(); Settings.Builder settingsBuilder = Settings.builder() .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1); - if (synthetic) { - settingsBuilder.put(INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SYNTHETIC); - } var createRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build()); assertAcked(createRequest); @@ -240,13 +231,20 @@ public void setup() throws IOException { numDocs = randomIntBetween(20, 35); numDims = 64 + randomIntBetween(1, 10) * 2; IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; - byte value = 0; for (int i = 0; i < numDocs; i++) { List vector = new ArrayList<>(numDims); for (int j = 0; j < numDims; j++) { - vector.add(value); + switch (elementType) { + case FLOAT: + vector.add(randomFloatBetween(-1F, 1F, true)); + break; + case BYTE: + vector.add(randomByte()); + break; + default: + throw new IllegalArgumentException("Unexpected element type: " + elementType); + } } - value++; docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "floats", vector, "vector", vector); indexedVectors.put(i, vector); } From e9afc06baa26c0f06764b2bccedca45e5a791df7 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 24 Jul 2025 17:39:02 +0000 Subject: [PATCH 07/21] [CI] Auto commit changes from spotless --- .../xpack/esql/DenseVectorFieldTypeIT.java | 13 +++++-------- .../xpack/esql/plugin/KnnFunctionIT.java | 10 +++------- .../xpack/esql/analysis/AnalyzerTests.java | 10 ++++++++-- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index e212478ed8603..b36bbac74e137 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -44,10 +44,7 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { "bbq_flat", "flat" ); - public static final Set NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Set.of( - "hnsw", - "flat" - ); + public static final Set NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Set.of("hnsw", "flat"); private final ElementType elementType; private final boolean synthetic; @@ -59,16 +56,16 @@ public static Iterable parameters() throws Exception { List params = new ArrayList<>(); // Indexed field types for (String indexType : ALL_DENSE_VECTOR_INDEX_TYPES) { - params.add(new Object[] {ElementType.FLOAT, indexType, true, false }); + params.add(new Object[] { ElementType.FLOAT, indexType, true, false }); } for (String indexType : NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES) { - params.add(new Object[] {ElementType.BYTE, indexType, true, false }); + params.add(new Object[] { ElementType.BYTE, indexType, true, false }); } for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT)) { // No indexing - params.add(new Object[]{elementType, null, false, false}); + params.add(new Object[] { elementType, null, false, false }); // No indexing, synthetic source - params.add(new Object[]{elementType, null, false, true}); + params.add(new Object[] { elementType, null, false, true }); } return params; } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index 083435a869f97..b7a0df6b74daa 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -48,15 +48,14 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase { private final DenseVectorFieldMapper.ElementType elementType; private final String indexType; - @ParametersFactory public static Iterable parameters() throws Exception { List params = new ArrayList<>(); for (String indexType : ALL_DENSE_VECTOR_INDEX_TYPES) { - params.add(new Object[] {DenseVectorFieldMapper.ElementType.FLOAT, indexType }); + params.add(new Object[] { DenseVectorFieldMapper.ElementType.FLOAT, indexType }); } for (String indexType : NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES) { - params.add(new Object[] {DenseVectorFieldMapper.ElementType.BYTE, indexType }); + params.add(new Object[] { DenseVectorFieldMapper.ElementType.BYTE, indexType }); } // Remove flat index types, as knn does not do a top k for flat @@ -64,10 +63,7 @@ public static Iterable parameters() throws Exception { return params; } - public KnnFunctionIT( - @Name("elementType") DenseVectorFieldMapper.ElementType elementType, - @Name("indexType") String indexType - ) { + public KnnFunctionIT(@Name("elementType") DenseVectorFieldMapper.ElementType elementType, @Name("indexType") String indexType) { this.elementType = elementType; this.indexType = indexType; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index bd483ff5b8c6a..9bd8eb10da421 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -2359,7 +2359,10 @@ private static void checkDenseVectorCastingKnn(String fieldName) { public void testDenseVectorImplicitCastingSimilarityFunctions() { if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(float_vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f)); + checkDenseVectorImplicitCastingSimilarityFunction( + "v_cosine(float_vector, [0.342, 0.164, 0.234])", + List.of(0.342f, 0.164f, 0.234f) + ); checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(byte_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); } if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { @@ -2370,7 +2373,10 @@ public void testDenseVectorImplicitCastingSimilarityFunctions() { checkDenseVectorImplicitCastingSimilarityFunction("v_dot_product(byte_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); } if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(float_vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f)); + checkDenseVectorImplicitCastingSimilarityFunction( + "v_l1_norm(float_vector, [0.342, 0.164, 0.234])", + List.of(0.342f, 0.164f, 0.234f) + ); checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(byte_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); } } From a4aca145b101e45be363bd186e6d61e583643450 Mon Sep 17 00:00:00 2001 From: cdelgado Date: Tue, 12 Aug 2025 14:00:04 +0200 Subject: [PATCH 08/21] Take into account normalization --- .../index/mapper/BlockDocValuesReader.java | 50 ++++++++- .../vectors/DenseVectorFieldMapper.java | 62 +++++------ .../xpack/esql/DenseVectorFieldTypeIT.java | 102 ++++++++++-------- .../xpack/esql/plugin/KnnFunctionIT.java | 46 ++++---- 4 files changed, 158 insertions(+), 102 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index 65fbfee94b3fa..5d95d3e0ac7f3 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -30,12 +30,15 @@ import org.elasticsearch.index.mapper.BlockLoader.DoubleBuilder; import org.elasticsearch.index.mapper.BlockLoader.IntBuilder; import org.elasticsearch.index.mapper.BlockLoader.LongBuilder; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder; import org.elasticsearch.search.fetch.StoredFieldsSpec; import java.io.IOException; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.COSINE_MAGNITUDE_FIELD_SUFFIX; + /** * A reader that supports reading doc-values from a Lucene segment in Block fashion. */ @@ -513,12 +516,12 @@ public String toString() { public static class DenseVectorBlockLoader extends DocValuesBlockLoader { private final String fieldName; private final int dimensions; - private final ElementType elementType; + private final DenseVectorFieldMapper.DenseVectorFieldType fieldType; - public DenseVectorBlockLoader(String fieldName, int dimensions, ElementType elementType) { + public DenseVectorBlockLoader(String fieldName, int dimensions, DenseVectorFieldMapper.DenseVectorFieldType fieldType) { this.fieldName = fieldName; this.dimensions = dimensions; - this.elementType = elementType; + this.fieldType = fieldType; } @Override @@ -528,10 +531,17 @@ public Builder builder(BlockFactory factory, int expectedCount) { @Override public AllReader reader(LeafReaderContext context) throws IOException { - switch (elementType) { + switch (fieldType.getElementType()) { case FLOAT -> { FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); if (floatVectorValues != null) { + if (fieldType.isNormalized()) { + return new FloatDenseVectorNormalizedValuesBlockReader( + floatVectorValues, + dimensions, + context.reader().getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX) + ); + } return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); } } @@ -596,6 +606,7 @@ public int docId() { } private static class FloatDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader { + FloatDenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) { super(floatVectorValues, dimensions); } @@ -615,6 +626,37 @@ public String toString() { } } + private static class FloatDenseVectorNormalizedValuesBlockReader extends DenseVectorValuesBlockReader { + private final NumericDocValues magnitudeDocValues; + + FloatDenseVectorNormalizedValuesBlockReader( + FloatVectorValues floatVectorValues, + int dimensions, + NumericDocValues magnitudeDocValues + ) { + super(floatVectorValues, dimensions); + this.magnitudeDocValues = magnitudeDocValues; + } + + @Override + protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { + float[] floats = vectorValues.vectorValue(iterator.index()); + assert floats.length == dimensions + : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; + + assert magnitudeDocValues.advanceExact(iterator.docID()); + float magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue()); + for (float aFloat : floats) { + builder.appendFloat(aFloat * magnitude); + } + } + + @Override + public String toString() { + return "BlockDocValuesReader.FloatDenseVectorNormalizedValuesBlockReader"; + } + } + private static class ByteDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader { ByteDenseVectorValuesBlockReader(ByteVectorValues floatVectorValues, int dimensions) { super(floatVectorValues, dimensions); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 2920cb55a4aad..4edd6475b890d 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -734,31 +734,29 @@ IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldTyp this, denseVectorFieldType.dims, denseVectorFieldType.indexed, - denseVectorFieldType.indexVersionCreated.onOrAfter(NORMALIZE_COSINE) - && denseVectorFieldType.indexed - && denseVectorFieldType.similarity.equals(VectorSimilarity.COSINE) ? r -> new FilterLeafReader(r) { - @Override - public CacheHelper getCoreCacheHelper() { - return r.getCoreCacheHelper(); - } + denseVectorFieldType.isNormalized() && denseVectorFieldType.indexed ? r -> new FilterLeafReader(r) { + @Override + public CacheHelper getCoreCacheHelper() { + return r.getCoreCacheHelper(); + } - @Override - public CacheHelper getReaderCacheHelper() { - return r.getReaderCacheHelper(); - } + @Override + public CacheHelper getReaderCacheHelper() { + return r.getReaderCacheHelper(); + } - @Override - public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException { - FloatVectorValues values = in.getFloatVectorValues(fieldName); - if (values == null) { - return null; - } - return new DenormalizedCosineFloatVectorValues( - values, - in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX) - ); + @Override + public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException { + FloatVectorValues values = in.getFloatVectorValues(fieldName); + if (values == null) { + return null; } - } : r -> r + return new DenormalizedCosineFloatVectorValues( + values, + in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX) + ); + } + } : r -> r ); } @@ -820,9 +818,7 @@ public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFie fieldMapper.checkDimensionMatches(index, context); checkVectorBounds(vector); checkVectorMagnitude(fieldMapper.fieldType().similarity, errorFloatElementsAppender(vector), squaredMagnitude); - if (fieldMapper.indexCreatedVersion.onOrAfter(NORMALIZE_COSINE) - && fieldMapper.fieldType().similarity.equals(VectorSimilarity.COSINE) - && isNotUnitVector(squaredMagnitude)) { + if (fieldMapper.fieldType().isNormalized() && isNotUnitVector(squaredMagnitude)) { float length = (float) Math.sqrt(squaredMagnitude); for (int i = 0; i < vector.length; i++) { vector[i] /= length; @@ -2491,6 +2487,10 @@ public Query createExactKnnQuery(VectorData queryVector, Float vectorSimilarity) return knnQuery; } + public boolean isNormalized() { + return indexVersionCreated.onOrAfter(NORMALIZE_COSINE) && VectorSimilarity.COSINE.equals(similarity); + } + private Query createExactKnnBitQuery(byte[] queryVector) { elementType.checkDimensions(dims, queryVector.length); return new DenseVectorQuery.Bytes(queryVector, name()); @@ -2511,9 +2511,7 @@ private Query createExactKnnFloatQuery(float[] queryVector) { if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude); - if (similarity == VectorSimilarity.COSINE - && indexVersionCreated.onOrAfter(NORMALIZE_COSINE) - && isNotUnitVector(squaredMagnitude)) { + if (isNormalized() && isNotUnitVector(squaredMagnitude)) { float length = (float) Math.sqrt(squaredMagnitude); queryVector = Arrays.copyOf(queryVector, queryVector.length); for (int i = 0; i < queryVector.length; i++) { @@ -2703,9 +2701,7 @@ private Query createKnnFloatQuery( if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude); - if (similarity == VectorSimilarity.COSINE - && indexVersionCreated.onOrAfter(NORMALIZE_COSINE) - && isNotUnitVector(squaredMagnitude)) { + if (isNormalized() && isNotUnitVector(squaredMagnitude)) { float length = (float) Math.sqrt(squaredMagnitude); queryVector = Arrays.copyOf(queryVector, queryVector.length); for (int i = 0; i < queryVector.length; i++) { @@ -2795,7 +2791,7 @@ int getVectorDimensions() { return dims; } - ElementType getElementType() { + public ElementType getElementType() { return elementType; } @@ -2816,7 +2812,7 @@ public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) { } if (indexed) { - return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims, elementType); + return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims, this); } if (hasDocValues() && (blContext.fieldExtractPreference() != FieldExtractPreference.STORED || isSyntheticSource)) { diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index d0cff91975afc..98c0d10f5bec2 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -13,7 +13,9 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import org.elasticsearch.script.field.vectors.DenseVector; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; @@ -27,6 +29,8 @@ import java.util.Locale; import java.util.Map; import java.util.Set; +import java.util.function.Function; +import java.util.function.Supplier; import static org.elasticsearch.index.IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING; import static org.elasticsearch.index.mapper.SourceFieldMapper.Mode.SYNTHETIC; @@ -45,8 +49,10 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { "flat" ); public static final Set NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Set.of("hnsw", "flat"); + public static final float DELTA = 1e-7F; private final ElementType elementType; + private final DenseVectorFieldMapper.VectorSimilarity similarity; private final boolean synthetic; private final String indexType; private final boolean index; @@ -55,29 +61,31 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { public static Iterable parameters() throws Exception { List params = new ArrayList<>(); // Indexed field types - for (String indexType : ALL_DENSE_VECTOR_INDEX_TYPES) { - params.add(new Object[] { ElementType.FLOAT, indexType, true, false }); - } - for (String indexType : NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES) { - params.add(new Object[] { ElementType.BYTE, indexType, true, false }); - } - for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT)) { - // No indexing - params.add(new Object[] { elementType, null, false, false }); - // No indexing, synthetic source - params.add(new Object[] { elementType, null, false, true }); - } + Supplier elementTypeProvider = () -> randomFrom(ElementType.FLOAT, ElementType.BYTE); + Function indexTypeProvider = e -> e == ElementType.FLOAT + ? randomFrom(ALL_DENSE_VECTOR_INDEX_TYPES) + : randomFrom(NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES); + Supplier vectorSimilarityProvider = () -> randomFrom( + DenseVectorFieldMapper.VectorSimilarity.values() + ); + params.add(new Object[] { elementTypeProvider, indexTypeProvider, vectorSimilarityProvider, true, false }); + // No indexing + params.add(new Object[] { elementTypeProvider, null, null, false, false }); + // No indexing, synthetic source + params.add(new Object[] { elementTypeProvider, null, null, false, true }); return params; } public DenseVectorFieldTypeIT( - @Name("elementType") ElementType elementType, - @Name("indexType") String indexType, + @Name("elementType") Supplier elementTypeProvider, + @Name("indexType") Function indexTypeProvider, + @Name("similarity") Supplier similarityProvider, @Name("index") boolean index, @Name("synthetic") boolean synthetic ) { - this.elementType = elementType; - this.indexType = indexType; + this.elementType = elementTypeProvider.get(); + this.indexType = indexTypeProvider == null ? null : indexTypeProvider.apply(this.elementType); + this.similarity = similarityProvider == null ? null : similarityProvider.get(); this.index = index; this.synthetic = synthetic; } @@ -105,17 +113,17 @@ public void testRetrieveTopNDenseVectorFieldData() { try (var resp = run(query)) { List> valuesList = EsqlTestUtils.getValuesList(resp); - indexedVectors.forEach((id, vector) -> { + indexedVectors.forEach((id, expectedVector) -> { var values = valuesList.get(id); assertEquals(id, values.get(0)); - List vectors = (List) values.get(1); - if (vector == null) { - assertNull(vectors); + List actualVector = (List) values.get(1); + if (expectedVector == null) { + assertNull(actualVector); } else { - assertNotNull(vectors); - assertEquals(vector.size(), vectors.size()); - for (int i = 0; i < vector.size(); i++) { - assertEquals(vector.get(i).floatValue(), vectors.get(i).floatValue(), 0F); + assertNotNull(actualVector); + assertEquals(expectedVector.size(), actualVector.size()); + for (int i = 0; i < expectedVector.size(); i++) { + assertEquals(expectedVector.get(i).floatValue(), actualVector.get(i).floatValue(), DELTA); } } }); @@ -129,23 +137,31 @@ public void testRetrieveDenseVectorFieldData() { | KEEP id, vector """; + indexedVectors.forEach((i, v) -> { + System.out.println("ID: " + i + ", Vector: " + v); + }); + try (var resp = run(query)) { List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(valuesList.size(), indexedVectors.size()); + // print all values for debugging valuesList.forEach(value -> { assertEquals(2, value.size()); Integer id = (Integer) value.get(0); List expectedVector = indexedVectors.get(id); - List vector = (List) value.get(1); + List actualVector = (List) value.get(1); if (expectedVector == null) { - assertNull(vector); + assertNull(actualVector); } else { - assertNotNull(vector); - assertEquals(expectedVector.size(), vector.size()); - assertNotNull(vector); - assertNotNull(expectedVector); - for (int i = 0; i < vector.size(); i++) { - assertEquals(expectedVector.get(i).floatValue(), vector.get(i).floatValue(), 0F); + assertNotNull(actualVector); + assertEquals(expectedVector.size(), actualVector.size()); + for (int i = 0; i < actualVector.size(); i++) { + assertEquals( + "Actual: " + actualVector + "; expected: " + expectedVector, + expectedVector.get(i).floatValue(), + actualVector.get(i).floatValue(), + DELTA + ); } } }); @@ -198,14 +214,17 @@ public void setup() throws IOException { } else { for (int j = 0; j < numDims; j++) { switch (elementType) { - case FLOAT: - vector.add(randomFloatBetween(-1F, 1F, true)); - break; - case BYTE: - vector.add(randomByte()); - break; - default: - throw new IllegalArgumentException("Unexpected element type: " + elementType); + case FLOAT -> vector.add(randomFloatBetween(0F, 1F, true)); + case BYTE -> vector.add((byte) (randomFloatBetween(0F, 1F, true) * 127.0f)); + default -> throw new IllegalArgumentException("Unexpected element type: " + elementType); + } + } + if (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT) { + // Normalize the vector + float magnitude = DenseVector.getMagnitude(vector); + switch (elementType) { + case FLOAT -> vector.replaceAll(number -> number.floatValue() / magnitude); + case BYTE -> vector.replaceAll(number -> (byte) (number.byteValue() / magnitude)); } } docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector); @@ -231,8 +250,7 @@ private void createIndexWithDenseVector(String indexName) throws IOException { if (index) { mapping.field( "similarity", - // Let's not use others to avoid vector normalization - randomFrom("l2_norm", "max_inner_product") + similarity.name().toLowerCase(Locale.ROOT) ); } if (indexType != null) { diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index b7a0df6b74daa..d44a9b458b082 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -75,29 +75,32 @@ public void testKnnDefaults() { var query = String.format(Locale.ROOT, """ FROM test METADATA _score | WHERE knn(vector, %s, 10) - | KEEP id, floats, _score, vector + | KEEP id, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); - assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + assertColumnNames(resp.columns(), List.of("id", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(Math.min(indexedVectors.size(), 10), valuesList.size()); double previousScore = Float.MAX_VALUE; for (List row : valuesList) { // Vectors should be in score order - double currentScore = (Double) row.get(2); + double currentScore = (Double) row.get(1); assertThat(currentScore, lessThanOrEqualTo(previousScore)); previousScore = currentScore; @SuppressWarnings("unchecked") // Vectors should be the same - List floats = (List) row.get(1); - for (int j = 0; j < floats.size(); j++) { - assertEquals(floats.get(j).floatValue(), indexedVectors.get(row.get(0)).get(j).floatValue(), 0f); + List actualVector = (List) row.get(2); + List expectedVector = indexedVectors.get(row.get(0)); + for (int j = 0; j < actualVector.size(); j++) { + float expected = expectedVector.get(j).floatValue(); + float actual = actualVector.get(j).floatValue(); + assertEquals(expected, actual, 0f); } - var score = (Double) row.get(2); + var score = (Double) row.get(1); assertNotNull(score); assertTrue(score > 0.0); } @@ -111,13 +114,13 @@ public void testKnnOptions() { var query = String.format(Locale.ROOT, """ FROM test METADATA _score | WHERE knn(vector, %s, 5) - | KEEP id, floats, _score, vector + | KEEP id, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); - assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + assertColumnNames(resp.columns(), List.of("id", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(5, valuesList.size()); @@ -132,13 +135,13 @@ public void testKnnNonPushedDown() { var query = String.format(Locale.ROOT, """ FROM test METADATA _score | WHERE knn(vector, %s, 5) OR id > 100 - | KEEP id, floats, _score, vector + | KEEP id, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); - assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + assertColumnNames(resp.columns(), List.of("id", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(5, valuesList.size()); @@ -153,14 +156,14 @@ public void testKnnWithPrefilters() { var query = String.format(Locale.ROOT, """ FROM test METADATA _score | WHERE knn(vector, %s, 5) AND id > 5 AND id <= 10 - | KEEP id, floats, _score, vector + | KEEP id, _score, vector | SORT _score DESC | LIMIT 5 """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); - assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + assertColumnNames(resp.columns(), List.of("id", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); // K = 5, 1 more for every id > 10 @@ -200,9 +203,6 @@ public void setup() throws IOException { .startObject("id") .field("type", "integer") .endObject() - .startObject("floats") - .field("type", "float") - .endObject() .startObject("vector") .field("type", "dense_vector") .field( @@ -232,16 +232,16 @@ public void setup() throws IOException { for (int j = 0; j < numDims; j++) { switch (elementType) { case FLOAT: - vector.add(randomFloatBetween(-1F, 1F, true)); + vector.add(randomFloatBetween(0F, 1F, true)); break; case BYTE: - vector.add(randomByte()); + vector.add((byte) (randomFloatBetween(0F, 1F, true) * 127)); break; default: throw new IllegalArgumentException("Unexpected element type: " + elementType); } } - docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "floats", vector, "vector", vector); + docs[i] = prepareIndex("test").setId(String.valueOf(i)).setSource("id", String.valueOf(i), "vector", vector); indexedVectors.put(i, vector); } From 08a3c5c5983aa7d83a057aa572d7014197ce0258 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 12 Aug 2025 12:10:11 +0000 Subject: [PATCH 09/21] [CI] Auto commit changes from spotless --- .../elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 98c0d10f5bec2..774ab711fc90b 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -137,9 +137,7 @@ public void testRetrieveDenseVectorFieldData() { | KEEP id, vector """; - indexedVectors.forEach((i, v) -> { - System.out.println("ID: " + i + ", Vector: " + v); - }); + indexedVectors.forEach((i, v) -> { System.out.println("ID: " + i + ", Vector: " + v); }); try (var resp = run(query)) { List> valuesList = EsqlTestUtils.getValuesList(resp); @@ -248,10 +246,7 @@ private void createIndexWithDenseVector(String indexName) throws IOException { .field("element_type", elementType.toString().toLowerCase(Locale.ROOT)) .field("index", index); if (index) { - mapping.field( - "similarity", - similarity.name().toLowerCase(Locale.ROOT) - ); + mapping.field("similarity", similarity.name().toLowerCase(Locale.ROOT)); } if (indexType != null) { mapping.startObject("index_options").field("type", indexType).endObject(); From 763fe6392701e5612784c98307febc75974eea43 Mon Sep 17 00:00:00 2001 From: cdelgado Date: Tue, 12 Aug 2025 14:00:04 +0200 Subject: [PATCH 10/21] Take into account normalization for dense vector support --- .../index/mapper/BlockDocValuesReader.java | 100 ++++++++++++++- .../vectors/DenseVectorFieldMapper.java | 62 +++++----- .../xpack/esql/DenseVectorFieldTypeIT.java | 100 ++++++++++----- .../xpack/esql/plugin/KnnFunctionIT.java | 117 ++++++++++++------ 4 files changed, 270 insertions(+), 109 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index 809bad5145fe6..3ad251af5ef47 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -29,11 +29,15 @@ import org.elasticsearch.index.mapper.BlockLoader.DoubleBuilder; import org.elasticsearch.index.mapper.BlockLoader.IntBuilder; import org.elasticsearch.index.mapper.BlockLoader.LongBuilder; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder; import org.elasticsearch.search.fetch.StoredFieldsSpec; import java.io.IOException; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.COSINE_MAGNITUDE_FIELD_SUFFIX; + /** * A reader that supports reading doc-values from a Lucene segment in Block fashion. */ @@ -511,10 +515,12 @@ public String toString() { public static class DenseVectorBlockLoader extends DocValuesBlockLoader { private final String fieldName; private final int dimensions; + private final DenseVectorFieldMapper.DenseVectorFieldType fieldType; - public DenseVectorBlockLoader(String fieldName, int dimensions) { + public DenseVectorBlockLoader(String fieldName, int dimensions, DenseVectorFieldMapper.DenseVectorFieldType fieldType) { this.fieldName = fieldName; this.dimensions = dimensions; + this.fieldType = fieldType; } @Override @@ -524,9 +530,26 @@ public Builder builder(BlockFactory factory, int expectedCount) { @Override public AllReader reader(LeafReaderContext context) throws IOException { - FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); - if (floatVectorValues != null) { - return new DenseVectorValuesBlockReader(floatVectorValues, dimensions); + switch (fieldType.getElementType()) { + case FLOAT -> { + FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); + if (floatVectorValues != null) { + if (fieldType.isNormalized()) { + return new FloatDenseVectorNormalizedValuesBlockReader( + floatVectorValues, + dimensions, + context.reader().getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX) + ); + } + return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); + } + } + case BYTE -> { + ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(fieldName); + if (byteVectorValues != null) { + return new ByteDenseVectorValuesBlockReader(byteVectorValues, dimensions); + } + } } return new ConstantNullsReader(); } @@ -580,10 +603,77 @@ private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException public int docId() { return iterator.docID(); } + } + + private static class FloatDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader { + + FloatDenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) { + super(floatVectorValues, dimensions); + } + + protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { + float[] floats = vectorValues.vectorValue(iterator.index()); + assert floats.length == dimensions + : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; + for (float aFloat : floats) { + builder.appendFloat(aFloat); + } + } + + @Override + public String toString() { + return "BlockDocValuesReader.FloatDenseVectorValuesBlockReader"; + } + } + + private static class FloatDenseVectorNormalizedValuesBlockReader extends DenseVectorValuesBlockReader { + private final NumericDocValues magnitudeDocValues; + + FloatDenseVectorNormalizedValuesBlockReader( + FloatVectorValues floatVectorValues, + int dimensions, + NumericDocValues magnitudeDocValues + ) { + super(floatVectorValues, dimensions); + this.magnitudeDocValues = magnitudeDocValues; + } + + @Override + protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { + float[] floats = vectorValues.vectorValue(iterator.index()); + assert floats.length == dimensions + : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; + + assert magnitudeDocValues.advanceExact(iterator.docID()); + float magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue()); + for (float aFloat : floats) { + builder.appendFloat(aFloat * magnitude); + } + } + + @Override + public String toString() { + return "BlockDocValuesReader.FloatDenseVectorNormalizedValuesBlockReader"; + } + } + + private static class ByteDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader { + ByteDenseVectorValuesBlockReader(ByteVectorValues floatVectorValues, int dimensions) { + super(floatVectorValues, dimensions); + } + + protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { + byte[] bytes = vectorValues.vectorValue(iterator.index()); + assert bytes.length == dimensions + : "unexpected dimensions for vector value; expected " + dimensions + " but got " + bytes.length; + for (byte aFloat : bytes) { + builder.appendFloat(aFloat); + } + } @Override public String toString() { - return "BlockDocValuesReader.FloatVectorValuesBlockReader"; + return "BlockDocValuesReader.ByteDenseVectorValuesBlockReader"; } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index c9c14d027ebfd..2a3655a1100dd 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -734,31 +734,29 @@ IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldTyp this, denseVectorFieldType.dims, denseVectorFieldType.indexed, - denseVectorFieldType.indexVersionCreated.onOrAfter(NORMALIZE_COSINE) - && denseVectorFieldType.indexed - && denseVectorFieldType.similarity.equals(VectorSimilarity.COSINE) ? r -> new FilterLeafReader(r) { - @Override - public CacheHelper getCoreCacheHelper() { - return r.getCoreCacheHelper(); - } + denseVectorFieldType.isNormalized() && denseVectorFieldType.indexed ? r -> new FilterLeafReader(r) { + @Override + public CacheHelper getCoreCacheHelper() { + return r.getCoreCacheHelper(); + } - @Override - public CacheHelper getReaderCacheHelper() { - return r.getReaderCacheHelper(); - } + @Override + public CacheHelper getReaderCacheHelper() { + return r.getReaderCacheHelper(); + } - @Override - public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException { - FloatVectorValues values = in.getFloatVectorValues(fieldName); - if (values == null) { - return null; - } - return new DenormalizedCosineFloatVectorValues( - values, - in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX) - ); + @Override + public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException { + FloatVectorValues values = in.getFloatVectorValues(fieldName); + if (values == null) { + return null; } - } : r -> r + return new DenormalizedCosineFloatVectorValues( + values, + in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX) + ); + } + } : r -> r ); } @@ -820,9 +818,7 @@ public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFie fieldMapper.checkDimensionMatches(index, context); checkVectorBounds(vector); checkVectorMagnitude(fieldMapper.fieldType().similarity, errorFloatElementsAppender(vector), squaredMagnitude); - if (fieldMapper.indexCreatedVersion.onOrAfter(NORMALIZE_COSINE) - && fieldMapper.fieldType().similarity.equals(VectorSimilarity.COSINE) - && isNotUnitVector(squaredMagnitude)) { + if (fieldMapper.fieldType().isNormalized() && isNotUnitVector(squaredMagnitude)) { float length = (float) Math.sqrt(squaredMagnitude); for (int i = 0; i < vector.length; i++) { vector[i] /= length; @@ -2491,6 +2487,10 @@ public Query createExactKnnQuery(VectorData queryVector, Float vectorSimilarity) return knnQuery; } + public boolean isNormalized() { + return indexVersionCreated.onOrAfter(NORMALIZE_COSINE) && VectorSimilarity.COSINE.equals(similarity); + } + private Query createExactKnnBitQuery(byte[] queryVector) { elementType.checkDimensions(dims, queryVector.length); return new DenseVectorQuery.Bytes(queryVector, name()); @@ -2511,9 +2511,7 @@ private Query createExactKnnFloatQuery(float[] queryVector) { if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude); - if (similarity == VectorSimilarity.COSINE - && indexVersionCreated.onOrAfter(NORMALIZE_COSINE) - && isNotUnitVector(squaredMagnitude)) { + if (isNormalized() && isNotUnitVector(squaredMagnitude)) { float length = (float) Math.sqrt(squaredMagnitude); queryVector = Arrays.copyOf(queryVector, queryVector.length); for (int i = 0; i < queryVector.length; i++) { @@ -2703,9 +2701,7 @@ private Query createKnnFloatQuery( if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude); - if (similarity == VectorSimilarity.COSINE - && indexVersionCreated.onOrAfter(NORMALIZE_COSINE) - && isNotUnitVector(squaredMagnitude)) { + if (isNormalized() && isNotUnitVector(squaredMagnitude)) { float length = (float) Math.sqrt(squaredMagnitude); queryVector = Arrays.copyOf(queryVector, queryVector.length); for (int i = 0; i < queryVector.length; i++) { @@ -2795,7 +2791,7 @@ int getVectorDimensions() { return dims; } - ElementType getElementType() { + public ElementType getElementType() { return elementType; } @@ -2816,7 +2812,7 @@ public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) { } if (indexed) { - return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims); + return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims, this); } if (hasDocValues() && (blContext.fieldExtractPreference() != FieldExtractPreference.STORED || isSyntheticSource)) { diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 0673856cbcc3b..87d2583e2285b 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -13,6 +13,9 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import org.elasticsearch.script.field.vectors.DenseVector; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; @@ -23,8 +26,11 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; +import java.util.function.Function; +import java.util.function.Supplier; import static org.elasticsearch.index.IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING; import static org.elasticsearch.index.mapper.SourceFieldMapper.Mode.SYNTHETIC; @@ -32,7 +38,7 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { - private static final Set DENSE_VECTOR_INDEX_TYPES = Set.of( + public static final Set ALL_DENSE_VECTOR_INDEX_TYPES = Set.of( "int8_hnsw", "hnsw", "int4_hnsw", @@ -43,31 +49,46 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { "flat" ); + public static final float DELTA = 1e-7F; + + private final ElementType elementType; + private final DenseVectorFieldMapper.VectorSimilarity similarity; + private final boolean synthetic; private final String indexType; private final boolean index; - private final boolean synthetic; @ParametersFactory public static Iterable parameters() throws Exception { List params = new ArrayList<>(); // Indexed field types - for (String indexType : DENSE_VECTOR_INDEX_TYPES) { - params.add(new Object[] { indexType, true, false }); - } + Supplier elementTypeProvider = () -> ElementType.FLOAT; + Function indexTypeProvider = e -> randomFrom(ALL_DENSE_VECTOR_INDEX_TYPES); + Supplier vectorSimilarityProvider = () -> randomFrom( + DenseVectorFieldMapper.VectorSimilarity.values() + ); + params.add(new Object[] { elementTypeProvider, indexTypeProvider, vectorSimilarityProvider, true, false }); // No indexing - params.add(new Object[] { null, false, false }); + params.add(new Object[] { elementTypeProvider, null, null, false, false }); // No indexing, synthetic source - params.add(new Object[] { null, false, true }); + params.add(new Object[] { elementTypeProvider, null, null, false, true }); return params; } - public DenseVectorFieldTypeIT(@Name("indexType") String indexType, @Name("index") boolean index, @Name("synthetic") boolean synthetic) { - this.indexType = indexType; + public DenseVectorFieldTypeIT( + @Name("elementType") Supplier elementTypeProvider, + @Name("indexType") Function indexTypeProvider, + @Name("similarity") Supplier similarityProvider, + @Name("index") boolean index, + @Name("synthetic") boolean synthetic + ) { + this.elementType = elementTypeProvider.get(); + this.indexType = indexTypeProvider == null ? null : indexTypeProvider.apply(this.elementType); + this.similarity = similarityProvider == null ? null : similarityProvider.get(); this.index = index; this.synthetic = synthetic; } - private final Map> indexedVectors = new HashMap<>(); + private final Map> indexedVectors = new HashMap<>(); public void testRetrieveFieldType() { var query = """ @@ -90,17 +111,17 @@ public void testRetrieveTopNDenseVectorFieldData() { try (var resp = run(query)) { List> valuesList = EsqlTestUtils.getValuesList(resp); - indexedVectors.forEach((id, vector) -> { + indexedVectors.forEach((id, expectedVector) -> { var values = valuesList.get(id); assertEquals(id, values.get(0)); - List vectors = (List) values.get(1); - if (vector == null) { - assertNull(vectors); + List actualVector = (List) values.get(1); + if (expectedVector == null) { + assertNull(actualVector); } else { - assertNotNull(vectors); - assertEquals(vector.size(), vectors.size()); - for (int i = 0; i < vector.size(); i++) { - assertEquals(vector.get(i), vectors.get(i), 0F); + assertNotNull(actualVector); + assertEquals(expectedVector.size(), actualVector.size()); + for (int i = 0; i < expectedVector.size(); i++) { + assertEquals(expectedVector.get(i).floatValue(), actualVector.get(i).floatValue(), DELTA); } } }); @@ -114,24 +135,31 @@ public void testRetrieveDenseVectorFieldData() { | KEEP id, vector """; + indexedVectors.forEach((i, v) -> { + System.out.println("ID: " + i + ", Vector: " + v); + }); + try (var resp = run(query)) { List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(valuesList.size(), indexedVectors.size()); + // print all values for debugging valuesList.forEach(value -> { - ; assertEquals(2, value.size()); Integer id = (Integer) value.get(0); - List expectedVector = indexedVectors.get(id); - List vector = (List) value.get(1); + List expectedVector = indexedVectors.get(id); + List actualVector = (List) value.get(1); if (expectedVector == null) { - assertNull(vector); + assertNull(actualVector); } else { - assertNotNull(vector); - assertEquals(expectedVector.size(), vector.size()); - assertNotNull(vector); - assertNotNull(expectedVector); - for (int i = 0; i < vector.size(); i++) { - assertEquals(expectedVector.get(i), vector.get(i), 0F); + assertNotNull(actualVector); + assertEquals(expectedVector.size(), actualVector.size()); + for (int i = 0; i < actualVector.size(); i++) { + assertEquals( + "Actual: " + actualVector + "; expected: " + expectedVector, + expectedVector.get(i).floatValue(), + actualVector.get(i).floatValue(), + DELTA + ); } } }); @@ -177,13 +205,19 @@ public void setup() throws IOException { int numDocs = randomIntBetween(10, 100); IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; for (int i = 0; i < numDocs; i++) { - List vector = new ArrayList<>(numDims); + List vector = new ArrayList<>(numDims); if (rarely()) { docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i)); indexedVectors.put(i, null); } else { for (int j = 0; j < numDims; j++) { - vector.add(randomFloat()); + vector.add(randomFloatBetween(0F, 1F, true)); + } + if (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT) { + // Normalize the vector + float magnitude = DenseVector.getMagnitude(vector); + vector.replaceAll(number -> number.floatValue() / magnitude); + } } docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector); indexedVectors.put(i, vector); @@ -203,9 +237,13 @@ private void createIndexWithDenseVector(String indexName) throws IOException { .endObject() .startObject("vector") .field("type", "dense_vector") + .field("element_type", elementType.toString().toLowerCase(Locale.ROOT)) .field("index", index); if (index) { - mapping.field("similarity", "l2_norm"); + mapping.field( + "similarity", + similarity.name().toLowerCase(Locale.ROOT) + ); } if (indexType != null) { mapping.startObject("index_options").field("type", indexType).endObject(); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index 9ae1c980337f1..c2f8662ac502b 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -7,11 +7,15 @@ package org.elasticsearch.xpack.esql.plugin; +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.client.internal.IndicesAdminClient; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.esql.EsqlTestUtils; @@ -30,42 +34,73 @@ import static org.elasticsearch.index.IndexMode.LOOKUP; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.xpack.esql.DenseVectorFieldTypeIT.ALL_DENSE_VECTOR_INDEX_TYPES; +import static org.elasticsearch.xpack.esql.DenseVectorFieldTypeIT.NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES; import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.Matchers.lessThanOrEqualTo; public class KnnFunctionIT extends AbstractEsqlIntegTestCase { - private final Map> indexedVectors = new HashMap<>(); + private final Map> indexedVectors = new HashMap<>(); private int numDocs; private int numDims; + private final DenseVectorFieldMapper.ElementType elementType; + private final String indexType; + + @ParametersFactory + public static Iterable parameters() throws Exception { + List params = new ArrayList<>(); + for (String indexType : ALL_DENSE_VECTOR_INDEX_TYPES) { + params.add(new Object[] { DenseVectorFieldMapper.ElementType.FLOAT, indexType }); + } + for (String indexType : NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES) { + params.add(new Object[] { DenseVectorFieldMapper.ElementType.BYTE, indexType }); + } + + // Remove flat index types, as knn does not do a top k for flat + params.removeIf(param -> param[1] != null && ((String) param[1]).contains("flat")); + return params; + } + + public KnnFunctionIT(@Name("elementType") DenseVectorFieldMapper.ElementType elementType, @Name("indexType") String indexType) { + this.elementType = elementType; + this.indexType = indexType; + } + public void testKnnDefaults() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); var query = String.format(Locale.ROOT, """ FROM test METADATA _score | WHERE knn(vector, %s, 10) - | KEEP id, floats, _score, vector + | KEEP id, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); - assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + assertColumnNames(resp.columns(), List.of("id", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(Math.min(indexedVectors.size(), 10), valuesList.size()); - for (int i = 0; i < valuesList.size(); i++) { - List row = valuesList.get(i); - // Vectors should be in order of ID, as they're less similar than the query vector as the ID increases - assertEquals(i, row.getFirst()); + double previousScore = Float.MAX_VALUE; + for (List row : valuesList) { + // Vectors should be in score order + double currentScore = (Double) row.get(1); + assertThat(currentScore, lessThanOrEqualTo(previousScore)); + previousScore = currentScore; @SuppressWarnings("unchecked") // Vectors should be the same - List floats = (List) row.get(1); - for (int j = 0; j < floats.size(); j++) { - assertEquals(floats.get(j).floatValue(), indexedVectors.get(i).get(j), 0f); + List actualVector = (List) row.get(2); + List expectedVector = indexedVectors.get(row.get(0)); + for (int j = 0; j < actualVector.size(); j++) { + float expected = expectedVector.get(j).floatValue(); + float actual = actualVector.get(j).floatValue(); + assertEquals(expected, actual, 0f); } - var score = (Double) row.get(2); + var score = (Double) row.get(1); assertNotNull(score); assertTrue(score > 0.0); } @@ -74,18 +109,18 @@ public void testKnnDefaults() { public void testKnnOptions() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); var query = String.format(Locale.ROOT, """ FROM test METADATA _score | WHERE knn(vector, %s, 5) - | KEEP id, floats, _score, vector + | KEEP id, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); - assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + assertColumnNames(resp.columns(), List.of("id", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(5, valuesList.size()); @@ -94,42 +129,41 @@ public void testKnnOptions() { public void testKnnNonPushedDown() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); // TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, 5) OR id > 10 - | KEEP id, floats, _score, vector + | WHERE knn(vector, %s, 5) OR id > 100 + | KEEP id, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); - assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + assertColumnNames(resp.columns(), List.of("id", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); - // K = 5, 1 more for every id > 10 - assertEquals(5 + Math.max(0, numDocs - 10 - 1), valuesList.size()); + assertEquals(5, valuesList.size()); } } public void testKnnWithPrefilters() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); // We retrieve 5 from knn, but must be prefiltered with id > 5 or no result will be returned as it would be post-filtered var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, 5) AND id > 5 - | KEEP id, floats, _score, vector + | WHERE knn(vector, %s, 5) AND id > 5 AND id <= 10 + | KEEP id, _score, vector | SORT _score DESC | LIMIT 5 """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); - assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + assertColumnNames(resp.columns(), List.of("id", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); // K = 5, 1 more for every id > 10 @@ -139,12 +173,12 @@ public void testKnnWithPrefilters() { public void testKnnWithLookupJoin() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); var query = String.format(Locale.ROOT, """ FROM test | LOOKUP JOIN test_lookup ON id - | WHERE KNN(lookup_vector, %s, 5) OR id > 10 + | WHERE KNN(lookup_vector, %s, 5) OR id > 100 """, Arrays.toString(queryVector)); var error = expectThrows(VerificationException.class, () -> run(query)); @@ -171,10 +205,14 @@ public void setup() throws IOException { .endObject() .startObject("vector") .field("type", "dense_vector") - .field("similarity", "l2_norm") + .field( + "similarity", + // Let's not use others to avoid vector normalization + randomFrom("l2_norm", "max_inner_product") + ) + .startObject("index_options") + .field("type", indexType) .endObject() - .startObject("floats") - .field("type", "float") .endObject() .endObject() .endObject(); @@ -186,16 +224,15 @@ public void setup() throws IOException { var createRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build()); assertAcked(createRequest); - numDocs = randomIntBetween(15, 25); - numDims = randomIntBetween(3, 10); + numDocs = randomIntBetween(20, 35); + numDims = 64 + randomIntBetween(1, 10) * 2; IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; - float value = 0.0f; for (int i = 0; i < numDocs; i++) { - List vector = new ArrayList<>(numDims); + List vector = new ArrayList<>(numDims); for (int j = 0; j < numDims; j++) { - vector.add(value++); + vector.add(randomFloatBetween(0F, 1F, true)); } - docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "floats", vector, "vector", vector); + docs[i] = prepareIndex("test").setId(String.valueOf(i)).setSource("id", String.valueOf(i), "vector", vector); indexedVectors.put(i, vector); } From 80b48cf03b5c11f005c6f19134782a9546c67a9b Mon Sep 17 00:00:00 2001 From: cdelgado Date: Tue, 12 Aug 2025 14:20:02 +0200 Subject: [PATCH 11/21] Fix cherry pick --- .../index/mapper/BlockDocValuesReader.java | 74 ++++++------------- .../xpack/esql/DenseVectorFieldTypeIT.java | 3 +- .../xpack/esql/plugin/KnnFunctionIT.java | 4 - 3 files changed, 22 insertions(+), 59 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index 3ad251af5ef47..5ae2162192e28 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -30,7 +30,6 @@ import org.elasticsearch.index.mapper.BlockLoader.IntBuilder; import org.elasticsearch.index.mapper.BlockLoader.LongBuilder; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder; import org.elasticsearch.search.fetch.StoredFieldsSpec; @@ -530,39 +529,31 @@ public Builder builder(BlockFactory factory, int expectedCount) { @Override public AllReader reader(LeafReaderContext context) throws IOException { - switch (fieldType.getElementType()) { - case FLOAT -> { - FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); - if (floatVectorValues != null) { - if (fieldType.isNormalized()) { - return new FloatDenseVectorNormalizedValuesBlockReader( - floatVectorValues, - dimensions, - context.reader().getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX) - ); - } - return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); - } - } - case BYTE -> { - ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(fieldName); - if (byteVectorValues != null) { - return new ByteDenseVectorValuesBlockReader(byteVectorValues, dimensions); - } + FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); + if (floatVectorValues != null) { + if (fieldType.isNormalized()) { + return new FloatDenseVectorNormalizedValuesBlockReader( + floatVectorValues, + dimensions, + context.reader().getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX) + ); } + return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); } + return new ConstantNullsReader(); } } - private static class DenseVectorValuesBlockReader extends BlockDocValuesReader { - private final FloatVectorValues floatVectorValues; - private final KnnVectorValues.DocIndexIterator iterator; - private final int dimensions; + private abstract static class DenseVectorValuesBlockReader extends BlockDocValuesReader { - DenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) { - this.floatVectorValues = floatVectorValues; - iterator = floatVectorValues.iterator(); + protected final T vectorValues; + protected final KnnVectorValues.DocIndexIterator iterator; + protected final int dimensions; + + DenseVectorValuesBlockReader(T vectorValues, int dimensions) { + this.vectorValues = vectorValues; + iterator = vectorValues.iterator(); this.dimensions = dimensions; } @@ -587,18 +578,15 @@ private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException builder.appendNull(); } else if (iterator.docID() == doc || iterator.advance(doc) == doc) { builder.beginPositionEntry(); - float[] floats = floatVectorValues.vectorValue(iterator.index()); - assert floats.length == dimensions - : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; - for (float aFloat : floats) { - builder.appendFloat(aFloat); - } + appendDoc(builder); builder.endPositionEntry(); } else { builder.appendNull(); } } + protected abstract void appendDoc(BlockLoader.FloatBuilder builder) throws IOException; + @Override public int docId() { return iterator.docID(); @@ -657,26 +645,6 @@ public String toString() { } } - private static class ByteDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader { - ByteDenseVectorValuesBlockReader(ByteVectorValues floatVectorValues, int dimensions) { - super(floatVectorValues, dimensions); - } - - protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { - byte[] bytes = vectorValues.vectorValue(iterator.index()); - assert bytes.length == dimensions - : "unexpected dimensions for vector value; expected " + dimensions + " but got " + bytes.length; - for (byte aFloat : bytes) { - builder.appendFloat(aFloat); - } - } - - @Override - public String toString() { - return "BlockDocValuesReader.ByteDenseVectorValuesBlockReader"; - } - } - public static class BytesRefsFromOrdsBlockLoader extends DocValuesBlockLoader { private final String fieldName; diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 87d2583e2285b..97addbecd4564 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -216,8 +216,7 @@ public void setup() throws IOException { if (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT) { // Normalize the vector float magnitude = DenseVector.getMagnitude(vector); - vector.replaceAll(number -> number.floatValue() / magnitude); - } + vector.replaceAll(number -> number.floatValue() / magnitude); } docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector); indexedVectors.put(i, vector); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index c2f8662ac502b..4d630d95b264b 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -35,7 +35,6 @@ import static org.elasticsearch.index.IndexMode.LOOKUP; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.xpack.esql.DenseVectorFieldTypeIT.ALL_DENSE_VECTOR_INDEX_TYPES; -import static org.elasticsearch.xpack.esql.DenseVectorFieldTypeIT.NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.Matchers.lessThanOrEqualTo; @@ -54,9 +53,6 @@ public static Iterable parameters() throws Exception { for (String indexType : ALL_DENSE_VECTOR_INDEX_TYPES) { params.add(new Object[] { DenseVectorFieldMapper.ElementType.FLOAT, indexType }); } - for (String indexType : NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES) { - params.add(new Object[] { DenseVectorFieldMapper.ElementType.BYTE, indexType }); - } // Remove flat index types, as knn does not do a top k for flat params.removeIf(param -> param[1] != null && ((String) param[1]).contains("flat")); From 40edca3f893b232dfb6170093775ed0438e6cdc4 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 12 Aug 2025 12:35:32 +0000 Subject: [PATCH 12/21] [CI] Auto commit changes from spotless --- .../elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 97addbecd4564..00f890b816aea 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -135,9 +135,7 @@ public void testRetrieveDenseVectorFieldData() { | KEEP id, vector """; - indexedVectors.forEach((i, v) -> { - System.out.println("ID: " + i + ", Vector: " + v); - }); + indexedVectors.forEach((i, v) -> { System.out.println("ID: " + i + ", Vector: " + v); }); try (var resp = run(query)) { List> valuesList = EsqlTestUtils.getValuesList(resp); @@ -239,10 +237,7 @@ private void createIndexWithDenseVector(String indexName) throws IOException { .field("element_type", elementType.toString().toLowerCase(Locale.ROOT)) .field("index", index); if (index) { - mapping.field( - "similarity", - similarity.name().toLowerCase(Locale.ROOT) - ); + mapping.field("similarity", similarity.name().toLowerCase(Locale.ROOT)); } if (indexType != null) { mapping.startObject("index_options").field("type", indexType).endObject(); From 8bd7f7914e2f01b2d3dc56a9221f070cd229fe9b Mon Sep 17 00:00:00 2001 From: cdelgado Date: Tue, 12 Aug 2025 15:21:55 +0200 Subject: [PATCH 13/21] Remove debugging code --- .../org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 97addbecd4564..cd6a00d613de1 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -135,10 +135,6 @@ public void testRetrieveDenseVectorFieldData() { | KEEP id, vector """; - indexedVectors.forEach((i, v) -> { - System.out.println("ID: " + i + ", Vector: " + v); - }); - try (var resp = run(query)) { List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(valuesList.size(), indexedVectors.size()); From 7d2625cee55b5b05257fb61eaa4543742dd6eb14 Mon Sep 17 00:00:00 2001 From: cdelgado Date: Tue, 12 Aug 2025 16:52:38 +0200 Subject: [PATCH 14/21] Check that we may not have magnitudes at all, or for normalized vectors --- .../index/mapper/BlockDocValuesReader.java | 16 +++++++++------- .../xpack/esql/DenseVectorFieldTypeIT.java | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index 5ae2162192e28..6d869c4d394f8 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -532,11 +532,9 @@ public AllReader reader(LeafReaderContext context) throws IOException { FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); if (floatVectorValues != null) { if (fieldType.isNormalized()) { - return new FloatDenseVectorNormalizedValuesBlockReader( - floatVectorValues, - dimensions, - context.reader().getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX) - ); + NumericDocValues magnitudeDocValues = context.reader() + .getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX); + return new FloatDenseVectorNormalizedValuesBlockReader(floatVectorValues, dimensions, magnitudeDocValues); } return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); } @@ -632,8 +630,12 @@ protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { assert floats.length == dimensions : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; - assert magnitudeDocValues.advanceExact(iterator.docID()); - float magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue()); + float magnitude = 1.0f; + // If all vectors are normalized, no doc values will be present. The vector may be normalized already, so we may not have a + // stored magnitude for all docs + if ((magnitudeDocValues != null) && magnitudeDocValues.advanceExact(iterator.docID())) { + magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue()); + } for (float aFloat : floats) { builder.appendFloat(aFloat * magnitude); } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 566f130dd9cc5..346fe51daebb6 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -209,7 +209,7 @@ public void setup() throws IOException { for (int j = 0; j < numDims; j++) { vector.add(randomFloatBetween(0F, 1F, true)); } - if (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT) { + if ((similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT) || rarely()) { // Normalize the vector float magnitude = DenseVector.getMagnitude(vector); vector.replaceAll(number -> number.floatValue() / magnitude); From 5bcac4934f6e72db2eccc55add36cba2d12c8e77 Mon Sep 17 00:00:00 2001 From: cdelgado Date: Wed, 13 Aug 2025 10:11:19 +0200 Subject: [PATCH 15/21] Fix merge --- .../index/mapper/BlockDocValuesReader.java | 5 ++--- .../xpack/esql/DenseVectorFieldTypeIT.java | 11 ++++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index d5d13476821c7..fe27f285e3e03 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -38,6 +38,7 @@ import java.io.IOException; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.COSINE_MAGNITUDE_FIELD_SUFFIX; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE; /** * A reader that supports reading doc-values from a Lucene segment in Block fashion. @@ -540,9 +541,8 @@ public AllReader reader(LeafReaderContext context) throws IOException { .getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX); return new FloatDenseVectorNormalizedValuesBlockReader(floatVectorValues, dimensions, magnitudeDocValues); } - } - return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); } + return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); } case BYTE -> { ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(fieldName); @@ -657,7 +657,6 @@ protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { @Override public String toString() { return "BlockDocValuesReader.FloatDenseVectorNormalizedValuesBlockReader"; - return "BlockDocValuesReader.FloatDenseVectorNormalizedValuesBlockReader"; } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 2fc2b4215eceb..bd9d14aadd352 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -219,13 +219,14 @@ public void setup() throws IOException { default -> throw new IllegalArgumentException("Unexpected element type: " + elementType); } } - if (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT || rarely()) { + if ((elementType == ElementType.FLOAT) && (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT || rarely())) { // Normalize the vector float magnitude = DenseVector.getMagnitude(vector); - switch (elementType) { - case FLOAT -> vector.replaceAll(number -> number.floatValue() / magnitude); - case BYTE -> vector.replaceAll(number -> (byte) (number.byteValue() / magnitude)); - } + vector.replaceAll(number -> number.floatValue() / magnitude); + } + if (vector.stream().allMatch(v -> v.floatValue() == 0.0f)) { + // Avoid zero vectors + vector.set(randomIntBetween(0, numDims - 1), 1.0f); } docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector); indexedVectors.put(i, vector); From 57c45b50622e72af189f3444f81d80f46dcde8f5 Mon Sep 17 00:00:00 2001 From: cdelgado Date: Wed, 13 Aug 2025 10:28:15 +0200 Subject: [PATCH 16/21] Remove cosine similarity code --- .../index/mapper/BlockDocValuesReader.java | 42 +------------------ .../xpack/esql/DenseVectorFieldTypeIT.java | 8 ++-- 2 files changed, 4 insertions(+), 46 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index fe27f285e3e03..c6e6577508820 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -536,13 +536,8 @@ public AllReader reader(LeafReaderContext context) throws IOException { case FLOAT -> { FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); if (floatVectorValues != null) { - if (fieldType.isNormalized()) { - NumericDocValues magnitudeDocValues = context.reader() - .getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX); - return new FloatDenseVectorNormalizedValuesBlockReader(floatVectorValues, dimensions, magnitudeDocValues); - } + return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); } - return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); } case BYTE -> { ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(fieldName); @@ -625,41 +620,6 @@ public String toString() { } } - private static class FloatDenseVectorNormalizedValuesBlockReader extends DenseVectorValuesBlockReader { - private final NumericDocValues magnitudeDocValues; - - FloatDenseVectorNormalizedValuesBlockReader( - FloatVectorValues floatVectorValues, - int dimensions, - NumericDocValues magnitudeDocValues - ) { - super(floatVectorValues, dimensions); - this.magnitudeDocValues = magnitudeDocValues; - } - - @Override - protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { - float[] floats = vectorValues.vectorValue(iterator.index()); - assert floats.length == dimensions - : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; - - float magnitude = 1.0f; - // If all vectors are normalized, no doc values will be present. The vector may be normalized already, so we may not have a - // stored magnitude for all docs - if ((magnitudeDocValues != null) && magnitudeDocValues.advanceExact(iterator.docID())) { - magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue()); - } - for (float aFloat : floats) { - builder.appendFloat(aFloat * magnitude); - } - } - - @Override - public String toString() { - return "BlockDocValuesReader.FloatDenseVectorNormalizedValuesBlockReader"; - } - } - private static class ByteDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader { ByteDenseVectorValuesBlockReader(ByteVectorValues floatVectorValues, int dimensions) { super(floatVectorValues, dimensions); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index bd9d14aadd352..3fa523a011052 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -66,7 +66,9 @@ public static Iterable parameters() throws Exception { ? randomFrom(ALL_DENSE_VECTOR_INDEX_TYPES) : randomFrom(NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES); Supplier vectorSimilarityProvider = () -> randomFrom( - DenseVectorFieldMapper.VectorSimilarity.values() + DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT, + DenseVectorFieldMapper.VectorSimilarity.L2_NORM, + DenseVectorFieldMapper.VectorSimilarity.MAX_INNER_PRODUCT ); params.add(new Object[] { elementTypeProvider, indexTypeProvider, vectorSimilarityProvider, true, false }); // No indexing @@ -137,10 +139,6 @@ public void testRetrieveDenseVectorFieldData() { | KEEP id, vector """; - indexedVectors.forEach((i, v) -> { - System.out.println("ID: " + i + ", Vector: " + v); - }); - try (var resp = run(query)) { List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(valuesList.size(), indexedVectors.size()); From 6977dbe555bd4ceab73bfb577345e2231f499a44 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 13 Aug 2025 08:36:27 +0000 Subject: [PATCH 17/21] [CI] Auto commit changes from spotless --- .../org/elasticsearch/index/mapper/BlockDocValuesReader.java | 1 - 1 file changed, 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index c6e6577508820..c2d8eafe049e4 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -37,7 +37,6 @@ import java.io.IOException; -import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.COSINE_MAGNITUDE_FIELD_SUFFIX; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE; /** From 239d35043e83cb7f5ac36e1d021b16fe7fef62ac Mon Sep 17 00:00:00 2001 From: cdelgado Date: Wed, 13 Aug 2025 11:31:52 +0200 Subject: [PATCH 18/21] Better parameterized test --- .../xpack/esql/DenseVectorFieldTypeIT.java | 42 +++++++++---------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index fdc72797659b1..847a52c2c6d52 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -29,8 +29,6 @@ import java.util.Locale; import java.util.Map; import java.util.Set; -import java.util.function.Function; -import java.util.function.Supplier; import static org.elasticsearch.index.IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING; import static org.elasticsearch.index.mapper.SourceFieldMapper.Mode.SYNTHETIC; @@ -54,40 +52,39 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { private final ElementType elementType; private final DenseVectorFieldMapper.VectorSimilarity similarity; private final boolean synthetic; - private final String indexType; private final boolean index; @ParametersFactory public static Iterable parameters() throws Exception { List params = new ArrayList<>(); - // Indexed field types - Supplier elementTypeProvider = () -> randomFrom(ElementType.FLOAT, ElementType.BYTE); - Function indexTypeProvider = e -> e == ElementType.FLOAT - ? randomFrom(ALL_DENSE_VECTOR_INDEX_TYPES) - : randomFrom(NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES); - Supplier vectorSimilarityProvider = () -> randomFrom( + List similarities = List.of( DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT, DenseVectorFieldMapper.VectorSimilarity.L2_NORM, DenseVectorFieldMapper.VectorSimilarity.MAX_INNER_PRODUCT ); - params.add(new Object[] { elementTypeProvider, indexTypeProvider, vectorSimilarityProvider, true, false }); - // No indexing - params.add(new Object[] { elementTypeProvider, null, null, false, false }); - // No indexing, synthetic source - params.add(new Object[] { elementTypeProvider, null, null, false, true }); + + for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT)) { + // Test all similarities for element types + for (DenseVectorFieldMapper.VectorSimilarity similarity : similarities) { + params.add(new Object[] { elementType, similarity, true, false }); + } + + // No indexing + params.add(new Object[] { elementType, null, false, false }); + // No indexing, synthetic source + params.add(new Object[] { elementType, null, false, true }); + } return params; } public DenseVectorFieldTypeIT( - @Name("elementType") Supplier elementTypeProvider, - @Name("indexType") Function indexTypeProvider, - @Name("similarity") Supplier similarityProvider, + @Name("elementType") ElementType elementType, + @Name("similarity") DenseVectorFieldMapper.VectorSimilarity similarity, @Name("index") boolean index, @Name("synthetic") boolean synthetic ) { - this.elementType = elementTypeProvider.get(); - this.indexType = indexTypeProvider == null ? null : indexTypeProvider.apply(this.elementType); - this.similarity = similarityProvider == null ? null : similarityProvider.get(); + this.elementType = elementType; + this.similarity = similarity == null ? null : similarity; this.index = index; this.synthetic = synthetic; } @@ -244,8 +241,9 @@ private void createIndexWithDenseVector(String indexName) throws IOException { .field("index", index); if (index) { mapping.field("similarity", similarity.name().toLowerCase(Locale.ROOT)); - } - if (indexType != null) { + String indexType = elementType == ElementType.FLOAT + ? randomFrom(ALL_DENSE_VECTOR_INDEX_TYPES) + : randomFrom(NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES); mapping.startObject("index_options").field("type", indexType).endObject(); } mapping.endObject().endObject().endObject(); From 93851132b860f97071542d589f1a67a1ffeb28ac Mon Sep 17 00:00:00 2001 From: cdelgado Date: Wed, 13 Aug 2025 11:33:57 +0200 Subject: [PATCH 19/21] Fix test --- .../org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index e82e386ef56e4..78074882eea9c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -2383,8 +2383,8 @@ public void testDenseVectorImplicitCastingSimilarityFunctions() { checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(byte_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); } if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f)); - checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(vector, [1, 2, 3])", List.of(1f, 2f, 3f)); + checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(float_vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f)); + checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(float_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); } } From 14ba3b009306196f634fd7a6375e2eda8d41e752 Mon Sep 17 00:00:00 2001 From: cdelgado Date: Wed, 13 Aug 2025 11:38:29 +0200 Subject: [PATCH 20/21] Fix test --- .../org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 847a52c2c6d52..8f9e613d2acec 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -84,7 +84,7 @@ public DenseVectorFieldTypeIT( @Name("synthetic") boolean synthetic ) { this.elementType = elementType; - this.similarity = similarity == null ? null : similarity; + this.similarity = similarity; this.index = index; this.synthetic = synthetic; } From 64ca5632cc2de135ee5d62a4730265887fbae323 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 13 Aug 2025 09:47:46 +0000 Subject: [PATCH 21/21] [CI] Auto commit changes from spotless --- .../org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 78074882eea9c..e040067458408 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -2383,7 +2383,10 @@ public void testDenseVectorImplicitCastingSimilarityFunctions() { checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(byte_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); } if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(float_vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f)); + checkDenseVectorImplicitCastingSimilarityFunction( + "v_l2_norm(float_vector, [0.342, 0.164, 0.234])", + List.of(0.342f, 0.164f, 0.234f) + ); checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(float_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); } }