diff --git a/docs/changelog/137072.yaml b/docs/changelog/137072.yaml new file mode 100644 index 0000000000000..918d754a33ea2 --- /dev/null +++ b/docs/changelog/137072.yaml @@ -0,0 +1,5 @@ +pr: 137072 +summary: Adding base64 indexing for vector values +area: Vector Search +type: enhancement +issues: [] diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/170_knn_search_hex_encoded_byte_vectors.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/170_knn_search_hex_encoded_byte_vectors.yml index f989e17e6ec30..4ccff8982f606 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/170_knn_search_hex_encoded_byte_vectors.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/170_knn_search_hex_encoded_byte_vectors.yml @@ -65,7 +65,7 @@ setup: # [-128, 127, 10] - is encoded as '807f0a' - do: - catch: /Failed to parse object./ + catch: bad_request index: index: knn_hex_vector_index id: "5" diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/171_knn_index_base64_encoded_vectors.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/171_knn_index_base64_encoded_vectors.yml new file mode 100644 index 0000000000000..37cac56e56f56 --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/171_knn_index_base64_encoded_vectors.yml @@ -0,0 +1,229 @@ +setup: + - requires: + cluster_features: "mapper.base64_dense_vectors" + reason: 'base64 encoding for vectors feature required' + + - do: + indices.create: + index: knn_base64_vector_index + body: + settings: + number_of_shards: 1 + mappings: + dynamic: false + properties: + my_vector_byte: + type: dense_vector + dims: 3 + index : true + similarity : l2_norm + element_type: byte + my_vector_float: + type: dense_vector + dims: 3 + index: true + element_type: float + similarity : l2_norm + + # [0.8837743, 0.6310808, 0.7800066] - is encoded as 'P2I/CD8hjoM/R66D' + # [-128, 127, 10] - is encoded as 'gH8K' + - do: + index: + index: knn_base64_vector_index + id: "1" + body: + my_vector_float: "P2I/CD8hjoM/R66D" + my_vector_byte: "gH8K" + + + # [0.27721548, 0.9202792 , 0.46455473] - is encoded as 'Po3vMD9rl2s+7doe' + # [0, 1, 0] - is encoded as 'AAEA' + - do: + index: + index: knn_base64_vector_index + id: "2" + body: + my_vector_float: "Po3vMD9rl2s+7doe" + my_vector_byte: "AAEA" + + - do: + index: + index: knn_base64_vector_index + id: "3" + body: + my_vector_float: [0.2509804, -0.039215684, -0.11764706] + my_vector_byte: [64, -10, -30] + + - do: + indices.refresh: {} + +--- +"Fail to index hex-encoded vector on float field": + + # [-128, 127, 10] - is encoded as '807f0a' + - do: + catch: bad_request + index: + index: knn_base64_vector_index + id: "5" + body: + my_vector_float: "807f0a" + +--- +"Knn retrieve base64 encoded vectors" : + - do: + get: + index: knn_base64_vector_index + id: "1" + _source_exclude_vectors: false + + - match: { _source.my_vector_float: [0.8837743, 0.6310808, 0.7800066] } + - match: { _source.my_vector_byte: [-128, 127, 10] } +--- +"Base64 bytes infers the dimensions correctly": + - do: + indices.create: + index: knn_base64_vector_index_infer_dims + body: + settings: + number_of_shards: 1 + mappings: + dynamic: false + properties: + my_vector_byte: + type: dense_vector + index : true + similarity : l2_norm + element_type: byte + + # [-128, 127, 10, 0] - is encoded as 'gH8KAA==' + - do: + index: + index: knn_base64_vector_index_infer_dims + id: "1" + body: + my_vector_byte: "gH8KAA==" + + - do: + cluster.health: + wait_for_events: languid + + - do: + indices.get_mapping: + index: knn_base64_vector_index_infer_dims + + # sanity + - match: { knn_base64_vector_index_infer_dims.mappings.properties.my_vector_byte.type: dense_vector } + - match: { knn_base64_vector_index_infer_dims.mappings.properties.my_vector_byte.index: true } + - match: { knn_base64_vector_index_infer_dims.mappings.properties.my_vector_byte.dims: 4 } +--- +"Base64 floats infers the dimensions correctly": + - do: + indices.create: + index: knn_base64_vector_index_infer_dims + body: + settings: + number_of_shards: 1 + mappings: + dynamic: false + properties: + my_vector_byte: + type: dense_vector + index : true + similarity : l2_norm + element_type: float + + # [0.8837743, 0.6310808, 0.7800066, 0.0] - is encoded as 'P2I/CD8hjoM/R66DAAAAAA==' + - do: + index: + index: knn_base64_vector_index_infer_dims + id: "1" + body: + my_vector_byte: "P2I/CD8hjoM/R66DAAAAAA==" + + - do: + cluster.health: + wait_for_events: languid + - do: + indices.get_mapping: + index: knn_base64_vector_index_infer_dims + + # sanity + - match: { knn_base64_vector_index_infer_dims.mappings.properties.my_vector_byte.type: dense_vector } + - match: { knn_base64_vector_index_infer_dims.mappings.properties.my_vector_byte.index: true } + - match: { knn_base64_vector_index_infer_dims.mappings.properties.my_vector_byte.dims: 4 } +--- +"Retrieve Base64 encoded vectors when exclude vectors from source is false": + - do: + indices.create: + index: knn_base64_vector_index_with_source_vectors + body: + settings: + number_of_shards: 1 + index: + mapping: + exclude_source_vectors: false + mappings: + dynamic: false + properties: + my_vector_byte: + type: dense_vector + dims: 3 + index : true + similarity : l2_norm + element_type: byte + my_vector_float: + type: dense_vector + dims: 3 + index: true + element_type: float + similarity : l2_norm + + - do: + index: + index: knn_base64_vector_index_with_source_vectors + id: "1" + body: + my_vector_float: "P2I/CD8hjoM/R66D" + my_vector_byte: "gH8K" + + - do: + index: + index: knn_base64_vector_index_with_source_vectors + id: "3" + body: + my_vector_float: [0.2509804, -0.039215684, -0.11764706] + my_vector_byte: [64, -10, -30] + + - do: + indices.refresh: {} + + - do: + search: + index: knn_base64_vector_index_with_source_vectors + body: + query: + ids: + values: ["1"] + _source: false + fields: + - my_vector_float + - my_vector_byte + + - match: { hits.hits.0.fields.my_vector_float: ["P2I/CD8hjoM/R66D"] } + - match: { hits.hits.0.fields.my_vector_byte: ["gH8K"] } + + - do: + search: + index: knn_base64_vector_index_with_source_vectors + body: + query: + ids: + values: ["3"] + _source: false + fields: + - my_vector_float + - my_vector_byte + + - match: { hits.hits.0.fields.my_vector_float: [0.2509804, -0.039215684, -0.11764706] } + - match: { hits.hits.0.fields.my_vector_byte: [64, -10, -30] } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/175_knn_query_hex_encoded_byte_vectors.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/175_knn_query_hex_encoded_byte_vectors.yml index cd94275234661..428833bbe6dff 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/175_knn_query_hex_encoded_byte_vectors.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/175_knn_query_hex_encoded_byte_vectors.yml @@ -65,7 +65,7 @@ setup: # [-128, 127, 10] - is encoded as '807f0a' - do: - catch: /Failed to parse object./ + catch: bad_request index: index: knn_hex_vector_index id: "5" diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index 279a37fcc2061..b540cd8ab4a61 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -58,6 +58,7 @@ public class MapperFeatures implements FeatureSpecification { "mapper.ignore_dynamic_field_names_beyond_limit" ); static final NodeFeature EXCLUDE_VECTORS_DOCVALUE_BUGFIX = new NodeFeature("mapper.exclude_vectors_docvalue_bugfix"); + static final NodeFeature BASE64_DENSE_VECTORS = new NodeFeature("mapper.base64_dense_vectors"); @Override public Set getTestFeatures() { @@ -99,7 +100,8 @@ public Set getTestFeatures() { DISKBBQ_ON_DISK_RESCORING, PROVIDE_INDEX_SORT_SETTING_DEFAULTS, INDEX_MAPPING_IGNORE_DYNAMIC_BEYOND_FIELD_NAME_LIMIT, - EXCLUDE_VECTORS_DOCVALUE_BUGFIX + EXCLUDE_VECTORS_DOCVALUE_BUGFIX, + BASE64_DENSE_VECTORS ); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index a67ffd77e97f2..27d578d62befe 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -57,7 +57,6 @@ import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; -import org.elasticsearch.index.mapper.ArraySourceValueFetcher; import org.elasticsearch.index.mapper.BlockLoader; import org.elasticsearch.index.mapper.BlockSourceReader; import org.elasticsearch.index.mapper.DocumentParserContext; @@ -102,6 +101,8 @@ import java.time.ZoneId; import java.util.ArrayList; import java.util.Arrays; +import java.util.Base64; +import java.util.Collections; import java.util.HexFormat; import java.util.List; import java.util.Locale; @@ -109,6 +110,7 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.function.Function; import java.util.function.Supplier; import java.util.function.UnaryOperator; import java.util.stream.Stream; @@ -791,9 +793,13 @@ VectorData parseVectorArray(DocumentParserContext context, int dims, IntBooleanC return VectorData.fromBytes(vector); } - VectorData parseHexEncodedVector(DocumentParserContext context, IntBooleanConsumer dimChecker, VectorSimilarity similarity) - throws IOException { - byte[] decodedVector = HexFormat.of().parseHex(context.parser().text()); + VectorData parseStringValue( + String s, + IntBooleanConsumer dimChecker, + VectorSimilarity similarity, + Function decoder + ) { + byte[] decodedVector = decoder.apply(s); dimChecker.accept(decodedVector.length, true); VectorData vectorData = VectorData.fromBytes(decodedVector); double squaredMagnitude = computeSquaredMagnitude(vectorData); @@ -801,6 +807,14 @@ VectorData parseHexEncodedVector(DocumentParserContext context, IntBooleanConsum return vectorData; } + VectorData parseHexEncodedVector(String s, IntBooleanConsumer dimChecker, VectorSimilarity similarity) { + return parseStringValue(s, dimChecker, similarity, HexFormat.of()::parseHex); + } + + VectorData parseBase64EncodedVector(String s, IntBooleanConsumer dimChecker, VectorSimilarity similarity) { + return parseStringValue(s, dimChecker, similarity, Base64.getDecoder()::decode); + } + @Override public VectorData parseKnnVector( DocumentParserContext context, @@ -811,7 +825,22 @@ public VectorData parseKnnVector( XContentParser.Token token = context.parser().currentToken(); return switch (token) { case START_ARRAY -> parseVectorArray(context, dims, dimChecker, similarity); - case VALUE_STRING -> parseHexEncodedVector(context, dimChecker, similarity); + case VALUE_STRING -> { + String s = context.parser().text(); + if (s.length() == dims * 2) { + try { + yield parseHexEncodedVector(s, dimChecker, similarity); + } catch (IllegalArgumentException e) { + yield parseBase64EncodedVector(s, dimChecker, similarity); + } + } else { + try { + yield parseBase64EncodedVector(s, dimChecker, similarity); + } catch (IllegalArgumentException e) { + yield parseHexEncodedVector(s, dimChecker, similarity); + } + } + } default -> throw new ParsingException( context.parser().getTokenLocation(), format("Unsupported type [%s] for provided value [%s]", token, context.parser().text()) @@ -829,6 +858,20 @@ public ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes) { return ByteBuffer.wrap(new byte[numBytes]); } + static boolean isMaybeHexString(String s) { + int len = s.length(); + if (len % 2 != 0) { + return false; + } + for (int i = 0; i < len; i++) { + char c = s.charAt(i); + if (HexFormat.isHexDigit(c) == false) { + return false; + } + } + return true; + } + @Override public int parseDimensionCount(DocumentParserContext context) throws IOException { XContentParser.Token currentToken = context.parser().currentToken(); @@ -841,8 +884,21 @@ public int parseDimensionCount(DocumentParserContext context) throws IOException yield index; } case VALUE_STRING -> { - byte[] decodedVector = HexFormat.of().parseHex(context.parser().text()); - yield decodedVector.length; + String v = context.parser().text(); + // Base64 is always divisible by 4, so if it's not try hex + if (v.length() % 4 != 0) { + try { + yield HexFormat.of().parseHex(v).length; + } catch (IllegalArgumentException e) { + yield Base64.getDecoder().decode(v).length; + } + } else { + try { + yield Base64.getDecoder().decode(v).length; + } catch (IllegalArgumentException e) { + yield HexFormat.of().parseHex(v).length; + } + } } default -> throw new ParsingException( context.parser().getTokenLocation(), @@ -968,23 +1024,56 @@ public double computeSquaredMagnitude(VectorData vectorData) { } @Override - public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException { - int index = 0; - float[] vector = new float[fieldMapper.fieldType().dims]; - float squaredMagnitude = 0; - for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { - fieldMapper.checkDimensionExceeded(index, context); - ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()); + public int parseDimensionCount(DocumentParserContext context) throws IOException { + XContentParser.Token currentToken = context.parser().currentToken(); + return switch (currentToken) { - float value = context.parser().floatValue(true); - vector[index++] = value; - squaredMagnitude += value * value; - } - fieldMapper.checkDimensionMatches(index, context); - checkVectorBounds(vector); - checkVectorMagnitude(fieldMapper.fieldType().similarity, errorElementsAppender(vector), squaredMagnitude); - if (fieldMapper.fieldType().isNormalized() && isNotUnitVector(squaredMagnitude)) { - float length = (float) Math.sqrt(squaredMagnitude); + case START_ARRAY -> { + int index = 0; + for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { + index++; + } + yield index; + } + case VALUE_STRING -> { + byte[] decodedVectorBytes = Base64.getDecoder().decode(context.parser().text()); + if (decodedVectorBytes.length % Float.BYTES != 0) { + throw new ParsingException( + context.parser().getTokenLocation(), + "Failed to parse object: Base64 decoded vector byte length [" + + decodedVectorBytes.length + + "] is not a multiple of [" + + Float.BYTES + + "]" + ); + } + yield decodedVectorBytes.length / Float.BYTES; + } + default -> throw new ParsingException( + context.parser().getTokenLocation(), + format("Unsupported type [%s] for provided value [%s]", currentToken, context.parser().text()) + ); + }; + } + + @Override + public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException { + var vandm = parseFloatVectorInput(context, fieldMapper.fieldType().dims, (i, end) -> { + if (end) { + fieldMapper.checkDimensionMatches(i, context); + } else { + fieldMapper.checkDimensionExceeded(i, context); + } + }); + checkVectorBounds(vandm.vectorData.asFloatVector()); + checkVectorMagnitude( + fieldMapper.fieldType().similarity, + errorElementsAppender(vandm.vectorData().floatVector()), + vandm.squaredMagnitude + ); + float[] vector = vandm.vectorData.asFloatVector(); + if (fieldMapper.fieldType().isNormalized() && isNotUnitVector(vandm.squaredMagnitude)) { + float length = (float) Math.sqrt(vandm.squaredMagnitude); for (int i = 0; i < vector.length; i++) { vector[i] /= length; } @@ -1007,23 +1096,73 @@ public VectorData parseKnnVector( IntBooleanConsumer dimChecker, VectorSimilarity similarity ) throws IOException { + var v = parseFloatVectorInput(context, dims, (i, end) -> { + if (end) { + dimChecker.accept(i, true); + } else { + dimChecker.accept(i, false); + } + }); + checkVectorBounds(v.vectorData.asFloatVector()); + checkVectorMagnitude(similarity, errorElementsAppender(v.vectorData.asFloatVector()), v.squaredMagnitude); + return v.vectorData; + } + + VectorDataAndMagnitude parseFloatVectorInput(DocumentParserContext context, int dims, IntBooleanConsumer dimChecker) + throws IOException { + XContentParser.Token token = context.parser().currentToken(); + return switch (token) { + case START_ARRAY -> parseVectorArray(context, dimChecker, dims); + case VALUE_STRING -> parseBase64EncodedVector(context, dimChecker, dims); + default -> throw new ParsingException( + context.parser().getTokenLocation(), + format("Unsupported type [%s] for provided value [%s]", token, context.parser().text()) + ); + }; + } + + VectorDataAndMagnitude parseVectorArray(DocumentParserContext context, IntBooleanConsumer dimChecker, int dims) throws IOException { int index = 0; - float squaredMagnitude = 0; float[] vector = new float[dims]; - for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { + float squaredMagnitude = 0; + for (XContentParser.Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser() + .nextToken()) { dimChecker.accept(index, false); ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()); float value = context.parser().floatValue(true); - vector[index] = value; + vector[index++] = value; squaredMagnitude += value * value; - index++; } dimChecker.accept(index, true); - checkVectorBounds(vector); - checkVectorMagnitude(similarity, errorElementsAppender(vector), squaredMagnitude); - return VectorData.fromFloats(vector); + return new VectorDataAndMagnitude(VectorData.fromFloats(vector), squaredMagnitude); + } + + VectorDataAndMagnitude parseBase64EncodedVector(DocumentParserContext context, IntBooleanConsumer dimChecker, int dims) + throws IOException { + // BIG_ENDIAN is the default, but just being explicit here + ByteBuffer byteBuffer = ByteBuffer.wrap(Base64.getDecoder().decode(context.parser().text())).order(ByteOrder.BIG_ENDIAN); + if (byteBuffer.remaining() != dims * Float.BYTES) { + throw new ParsingException( + context.parser().getTokenLocation(), + "Failed to parse object: Base64 decoded vector byte length [" + + byteBuffer.remaining() + + "] does not match the expected length of [" + + (dims * Float.BYTES) + + "] for dimension count [" + + dims + + "]" + ); + } + float[] decodedVector = new float[dims]; + byteBuffer.asFloatBuffer().get(decodedVector); + dimChecker.accept(decodedVector.length, true); + VectorData vectorData = VectorData.fromFloats(decodedVector); + float squaredMagnitude = (float) computeSquaredMagnitude(vectorData); + return new VectorDataAndMagnitude(vectorData, squaredMagnitude); } + record VectorDataAndMagnitude(VectorData vectorData, float squaredMagnitude) {} + @Override public int getNumBytes(int dimensions) { return dimensions * Float.BYTES; @@ -1095,9 +1234,13 @@ VectorData parseVectorArray(DocumentParserContext context, int dims, IntBooleanC } @Override - VectorData parseHexEncodedVector(DocumentParserContext context, IntBooleanConsumer dimChecker, VectorSimilarity similarity) - throws IOException { - byte[] decodedVector = HexFormat.of().parseHex(context.parser().text()); + VectorData parseStringValue( + String s, + IntBooleanConsumer dimChecker, + VectorSimilarity similarity, + Function decoder + ) { + byte[] decodedVector = decoder.apply(s); dimChecker.accept(decodedVector.length * Byte.SIZE, true); return VectorData.fromBytes(decodedVector); } @@ -2299,19 +2442,40 @@ public String typeName() { @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + // TODO add support to `binary` and `vector` formats to unify the formats if (format != null) { throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't support formats."); } - return new ArraySourceValueFetcher(name(), context) { + Set sourcePaths = context.isSourceEnabled() ? context.sourcePath(name()) : Collections.emptySet(); + return new SourceValueFetcher(name(), context) { + @Override + public List fetchValues(Source source, int doc, List ignoredValues) { + ArrayList values = new ArrayList<>(); + for (var path : sourcePaths) { + Object sourceValue = source.extractValue(path, null); + if (sourceValue == null) { + return List.of(); + } + switch (sourceValue) { + case List v -> values.addAll(v); + case String s -> values.add(s); + default -> ignoredValues.add(sourceValue); + } + } + values.trimToSize(); + return values; + } + @Override protected Object parseSourceValue(Object value) { - return value; + throw new IllegalStateException("parsing dense vector from source is not supported here"); } }; } @Override public DocValueFormat docValueFormat(String format, ZoneId timeZone) { + // TODO we should add DENSE_VECTOR_BINARY? return DocValueFormat.DENSE_VECTOR; } @@ -2738,18 +2902,57 @@ public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) { private SourceValueFetcher sourceValueFetcher(Set sourcePaths, IndexSettings indexSettings) { return new SourceValueFetcher(sourcePaths, null, indexSettings.getIgnoredSourceFormat()) { @Override - protected Object parseSourceValue(Object value) { - if (value.equals("")) { - return null; + public List fetchValues(Source source, int doc, List ignoredValues) { + ArrayList values = new ArrayList<>(); + for (var path : sourcePaths) { + Object sourceValue = source.extractValue(path, null); + if (sourceValue == null) { + return List.of(); + } + try { + switch (sourceValue) { + case List v -> { + for (Object o : v) { + values.add(NumberFieldMapper.NumberType.FLOAT.parse(o, false)); + } + } + case String s -> { + if ((element.elementType() == ElementType.BYTE || element.elementType() == ElementType.BIT) + && s.length() == dims * 2 + && ByteElement.isMaybeHexString(s)) { + byte[] bytes; + try { + bytes = HexFormat.of().parseHex(s); + } catch (IllegalArgumentException e) { + bytes = Base64.getDecoder().decode(s); + } + for (byte b : bytes) { + values.add((float) b); + } + } else { + byte[] floatBytes = Base64.getDecoder().decode(s); + float[] floats = new float[dims]; + ByteBuffer.wrap(floatBytes).asFloatBuffer().get(floats); + for (float f : floats) { + values.add(f); + } + } + } + default -> ignoredValues.add(sourceValue); + } + } catch (Exception e) { + // if parsing fails here then it would have failed at index time + // as well, meaning that we must be ignoring malformed values. + ignoredValues.add(sourceValue); + } } - return NumberFieldMapper.NumberType.FLOAT.parse(value, false); + values.trimToSize(); + return values; } @Override - public List fetchValues(Source source, int doc, List ignoredValues) { - List result = super.fetchValues(source, doc, ignoredValues); - assert result.size() == dims : "Unexpected number of dimensions; got " + result.size() + " but expected " + dims; - return result; + protected Object parseSourceValue(Object value) { + throw new IllegalStateException("parsing dense vector from source is not supported here"); } }; } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 645c61ee45e2b..e2a30e7df9f31 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -59,8 +59,10 @@ import org.junit.AssumptionViolatedException; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; +import java.util.Base64; import java.util.List; import java.util.Map; import java.util.Set; @@ -149,12 +151,31 @@ private void indexMapping(XContentBuilder b, IndexVersion indexVersion) throws I } @Override - protected Object getSampleValueForDocument() { + protected Object getSampleValueForDocument(boolean binaryFormat) { + if (binaryFormat) { + final byte[] toEncode; + if (elementType == ElementType.FLOAT) { + float[] array = randomNormalizedVector(this.dims); + final ByteBuffer buffer = ByteBuffer.allocate(Float.BYTES * array.length); + buffer.asFloatBuffer().put(array); + toEncode = buffer.array(); + } else { + toEncode = elementType == ElementType.BIT + ? randomByteArrayOfLength(this.dims / Byte.SIZE) + : randomByteArrayOfLength(this.dims); + } + return Base64.getEncoder().encodeToString(toEncode); + } return elementType == ElementType.FLOAT ? convertToList(randomNormalizedVector(this.dims)) : convertToList(randomByteArrayOfLength(elementType == ElementType.BIT ? this.dims / Byte.SIZE : dims)); } + @Override + protected Object getSampleValueForDocument() { + return getSampleValueForDocument(randomBoolean()); + } + public static List convertToList(float[] vector) { List list = new ArrayList<>(vector.length); for (float v : vector) { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapperTests.java index 5448fd792625a..badb74fabd636 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapperTests.java @@ -97,6 +97,11 @@ protected Object getSampleValueForDocument() { ); } + @Override + protected Object getSampleValueForDocument(boolean binaryFormat) { + return getSampleValueForDocument(); + } + @Override protected Object getSampleObjectForDocument() { return getSampleValueForDocument(); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/SyntheticVectorsMapperTestCase.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/SyntheticVectorsMapperTestCase.java index dbf4bd9846165..35d3afcde8a69 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/SyntheticVectorsMapperTestCase.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/SyntheticVectorsMapperTestCase.java @@ -30,6 +30,9 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; public abstract class SyntheticVectorsMapperTestCase extends MapperTestCase { + + protected abstract Object getSampleValueForDocument(boolean binaryFormat); + public void testSyntheticVectorsMinimalValidDocument() throws IOException { for (XContentType type : XContentType.values()) { BytesReference source = generateRandomDoc(type, true, true, false, false, false); @@ -162,7 +165,7 @@ private BytesReference generateRandomDoc( } if (includeVector) { - builder.field("emb", getSampleValueForDocument()); + builder.field("emb", getSampleValueForDocument(false)); // builder.array("emb", new float[] { 1, 2, 3 }); } @@ -187,13 +190,13 @@ private BytesReference generateRandomDoc( if (includeDoubleNested) { builder.startObject(); // builder.array("emb", new float[] { 1, 2, 3 }); - builder.field("emb", getSampleValueForDocument()); + builder.field("emb", getSampleValueForDocument(false)); builder.field("field", "nested_val"); builder.startArray("double_nested"); for (int i = 0; i < 2; i++) { builder.startObject(); // builder.array("emb", new float[] { 1, 2, 3 }); - builder.field("emb", getSampleValueForDocument()); + builder.field("emb", getSampleValueForDocument(false)); builder.field("field", "dn_field"); builder.endObject(); } @@ -216,20 +219,20 @@ private BytesReference generateRandomDocWithFlatPath(XContentType xContentType) // Root-level fields builder.field("field", randomAlphaOfLengthBetween(1, 2)); - builder.field("emb", getSampleValueForDocument()); + builder.field("emb", getSampleValueForDocument(false)); builder.field("another_field", randomAlphaOfLengthBetween(3, 5)); // Simulated flattened "obj.nested" builder.startObject("obj.nested"); builder.field("field", randomAlphaOfLengthBetween(4, 8)); - builder.field("emb", getSampleValueForDocument()); + builder.field("emb", getSampleValueForDocument(false)); builder.startArray("double_nested"); for (int i = 0; i < randomIntBetween(1, 2); i++) { builder.startObject(); builder.field("field", randomAlphaOfLengthBetween(4, 8)); - builder.field("emb", getSampleValueForDocument()); + builder.field("emb", getSampleValueForDocument(false)); builder.endObject(); } builder.endArray(); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index fc24e19f1cad5..d852ddeaf1d8b 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -22,9 +22,12 @@ import org.junit.Before; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; +import java.util.Base64; import java.util.HashMap; +import java.util.HexFormat; import java.util.List; import java.util.Locale; import java.util.Map; @@ -32,12 +35,19 @@ import java.util.stream.Collectors; import static org.elasticsearch.index.IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING; +import static org.elasticsearch.index.IndexSettings.INDEX_MAPPING_EXCLUDE_SOURCE_VECTORS_SETTING; import static org.elasticsearch.index.mapper.SourceFieldMapper.Mode.SYNTHETIC; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION; public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { + private enum VectorSourceOptions { + DEFAULT, + SYNTHETIC, + INCLUDE_SOURCE_VECTORS + } + public static final Set ALL_DENSE_VECTOR_INDEX_TYPES = Arrays.stream(DenseVectorFieldMapper.VectorIndexType.values()) .map(v -> v.getName().toLowerCase(Locale.ROOT)) .collect(Collectors.toSet()); @@ -51,7 +61,7 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { private final ElementType elementType; private final DenseVectorFieldMapper.VectorSimilarity similarity; - private final boolean synthetic; + private final VectorSourceOptions sourceOptions; private final boolean index; @ParametersFactory @@ -64,13 +74,14 @@ public static Iterable parameters() throws Exception { if (elementType == ElementType.BIT && similarity != DenseVectorFieldMapper.VectorSimilarity.L2_NORM) { continue; } - params.add(new Object[] { elementType, similarity, true, false }); + params.add(new Object[] { elementType, similarity, true, VectorSourceOptions.DEFAULT }); + params.add(new Object[] { elementType, similarity, true, VectorSourceOptions.SYNTHETIC }); + params.add(new Object[] { elementType, similarity, true, VectorSourceOptions.INCLUDE_SOURCE_VECTORS }); } - // No indexing - params.add(new Object[] { elementType, null, false, false }); - // No indexing, synthetic source - params.add(new Object[] { elementType, null, false, true }); + params.add(new Object[] { elementType, null, false, VectorSourceOptions.DEFAULT }); + params.add(new Object[] { elementType, null, false, VectorSourceOptions.SYNTHETIC }); + params.add(new Object[] { elementType, null, false, VectorSourceOptions.INCLUDE_SOURCE_VECTORS }); } return params; @@ -80,12 +91,12 @@ public DenseVectorFieldTypeIT( @Name("elementType") ElementType elementType, @Name("similarity") DenseVectorFieldMapper.VectorSimilarity similarity, @Name("index") boolean index, - @Name("synthetic") boolean synthetic + @Name("sourceOptions") VectorSourceOptions sourceOptions ) { this.elementType = elementType; this.similarity = similarity; this.index = index; - this.synthetic = synthetic; + this.sourceOptions = sourceOptions; } private final Map> indexedVectors = new HashMap<>(); @@ -169,7 +180,7 @@ public void testRetrieveDenseVectorFieldData() { } public void testNonIndexedDenseVectorField() throws IOException { - createIndexWithDenseVector("no_dense_vectors"); + createIndexWithDenseVector("no_dense_vectors", 64); int numDocs = randomIntBetween(10, 100); IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; @@ -199,9 +210,11 @@ public void testNonIndexedDenseVectorField() throws IOException { @Before public void setup() throws IOException { - createIndexWithDenseVector("test"); - - int numDims = randomIntBetween(32, 64) * 2; // min 64, even number + int numDims = randomIntBetween(8, 16) * 8; // min 64, even number + createIndexWithDenseVector("test", numDims); + if (elementType == ElementType.BIT) { + numDims = numDims / 8; + } int numDocs = randomIntBetween(10, 100); IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; for (int i = 0; i < numDocs; i++) { @@ -222,7 +235,31 @@ public void setup() throws IOException { float magnitude = DenseVector.getMagnitude(vector); vector.replaceAll(number -> number.floatValue() / magnitude); } - docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector); + Object vectorToIndex; + if (randomBoolean()) { + vectorToIndex = vector; + } else { + // Test array input + vectorToIndex = switch (elementType) { + case FLOAT -> { + float[] array = new float[numDims]; + for (int k = 0; k < numDims; k++) { + array[k] = vector.get(k).floatValue(); + } + final ByteBuffer buffer = ByteBuffer.allocate(Float.BYTES * numDims); + buffer.asFloatBuffer().put(array); + yield Base64.getEncoder().encodeToString(buffer.array()); + } + case BYTE, BIT -> { + byte[] array = new byte[numDims]; + for (int k = 0; k < numDims; k++) { + array[k] = vector.get(k).byteValue(); + } + yield randomBoolean() ? Base64.getEncoder().encodeToString(array) : HexFormat.of().formatHex(array); + } + }; + } + docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vectorToIndex); indexedVectors.put(i, vector); } } @@ -230,7 +267,7 @@ public void setup() throws IOException { indexRandom(true, docs); } - private void createIndexWithDenseVector(String indexName) throws IOException { + private void createIndexWithDenseVector(String indexName, int dims) throws IOException { var client = client().admin().indices(); XContentBuilder mapping = XContentFactory.jsonBuilder() .startObject() @@ -240,6 +277,7 @@ private void createIndexWithDenseVector(String indexName) throws IOException { .endObject() .startObject("vector") .field("type", "dense_vector") + .field("dims", dims) .field("element_type", elementType.toString().toLowerCase(Locale.ROOT)) .field("index", index); if (index) { @@ -256,8 +294,14 @@ private void createIndexWithDenseVector(String indexName) throws IOException { Settings.Builder settingsBuilder = Settings.builder() .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)); - if (synthetic) { - settingsBuilder.put(INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SYNTHETIC); + switch (sourceOptions) { + // ensure vectors are actually in _source + case INCLUDE_SOURCE_VECTORS -> settingsBuilder.put(INDEX_MAPPING_EXCLUDE_SOURCE_VECTORS_SETTING.getKey(), false); + // ensure synthetic source is on + case SYNTHETIC -> settingsBuilder.put(INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SYNTHETIC); + // default, which is vectors outside of source and synthetic off + case DEFAULT -> { + } } var createRequest = client.prepareCreate(indexName) diff --git a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java index 5834aca5fa0a5..32863cbd96d25 100644 --- a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java +++ b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java @@ -89,6 +89,11 @@ private void indexMapping(XContentBuilder b, IndexVersion indexVersion) throws I } } + @Override + protected Object getSampleValueForDocument(boolean binaryFormat) { + return getSampleValueForDocument(); + } + @Override protected Object getSampleValueForDocument() { int numVectors = randomIntBetween(1, 16);