diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index 9e0574985f4db..6a0b2eac47e85 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -10,6 +10,7 @@ package org.elasticsearch.index.mapper; import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.KnnVectorValues; @@ -30,11 +31,15 @@ import org.elasticsearch.index.mapper.BlockLoader.DoubleBuilder; import org.elasticsearch.index.mapper.BlockLoader.IntBuilder; import org.elasticsearch.index.mapper.BlockLoader.LongBuilder; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder; import org.elasticsearch.search.fetch.StoredFieldsSpec; import java.io.IOException; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE; + /** * A reader that supports reading doc-values from a Lucene segment in Block fashion. */ @@ -516,10 +521,12 @@ public String toString() { public static class DenseVectorBlockLoader extends DocValuesBlockLoader { private final String fieldName; private final int dimensions; + private final DenseVectorFieldMapper.DenseVectorFieldType fieldType; - public DenseVectorBlockLoader(String fieldName, int dimensions) { + public DenseVectorBlockLoader(String fieldName, int dimensions, DenseVectorFieldMapper.DenseVectorFieldType fieldType) { this.fieldName = fieldName; this.dimensions = dimensions; + this.fieldType = fieldType; } @Override @@ -529,22 +536,34 @@ public Builder builder(BlockFactory factory, int expectedCount) { @Override public AllReader reader(LeafReaderContext context) throws IOException { - FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); - if (floatVectorValues != null) { - return new DenseVectorValuesBlockReader(floatVectorValues, dimensions); + switch (fieldType.getElementType()) { + case FLOAT -> { + FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); + if (floatVectorValues != null) { + return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); + } + } + case BYTE -> { + ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(fieldName); + if (byteVectorValues != null) { + return new ByteDenseVectorValuesBlockReader(byteVectorValues, dimensions); + } + } } + return new ConstantNullsReader(); } } - private static class DenseVectorValuesBlockReader extends BlockDocValuesReader { - private final FloatVectorValues floatVectorValues; - private final KnnVectorValues.DocIndexIterator iterator; - private final int dimensions; + private abstract static class DenseVectorValuesBlockReader extends BlockDocValuesReader { + + protected final T vectorValues; + protected final KnnVectorValues.DocIndexIterator iterator; + protected final int dimensions; - DenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) { - this.floatVectorValues = floatVectorValues; - iterator = floatVectorValues.iterator(); + DenseVectorValuesBlockReader(T vectorValues, int dimensions) { + this.vectorValues = vectorValues; + iterator = vectorValues.iterator(); this.dimensions = dimensions; } @@ -569,26 +588,59 @@ private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException builder.appendNull(); } else if (iterator.docID() == doc || iterator.advance(doc) == doc) { builder.beginPositionEntry(); - float[] floats = floatVectorValues.vectorValue(iterator.index()); - assert floats.length == dimensions - : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; - for (float aFloat : floats) { - builder.appendFloat(aFloat); - } + appendDoc(builder); builder.endPositionEntry(); } else { builder.appendNull(); } } + protected abstract void appendDoc(BlockLoader.FloatBuilder builder) throws IOException; + @Override public int docId() { return iterator.docID(); } + } + + private static class FloatDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader { + + FloatDenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) { + super(floatVectorValues, dimensions); + } + + protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { + float[] floats = vectorValues.vectorValue(iterator.index()); + assert floats.length == dimensions + : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; + for (float aFloat : floats) { + builder.appendFloat(aFloat); + } + } @Override public String toString() { - return "BlockDocValuesReader.FloatVectorValuesBlockReader"; + return "BlockDocValuesReader.FloatDenseVectorValuesBlockReader"; + } + } + + private static class ByteDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader { + ByteDenseVectorValuesBlockReader(ByteVectorValues floatVectorValues, int dimensions) { + super(floatVectorValues, dimensions); + } + + protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { + byte[] bytes = vectorValues.vectorValue(iterator.index()); + assert bytes.length == dimensions + : "unexpected dimensions for vector value; expected " + dimensions + " but got " + bytes.length; + for (byte aFloat : bytes) { + builder.appendFloat(aFloat); + } + } + + @Override + public String toString() { + return "BlockDocValuesReader.ByteDenseVectorValuesBlockReader"; } } @@ -880,11 +932,13 @@ public static class DenseVectorFromBinaryBlockLoader extends DocValuesBlockLoade private final String fieldName; private final int dims; private final IndexVersion indexVersion; + private final ElementType elementType; - public DenseVectorFromBinaryBlockLoader(String fieldName, int dims, IndexVersion indexVersion) { + public DenseVectorFromBinaryBlockLoader(String fieldName, int dims, IndexVersion indexVersion, ElementType elementType) { this.fieldName = fieldName; this.dims = dims; this.indexVersion = indexVersion; + this.elementType = elementType; } @Override @@ -898,23 +952,40 @@ public AllReader reader(LeafReaderContext context) throws IOException { if (docValues == null) { return new ConstantNullsReader(); } - return new DenseVectorFromBinary(docValues, dims, indexVersion); + switch (elementType) { + case FLOAT: + return new FloatDenseVectorFromBinary(docValues, dims, indexVersion); + case BYTE: + return new ByteDenseVectorFromBinary(docValues, dims, indexVersion); + default: + throw new IllegalArgumentException("Unknown element type [" + elementType + "]"); + } } } - private static class DenseVectorFromBinary extends BlockDocValuesReader { - private final BinaryDocValues docValues; - private final IndexVersion indexVersion; - private final int dimensions; - private final float[] scratch; - - private int docID = -1; + // Abstract base for dense vector readers + private abstract static class AbstractDenseVectorFromBinary extends BlockDocValuesReader { + protected final BinaryDocValues docValues; + protected final IndexVersion indexVersion; + protected final int dimensions; + protected final T scratch; + protected int docID = -1; - DenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) { + AbstractDenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion, T scratch) { this.docValues = docValues; - this.scratch = new float[dims]; this.indexVersion = indexVersion; this.dimensions = dims; + this.scratch = scratch; + } + + @Override + public int docId() { + return docID; + } + + @Override + public void read(int docId, BlockLoader.StoredFields storedFields, Builder builder) throws IOException { + read(docId, (BlockLoader.FloatBuilder) builder); } @Override @@ -931,36 +1002,67 @@ public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throw } } - @Override - public void read(int docId, BlockLoader.StoredFields storedFields, Builder builder) throws IOException { - read(docId, (BlockLoader.FloatBuilder) builder); - } - private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException { this.docID = doc; - if (false == docValues.advanceExact(doc)) { + if (docValues.advanceExact(doc) == false) { builder.appendNull(); return; } BytesRef bytesRef = docValues.binaryValue(); assert bytesRef.length > 0; - VectorEncoderDecoder.decodeDenseVector(indexVersion, bytesRef, scratch); + decodeDenseVector(bytesRef, scratch); builder.beginPositionEntry(); + writeScratchToBuilder(scratch, builder); + builder.endPositionEntry(); + } + + protected abstract void decodeDenseVector(BytesRef bytesRef, T scratch); + + protected abstract void writeScratchToBuilder(T scratch, BlockLoader.FloatBuilder builder); + } + + private static class FloatDenseVectorFromBinary extends AbstractDenseVectorFromBinary { + FloatDenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) { + super(docValues, dims, indexVersion, new float[dims]); + } + + @Override + protected void writeScratchToBuilder(float[] scratch, BlockLoader.FloatBuilder builder) { for (float value : scratch) { builder.appendFloat(value); } - builder.endPositionEntry(); } @Override - public int docId() { - return docID; + protected void decodeDenseVector(BytesRef bytesRef, float[] scratch) { + VectorEncoderDecoder.decodeDenseVector(indexVersion, bytesRef, scratch); } @Override public String toString() { - return "DenseVectorFromBinary.Bytes"; + return "FloatDenseVectorFromBinary.Bytes"; + } + } + + private static class ByteDenseVectorFromBinary extends AbstractDenseVectorFromBinary { + ByteDenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) { + super(docValues, dims, indexVersion, new byte[dims]); + } + + @Override + public String toString() { + return "ByteDenseVectorFromBinary.Bytes"; + } + + protected void writeScratchToBuilder(byte[] scratch, BlockLoader.FloatBuilder builder) { + for (byte value : scratch) { + builder.appendFloat(value); + } + } + + protected void decodeDenseVector(BytesRef bytesRef, byte[] scratch) { + VectorEncoderDecoder.decodeDenseVector(indexVersion, bytesRef, scratch); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index c9c14d027ebfd..4edd6475b890d 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -734,31 +734,29 @@ IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldTyp this, denseVectorFieldType.dims, denseVectorFieldType.indexed, - denseVectorFieldType.indexVersionCreated.onOrAfter(NORMALIZE_COSINE) - && denseVectorFieldType.indexed - && denseVectorFieldType.similarity.equals(VectorSimilarity.COSINE) ? r -> new FilterLeafReader(r) { - @Override - public CacheHelper getCoreCacheHelper() { - return r.getCoreCacheHelper(); - } + denseVectorFieldType.isNormalized() && denseVectorFieldType.indexed ? r -> new FilterLeafReader(r) { + @Override + public CacheHelper getCoreCacheHelper() { + return r.getCoreCacheHelper(); + } - @Override - public CacheHelper getReaderCacheHelper() { - return r.getReaderCacheHelper(); - } + @Override + public CacheHelper getReaderCacheHelper() { + return r.getReaderCacheHelper(); + } - @Override - public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException { - FloatVectorValues values = in.getFloatVectorValues(fieldName); - if (values == null) { - return null; - } - return new DenormalizedCosineFloatVectorValues( - values, - in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX) - ); + @Override + public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException { + FloatVectorValues values = in.getFloatVectorValues(fieldName); + if (values == null) { + return null; } - } : r -> r + return new DenormalizedCosineFloatVectorValues( + values, + in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX) + ); + } + } : r -> r ); } @@ -820,9 +818,7 @@ public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFie fieldMapper.checkDimensionMatches(index, context); checkVectorBounds(vector); checkVectorMagnitude(fieldMapper.fieldType().similarity, errorFloatElementsAppender(vector), squaredMagnitude); - if (fieldMapper.indexCreatedVersion.onOrAfter(NORMALIZE_COSINE) - && fieldMapper.fieldType().similarity.equals(VectorSimilarity.COSINE) - && isNotUnitVector(squaredMagnitude)) { + if (fieldMapper.fieldType().isNormalized() && isNotUnitVector(squaredMagnitude)) { float length = (float) Math.sqrt(squaredMagnitude); for (int i = 0; i < vector.length; i++) { vector[i] /= length; @@ -2491,6 +2487,10 @@ public Query createExactKnnQuery(VectorData queryVector, Float vectorSimilarity) return knnQuery; } + public boolean isNormalized() { + return indexVersionCreated.onOrAfter(NORMALIZE_COSINE) && VectorSimilarity.COSINE.equals(similarity); + } + private Query createExactKnnBitQuery(byte[] queryVector) { elementType.checkDimensions(dims, queryVector.length); return new DenseVectorQuery.Bytes(queryVector, name()); @@ -2511,9 +2511,7 @@ private Query createExactKnnFloatQuery(float[] queryVector) { if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude); - if (similarity == VectorSimilarity.COSINE - && indexVersionCreated.onOrAfter(NORMALIZE_COSINE) - && isNotUnitVector(squaredMagnitude)) { + if (isNormalized() && isNotUnitVector(squaredMagnitude)) { float length = (float) Math.sqrt(squaredMagnitude); queryVector = Arrays.copyOf(queryVector, queryVector.length); for (int i = 0; i < queryVector.length; i++) { @@ -2703,9 +2701,7 @@ private Query createKnnFloatQuery( if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude); - if (similarity == VectorSimilarity.COSINE - && indexVersionCreated.onOrAfter(NORMALIZE_COSINE) - && isNotUnitVector(squaredMagnitude)) { + if (isNormalized() && isNotUnitVector(squaredMagnitude)) { float length = (float) Math.sqrt(squaredMagnitude); queryVector = Arrays.copyOf(queryVector, queryVector.length); for (int i = 0; i < queryVector.length; i++) { @@ -2795,7 +2791,7 @@ int getVectorDimensions() { return dims; } - ElementType getElementType() { + public ElementType getElementType() { return elementType; } @@ -2805,8 +2801,8 @@ public DenseVectorIndexOptions getIndexOptions() { @Override public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) { - if (elementType != ElementType.FLOAT) { - // Just float dense vector support for now + if (elementType == ElementType.BIT) { + // Just float and byte dense vector support for now return null; } @@ -2816,11 +2812,11 @@ public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) { } if (indexed) { - return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims); + return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims, this); } if (hasDocValues() && (blContext.fieldExtractPreference() != FieldExtractPreference.STORED || isSyntheticSource)) { - return new BlockDocValuesReader.DenseVectorFromBinaryBlockLoader(name(), dims, indexVersionCreated); + return new BlockDocValuesReader.DenseVectorFromBinaryBlockLoader(name(), dims, indexVersionCreated, elementType); } BlockSourceReader.LeafIteratorLookup lookup = BlockSourceReader.lookupMatchingAll(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java index 54b369ab1f377..9dec4a4f2dd61 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java @@ -84,6 +84,26 @@ public static void decodeDenseVector(IndexVersion indexVersion, BytesRef vectorB } } + /** + * Decodes a BytesRef into the provided array of bytes + * @param vectorBR - dense vector encoded in BytesRef + * @param vector - array of bytes where the decoded vector should be stored + */ + public static void decodeDenseVector(IndexVersion indexVersion, BytesRef vectorBR, byte[] vector) { + if (vectorBR == null) { + throw new IllegalArgumentException(DenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE); + } + if (indexVersion.onOrAfter(LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION)) { + ByteBuffer fb = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length).order(ByteOrder.LITTLE_ENDIAN); + fb.get(vector); + } else { + ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length); + for (int dim = 0; dim < vector.length; dim++) { + vector[dim] = byteBuffer.get(dim * vectorBR.offset); + } + } + } + public static float[] getMultiMagnitudes(BytesRef magnitudes) { assert magnitudes.length % Float.BYTES == 0; float[] multiMagnitudes = new float[magnitudes.length / Float.BYTES]; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/dense_vector.csv b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/dense_vector.csv index d24c9f8543b53..dc5b58e354cf5 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/dense_vector.csv +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/dense_vector.csv @@ -1,5 +1,5 @@ -id:l, vector:dense_vector -0, [1.0, 2.0, 3.0] -1, [4.0, 5.0, 6.0] -2, [9.0, 8.0, 7.0] -3, [0.054, 0.032, 0.012] +id:l, float_vector:dense_vector, byte_vector:dense_vector +0, [1.0, 2.0, 3.0], [10, 20, 30] +1, [4.0, 5.0, 6.0], [40, 50, 60] +2, [9.0, 8.0, 7.0], [90, 80, 70] +3, [0.054, 0.032, 0.012], [100, 110, 120] diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector-byte.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector-byte.csv-spec new file mode 100644 index 0000000000000..f3fc819f9867e --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector-byte.csv-spec @@ -0,0 +1,47 @@ +retrieveByteVectorData +required_capability: dense_vector_field_type_byte_elements + +FROM dense_vector +| KEEP id, byte_vector +| SORT id +; + +id:l | byte_vector:dense_vector +0 | [10, 20, 30] +1 | [40, 50, 60] +2 | [90, 80, 70] +3 | [100, 110, 120] +; + +denseByteVectorWithEval +required_capability: dense_vector_field_type_byte_elements + +FROM dense_vector +| EVAL v = byte_vector +| KEEP id, v +| SORT id +; + +id:l | v:dense_vector +0 | [10, 20, 30] +1 | [40, 50, 60] +2 | [90, 80, 70] +3 | [100, 110, 120] +; + +denseByteVectorWithRenameAndDrop +required_capability: dense_vector_field_type_byte_elements + +FROM dense_vector +| EVAL v = byte_vector +| RENAME v AS new_vector +| DROP float_vector, byte_vector +| SORT id +; + +id:l | new_vector:dense_vector +0 | [10, 20, 30] +1 | [40, 50, 60] +2 | [90, 80, 70] +3 | [100, 110, 120] +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector.csv-spec index 74ef532313055..077565b8b8997 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector.csv-spec @@ -1,13 +1,12 @@ - retrieveDenseVectorData required_capability: dense_vector_field_type FROM dense_vector -| KEEP id, vector +| KEEP id, float_vector | SORT id ; -id:l | vector:dense_vector +id:l | float_vector:dense_vector 0 | [1.0, 2.0, 3.0] 1 | [4.0, 5.0, 6.0] 2 | [9.0, 8.0, 7.0] @@ -18,7 +17,7 @@ denseVectorWithEval required_capability: dense_vector_field_type FROM dense_vector -| EVAL v = vector +| EVAL v = float_vector | KEEP id, v | SORT id ; @@ -34,9 +33,9 @@ denseVectorWithRenameAndDrop required_capability: dense_vector_field_type FROM dense_vector -| EVAL v = vector +| EVAL v = float_vector | RENAME v AS new_vector -| DROP vector +| DROP float_vector, byte_vector | SORT id ; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-dense_vector.json b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-dense_vector.json index 9c7d34f0f15e4..3e21d490583bb 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-dense_vector.json +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-dense_vector.json @@ -3,7 +3,7 @@ "id": { "type": "long" }, - "vector": { + "float_vector": { "type": "dense_vector", "similarity": "l2_norm", "index_options": { @@ -11,6 +11,16 @@ "m": 16, "ef_construction": 100 } + }, + "byte_vector": { + "type": "dense_vector", + "similarity": "l2_norm", + "element_type": "byte", + "index_options": { + "type": "hnsw", + "m": 16, + "ef_construction": 100 + } } } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 0673856cbcc3b..8f9e613d2acec 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -13,6 +13,9 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import org.elasticsearch.script.field.vectors.DenseVector; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; @@ -23,6 +26,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; @@ -32,7 +36,7 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { - private static final Set DENSE_VECTOR_INDEX_TYPES = Set.of( + public static final Set ALL_DENSE_VECTOR_INDEX_TYPES = Set.of( "int8_hnsw", "hnsw", "int4_hnsw", @@ -42,32 +46,50 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { "bbq_flat", "flat" ); + public static final Set NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Set.of("hnsw", "flat"); + public static final float DELTA = 1e-7F; - private final String indexType; - private final boolean index; + private final ElementType elementType; + private final DenseVectorFieldMapper.VectorSimilarity similarity; private final boolean synthetic; + private final boolean index; @ParametersFactory public static Iterable parameters() throws Exception { List params = new ArrayList<>(); - // Indexed field types - for (String indexType : DENSE_VECTOR_INDEX_TYPES) { - params.add(new Object[] { indexType, true, false }); + List similarities = List.of( + DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT, + DenseVectorFieldMapper.VectorSimilarity.L2_NORM, + DenseVectorFieldMapper.VectorSimilarity.MAX_INNER_PRODUCT + ); + + for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT)) { + // Test all similarities for element types + for (DenseVectorFieldMapper.VectorSimilarity similarity : similarities) { + params.add(new Object[] { elementType, similarity, true, false }); + } + + // No indexing + params.add(new Object[] { elementType, null, false, false }); + // No indexing, synthetic source + params.add(new Object[] { elementType, null, false, true }); } - // No indexing - params.add(new Object[] { null, false, false }); - // No indexing, synthetic source - params.add(new Object[] { null, false, true }); return params; } - public DenseVectorFieldTypeIT(@Name("indexType") String indexType, @Name("index") boolean index, @Name("synthetic") boolean synthetic) { - this.indexType = indexType; + public DenseVectorFieldTypeIT( + @Name("elementType") ElementType elementType, + @Name("similarity") DenseVectorFieldMapper.VectorSimilarity similarity, + @Name("index") boolean index, + @Name("synthetic") boolean synthetic + ) { + this.elementType = elementType; + this.similarity = similarity; this.index = index; this.synthetic = synthetic; } - private final Map> indexedVectors = new HashMap<>(); + private final Map> indexedVectors = new HashMap<>(); public void testRetrieveFieldType() { var query = """ @@ -90,17 +112,17 @@ public void testRetrieveTopNDenseVectorFieldData() { try (var resp = run(query)) { List> valuesList = EsqlTestUtils.getValuesList(resp); - indexedVectors.forEach((id, vector) -> { + indexedVectors.forEach((id, expectedVector) -> { var values = valuesList.get(id); assertEquals(id, values.get(0)); - List vectors = (List) values.get(1); - if (vector == null) { - assertNull(vectors); + List actualVector = (List) values.get(1); + if (expectedVector == null) { + assertNull(actualVector); } else { - assertNotNull(vectors); - assertEquals(vector.size(), vectors.size()); - for (int i = 0; i < vector.size(); i++) { - assertEquals(vector.get(i), vectors.get(i), 0F); + assertNotNull(actualVector); + assertEquals(expectedVector.size(), actualVector.size()); + for (int i = 0; i < expectedVector.size(); i++) { + assertEquals(expectedVector.get(i).floatValue(), actualVector.get(i).floatValue(), DELTA); } } }); @@ -117,21 +139,24 @@ public void testRetrieveDenseVectorFieldData() { try (var resp = run(query)) { List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(valuesList.size(), indexedVectors.size()); + // print all values for debugging valuesList.forEach(value -> { - ; assertEquals(2, value.size()); Integer id = (Integer) value.get(0); - List expectedVector = indexedVectors.get(id); - List vector = (List) value.get(1); + List expectedVector = indexedVectors.get(id); + List actualVector = (List) value.get(1); if (expectedVector == null) { - assertNull(vector); + assertNull(actualVector); } else { - assertNotNull(vector); - assertEquals(expectedVector.size(), vector.size()); - assertNotNull(vector); - assertNotNull(expectedVector); - for (int i = 0; i < vector.size(); i++) { - assertEquals(expectedVector.get(i), vector.get(i), 0F); + assertNotNull(actualVector); + assertEquals(expectedVector.size(), actualVector.size()); + for (int i = 0; i < actualVector.size(); i++) { + assertEquals( + "Actual: " + actualVector + "; expected: " + expectedVector, + expectedVector.get(i).floatValue(), + actualVector.get(i).floatValue(), + DELTA + ); } } }); @@ -177,13 +202,22 @@ public void setup() throws IOException { int numDocs = randomIntBetween(10, 100); IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; for (int i = 0; i < numDocs; i++) { - List vector = new ArrayList<>(numDims); + List vector = new ArrayList<>(numDims); if (rarely()) { docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i)); indexedVectors.put(i, null); } else { for (int j = 0; j < numDims; j++) { - vector.add(randomFloat()); + switch (elementType) { + case FLOAT -> vector.add(randomFloatBetween(0F, 1F, true)); + case BYTE -> vector.add((byte) (randomFloatBetween(0F, 1F, true) * 127.0f)); + default -> throw new IllegalArgumentException("Unexpected element type: " + elementType); + } + } + if ((elementType == ElementType.FLOAT) && (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT || rarely())) { + // Normalize the vector + float magnitude = DenseVector.getMagnitude(vector); + vector.replaceAll(number -> number.floatValue() / magnitude); } docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector); indexedVectors.put(i, vector); @@ -203,11 +237,13 @@ private void createIndexWithDenseVector(String indexName) throws IOException { .endObject() .startObject("vector") .field("type", "dense_vector") + .field("element_type", elementType.toString().toLowerCase(Locale.ROOT)) .field("index", index); if (index) { - mapping.field("similarity", "l2_norm"); - } - if (indexType != null) { + mapping.field("similarity", similarity.name().toLowerCase(Locale.ROOT)); + String indexType = elementType == ElementType.FLOAT + ? randomFrom(ALL_DENSE_VECTOR_INDEX_TYPES) + : randomFrom(NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES); mapping.startObject("index_options").field("type", indexType).endObject(); } mapping.endObject().endObject().endObject(); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index 9ae1c980337f1..d44a9b458b082 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -7,11 +7,15 @@ package org.elasticsearch.xpack.esql.plugin; +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.client.internal.IndicesAdminClient; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.esql.EsqlTestUtils; @@ -30,42 +34,73 @@ import static org.elasticsearch.index.IndexMode.LOOKUP; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.xpack.esql.DenseVectorFieldTypeIT.ALL_DENSE_VECTOR_INDEX_TYPES; +import static org.elasticsearch.xpack.esql.DenseVectorFieldTypeIT.NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES; import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.Matchers.lessThanOrEqualTo; public class KnnFunctionIT extends AbstractEsqlIntegTestCase { - private final Map> indexedVectors = new HashMap<>(); + private final Map> indexedVectors = new HashMap<>(); private int numDocs; private int numDims; + private final DenseVectorFieldMapper.ElementType elementType; + private final String indexType; + + @ParametersFactory + public static Iterable parameters() throws Exception { + List params = new ArrayList<>(); + for (String indexType : ALL_DENSE_VECTOR_INDEX_TYPES) { + params.add(new Object[] { DenseVectorFieldMapper.ElementType.FLOAT, indexType }); + } + for (String indexType : NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES) { + params.add(new Object[] { DenseVectorFieldMapper.ElementType.BYTE, indexType }); + } + + // Remove flat index types, as knn does not do a top k for flat + params.removeIf(param -> param[1] != null && ((String) param[1]).contains("flat")); + return params; + } + + public KnnFunctionIT(@Name("elementType") DenseVectorFieldMapper.ElementType elementType, @Name("indexType") String indexType) { + this.elementType = elementType; + this.indexType = indexType; + } + public void testKnnDefaults() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); var query = String.format(Locale.ROOT, """ FROM test METADATA _score | WHERE knn(vector, %s, 10) - | KEEP id, floats, _score, vector + | KEEP id, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); - assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + assertColumnNames(resp.columns(), List.of("id", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(Math.min(indexedVectors.size(), 10), valuesList.size()); - for (int i = 0; i < valuesList.size(); i++) { - List row = valuesList.get(i); - // Vectors should be in order of ID, as they're less similar than the query vector as the ID increases - assertEquals(i, row.getFirst()); + double previousScore = Float.MAX_VALUE; + for (List row : valuesList) { + // Vectors should be in score order + double currentScore = (Double) row.get(1); + assertThat(currentScore, lessThanOrEqualTo(previousScore)); + previousScore = currentScore; @SuppressWarnings("unchecked") // Vectors should be the same - List floats = (List) row.get(1); - for (int j = 0; j < floats.size(); j++) { - assertEquals(floats.get(j).floatValue(), indexedVectors.get(i).get(j), 0f); + List actualVector = (List) row.get(2); + List expectedVector = indexedVectors.get(row.get(0)); + for (int j = 0; j < actualVector.size(); j++) { + float expected = expectedVector.get(j).floatValue(); + float actual = actualVector.get(j).floatValue(); + assertEquals(expected, actual, 0f); } - var score = (Double) row.get(2); + var score = (Double) row.get(1); assertNotNull(score); assertTrue(score > 0.0); } @@ -74,18 +109,18 @@ public void testKnnDefaults() { public void testKnnOptions() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); var query = String.format(Locale.ROOT, """ FROM test METADATA _score | WHERE knn(vector, %s, 5) - | KEEP id, floats, _score, vector + | KEEP id, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); - assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + assertColumnNames(resp.columns(), List.of("id", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(5, valuesList.size()); @@ -94,42 +129,41 @@ public void testKnnOptions() { public void testKnnNonPushedDown() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); // TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, 5) OR id > 10 - | KEEP id, floats, _score, vector + | WHERE knn(vector, %s, 5) OR id > 100 + | KEEP id, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); - assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + assertColumnNames(resp.columns(), List.of("id", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); - // K = 5, 1 more for every id > 10 - assertEquals(5 + Math.max(0, numDocs - 10 - 1), valuesList.size()); + assertEquals(5, valuesList.size()); } } public void testKnnWithPrefilters() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); // We retrieve 5 from knn, but must be prefiltered with id > 5 or no result will be returned as it would be post-filtered var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, 5) AND id > 5 - | KEEP id, floats, _score, vector + | WHERE knn(vector, %s, 5) AND id > 5 AND id <= 10 + | KEEP id, _score, vector | SORT _score DESC | LIMIT 5 """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); - assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + assertColumnNames(resp.columns(), List.of("id", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); // K = 5, 1 more for every id > 10 @@ -139,12 +173,12 @@ public void testKnnWithPrefilters() { public void testKnnWithLookupJoin() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); var query = String.format(Locale.ROOT, """ FROM test | LOOKUP JOIN test_lookup ON id - | WHERE KNN(lookup_vector, %s, 5) OR id > 10 + | WHERE KNN(lookup_vector, %s, 5) OR id > 100 """, Arrays.toString(queryVector)); var error = expectThrows(VerificationException.class, () -> run(query)); @@ -171,10 +205,14 @@ public void setup() throws IOException { .endObject() .startObject("vector") .field("type", "dense_vector") - .field("similarity", "l2_norm") + .field( + "similarity", + // Let's not use others to avoid vector normalization + randomFrom("l2_norm", "max_inner_product") + ) + .startObject("index_options") + .field("type", indexType) .endObject() - .startObject("floats") - .field("type", "float") .endObject() .endObject() .endObject(); @@ -186,16 +224,24 @@ public void setup() throws IOException { var createRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build()); assertAcked(createRequest); - numDocs = randomIntBetween(15, 25); - numDims = randomIntBetween(3, 10); + numDocs = randomIntBetween(20, 35); + numDims = 64 + randomIntBetween(1, 10) * 2; IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; - float value = 0.0f; for (int i = 0; i < numDocs; i++) { - List vector = new ArrayList<>(numDims); + List vector = new ArrayList<>(numDims); for (int j = 0; j < numDims; j++) { - vector.add(value++); + switch (elementType) { + case FLOAT: + vector.add(randomFloatBetween(0F, 1F, true)); + break; + case BYTE: + vector.add((byte) (randomFloatBetween(0F, 1F, true) * 127)); + break; + default: + throw new IllegalArgumentException("Unexpected element type: " + elementType); + } } - docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "floats", vector, "vector", vector); + docs[i] = prepareIndex("test").setId(String.valueOf(i)).setSource("id", String.valueOf(i), "vector", vector); indexedVectors.put(i, vector); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index c7de35e94b464..bfe21b4019b87 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1349,7 +1349,12 @@ public enum Cap { /** * Support correct counting of skipped shards. */ - CORRECT_SKIPPED_SHARDS_COUNT; + CORRECT_SKIPPED_SHARDS_COUNT, + + /** + * Byte elements dense vector field type support. + */ + DENSE_VECTOR_FIELD_TYPE_BYTE_ELEMENTS(EsqlCorePlugin.DENSE_VECTOR_FEATURE_FLAG); private final boolean enabled; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index e94fff4c682f1..e040067458408 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -2341,11 +2341,15 @@ public void testImplicitCasting() { public void testDenseVectorImplicitCastingKnn() { assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); - Analyzer analyzer = analyzer(loadMapping("mapping-dense_vector.json", "vectors")); + assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); - var plan = analyze(""" - from test | where knn(vector, [0.342, 0.164, 0.234], 10) - """, "mapping-dense_vector.json"); + checkDenseVectorCastingKnn("float_vector"); + } + + private static void checkDenseVectorCastingKnn(String fieldName) { + var plan = analyze(String.format(Locale.ROOT, """ + from test | where knn(%s, [0.342, 0.164, 0.234], 10) + """, fieldName), "mapping-dense_vector.json"); var limit = as(plan, Limit.class); var filter = as(limit.child(), Filter.class); @@ -2358,23 +2362,32 @@ public void testDenseVectorImplicitCastingKnn() { public void testDenseVectorImplicitCastingSimilarityFunctions() { if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f)); - checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(vector, [1, 2, 3])", List.of(1f, 2f, 3f)); + checkDenseVectorImplicitCastingSimilarityFunction( + "v_cosine(float_vector, [0.342, 0.164, 0.234])", + List.of(0.342f, 0.164f, 0.234f) + ); + checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(byte_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); } if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { checkDenseVectorImplicitCastingSimilarityFunction( - "v_dot_product(vector, [0.342, 0.164, 0.234])", + "v_dot_product(float_vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f) ); - checkDenseVectorImplicitCastingSimilarityFunction("v_dot_product(vector, [1, 2, 3])", List.of(1f, 2f, 3f)); + checkDenseVectorImplicitCastingSimilarityFunction("v_dot_product(byte_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); } if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f)); - checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(vector, [1, 2, 3])", List.of(1f, 2f, 3f)); + checkDenseVectorImplicitCastingSimilarityFunction( + "v_l1_norm(float_vector, [0.342, 0.164, 0.234])", + List.of(0.342f, 0.164f, 0.234f) + ); + checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(byte_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); } if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f)); - checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(vector, [1, 2, 3])", List.of(1f, 2f, 3f)); + checkDenseVectorImplicitCastingSimilarityFunction( + "v_l2_norm(float_vector, [0.342, 0.164, 0.234])", + List.of(0.342f, 0.164f, 0.234f) + ); + checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(float_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); } } @@ -2389,7 +2402,7 @@ private void checkDenseVectorImplicitCastingSimilarityFunction(String similarity assertEquals("similarity", alias.name()); var similarity = as(alias.child(), VectorSimilarityFunction.class); var left = as(similarity.left(), FieldAttribute.class); - assertEquals("vector", left.name()); + assertThat(List.of("float_vector", "byte_vector"), hasItem(left.name())); var right = as(similarity.right(), Literal.class); assertThat(right.dataType(), is(DENSE_VECTOR)); assertThat(right.value(), equalTo(expectedElems));