diff --git a/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java b/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java index 7b7fa6b4ab6d1..f517e52ce423a 100644 --- a/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java +++ b/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java @@ -96,7 +96,8 @@ public class CcsCommonYamlTestSuiteIT extends ESClientYamlSuiteTestCase { // geohex_grid requires gold license .setting("xpack.license.self_generated.type", "trial") .feature(FeatureFlag.TIME_SERIES_MODE) - .feature(FeatureFlag.SYNTHETIC_VECTORS); + .feature(FeatureFlag.SYNTHETIC_VECTORS) + .feature(FeatureFlag.GENERIC_VECTOR_FORMAT); private static ElasticsearchCluster remoteCluster = ElasticsearchCluster.local() .name(REMOTE_CLUSTER_NAME) diff --git a/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/RcsCcsCommonYamlTestSuiteIT.java b/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/RcsCcsCommonYamlTestSuiteIT.java index e5362c31f32f9..7d1ed9d92238a 100644 --- a/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/RcsCcsCommonYamlTestSuiteIT.java +++ b/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/RcsCcsCommonYamlTestSuiteIT.java @@ -96,6 +96,7 @@ public class RcsCcsCommonYamlTestSuiteIT extends ESClientYamlSuiteTestCase { .setting("xpack.security.remote_cluster_client.ssl.enabled", "false") .feature(FeatureFlag.TIME_SERIES_MODE) .feature(FeatureFlag.SYNTHETIC_VECTORS) + .feature(FeatureFlag.GENERIC_VECTOR_FORMAT) .user("test_admin", "x-pack-test-password"); private static ElasticsearchCluster fulfillingCluster = ElasticsearchCluster.local() diff --git a/qa/smoke-test-multinode/src/yamlRestTest/java/org/elasticsearch/smoketest/SmokeTestMultiNodeClientYamlTestSuiteIT.java b/qa/smoke-test-multinode/src/yamlRestTest/java/org/elasticsearch/smoketest/SmokeTestMultiNodeClientYamlTestSuiteIT.java index 529d7a7155264..8af1760f2ebd3 100644 --- a/qa/smoke-test-multinode/src/yamlRestTest/java/org/elasticsearch/smoketest/SmokeTestMultiNodeClientYamlTestSuiteIT.java +++ b/qa/smoke-test-multinode/src/yamlRestTest/java/org/elasticsearch/smoketest/SmokeTestMultiNodeClientYamlTestSuiteIT.java @@ -38,6 +38,7 @@ public class SmokeTestMultiNodeClientYamlTestSuiteIT extends ESClientYamlSuiteTe .feature(FeatureFlag.DOC_VALUES_SKIPPER) .feature(FeatureFlag.SYNTHETIC_VECTORS) .feature(FeatureFlag.RANDOM_SAMPLING) + .feature(FeatureFlag.GENERIC_VECTOR_FORMAT) .build(); public SmokeTestMultiNodeClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { diff --git a/rest-api-spec/src/yamlRestTest/java/org/elasticsearch/test/rest/ClientYamlTestSuiteIT.java b/rest-api-spec/src/yamlRestTest/java/org/elasticsearch/test/rest/ClientYamlTestSuiteIT.java index 2cebc6a743703..45b049784c380 100644 --- a/rest-api-spec/src/yamlRestTest/java/org/elasticsearch/test/rest/ClientYamlTestSuiteIT.java +++ b/rest-api-spec/src/yamlRestTest/java/org/elasticsearch/test/rest/ClientYamlTestSuiteIT.java @@ -38,6 +38,7 @@ public class ClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase { .feature(FeatureFlag.DOC_VALUES_SKIPPER) .feature(FeatureFlag.SYNTHETIC_VECTORS) .feature(FeatureFlag.RANDOM_SAMPLING) + .feature(FeatureFlag.GENERIC_VECTOR_FORMAT) .build(); public ClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml index 161fc23a84651..6edbcf5ef28ff 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/200_dense_vector_docvalue_fields.yml @@ -123,7 +123,6 @@ setup: - match: {hits.hits.0.fields.vector5.0: [1, 111, -13, 15, -128]} - match: {hits.hits.0.fields.vector6.0: [-1, 11, 0, 12, 111]} - - match: {hits.hits.1._id: "2"} - match: {hits.hits.1.fields.name.0: "moose.jpg"} @@ -143,7 +142,6 @@ setup: - match: {hits.hits.1.fields.vector4.0: [-1, 50, -1, 1, 120]} - match: {hits.hits.1.fields.vector5.0: [1, 111, -13, 15, -128]} - - match: {hits.hits.2._id: "3"} - match: {hits.hits.2.fields.name.0: "rabbit.jpg"} @@ -161,3 +159,103 @@ setup: - close_to: { hits.hits.2.fields.vector2.0.4: { value: -100.0, error: 0.001 } } - match: {hits.hits.2.fields.vector3.0: [-1, 100, -13, 15, -128]} + +--- +"dense_vector docvalues with bfloat16": + - requires: + cluster_features: [ "mapper.vectors.generic_vector_format" ] + reason: Needs generic vector support + - do: + indices.create: + index: test-bfloat16 + body: + mappings: + properties: + name: + type: keyword + vector7: + type: dense_vector + element_type: bfloat16 + dims: 5 + index: true + vector8: + type: dense_vector + element_type: bfloat16 + dims: 5 + index: false + + - do: + index: + index: test-bfloat16 + id: "1" + body: + name: cow.jpg + vector7: [230.0, 300.33, -34.8988, 15.555, -200.0] + vector8: [130.0, 115.0, -1.02, 15.555, -100.0] + - do: + index: + index: test-bfloat16 + id: "2" + body: + name: moose.jpg + vector7: [-0.5, 100.0, -13, 14.8, -156.0] + - do: + index: + index: test-bfloat16 + id: "3" + body: + name: rabbit.jpg + vector8: [130.0, 115.0, -1.02, 15.555, -100.0] + + - do: + indices.refresh: {} + + - do: + search: + _source: false + index: test-bfloat16 + body: + docvalue_fields: [name, vector7, vector8] + sort: name + + - match: {hits.hits.0._id: "1"} + - match: {hits.hits.0.fields.name.0: "cow.jpg"} + + - length: {hits.hits.0.fields.vector7.0: 5} + - length: {hits.hits.0.fields.vector8.0: 5} + + - close_to: { hits.hits.0.fields.vector7.0.0: { value: 230.0, error: 0.1 } } + - close_to: { hits.hits.0.fields.vector7.0.1: { value: 300.33, error: 0.1 } } + - close_to: { hits.hits.0.fields.vector7.0.2: { value: -34.8988, error: 0.1 } } + - close_to: { hits.hits.0.fields.vector7.0.3: { value: 15.555, error: 0.1 } } + - close_to: { hits.hits.0.fields.vector7.0.4: { value: -200.0, error: 0.1 } } + + - close_to: { hits.hits.0.fields.vector8.0.0: { value: 130.0, error: 0.1 } } + - close_to: { hits.hits.0.fields.vector8.0.1: { value: 115.0, error: 0.1 } } + - close_to: { hits.hits.0.fields.vector8.0.2: { value: -1.02, error: 0.1 } } + - close_to: { hits.hits.0.fields.vector8.0.3: { value: 15.555, error: 0.1 } } + - close_to: { hits.hits.0.fields.vector8.0.4: { value: -100.0, error: 0.1 } } + + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.1.fields.name.0: "moose.jpg"} + + - length: {hits.hits.1.fields.vector7.0: 5} + - match: {hits.hits.1.fields.vector8: null} + + - close_to: { hits.hits.1.fields.vector7.0.0: { value: -0.5, error: 0.1 } } + - close_to: { hits.hits.1.fields.vector7.0.1: { value: 100.0, error: 0.1 } } + - close_to: { hits.hits.1.fields.vector7.0.2: { value: -13, error: 0.1 } } + - close_to: { hits.hits.1.fields.vector7.0.3: { value: 14.8, error: 0.1 } } + - close_to: { hits.hits.1.fields.vector7.0.4: { value: -156.0, error: 0.1 } } + + - match: {hits.hits.2._id: "3"} + - match: {hits.hits.2.fields.name.0: "rabbit.jpg"} + + - length: {hits.hits.2.fields.vector8.0: 5} + - match: {hits.hits.2.fields.vector7: null} + + - close_to: { hits.hits.2.fields.vector8.0.0: { value: 130.0, error: 0.1 } } + - close_to: { hits.hits.2.fields.vector8.0.1: { value: 115.0, error: 0.1 } } + - close_to: { hits.hits.2.fields.vector8.0.2: { value: -1.02, error: 0.1 } } + - close_to: { hits.hits.2.fields.vector8.0.3: { value: 15.555, error: 0.1 } } + - close_to: { hits.hits.2.fields.vector8.0.4: { value: -100.0, error: 0.1 } } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml new file mode 100644 index 0000000000000..51adafc624469 --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search_bfloat16.yml @@ -0,0 +1,629 @@ +setup: + - requires: + cluster_features: [ "mapper.vectors.generic_vector_format" ] + reason: Needs generic vector support + - do: + indices.create: + index: test + body: + mappings: + properties: + name: + type: keyword + vector: + type: dense_vector + element_type: bfloat16 + dims: 5 + index: true + similarity: l2_norm + index_options: + type: hnsw + m: 16 + ef_construction: 200 + another_vector: + type: dense_vector + element_type: bfloat16 + dims: 5 + index: true + similarity: l2_norm + index_options: + type: hnsw + m: 16 + ef_construction: 200 + - do: + index: + index: test + id: "1" + body: + name: cow.jpg + vector: [ 230.0, 300.33, -34.8988, 15.555, -200.0 ] + another_vector: [ 130.0, 115.0, -1.02, 15.555, -100.0 ] + + - do: + index: + index: test + id: "2" + body: + name: moose.jpg + vector: [ -0.5, 100.0, -13, 14.8, -156.0 ] + another_vector: [ -0.5, 50.0, -1, 1, 120 ] + + - do: + index: + index: test + id: "3" + body: + name: rabbit.jpg + vector: [ 0.5, 111.3, -13.0, 14.8, -156.0 ] + another_vector: [ -0.5, 11.0, 0, 12, 111.0 ] + + - do: + indices.refresh: { } + +--- +"kNN search only": + - requires: + cluster_features: "gte_v8.4.0" + reason: 'kNN added to search endpoint in 8.4' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + + - match: { hits.hits.0._id: "2" } + - match: { hits.hits.0.fields.name.0: "moose.jpg" } + + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.1.fields.name.0: "rabbit.jpg" } +--- +"kNN multi-field search only": + - requires: + cluster_features: "gte_v8.7.0" + reason: 'multi-field kNN search added to search endpoint in 8.7' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + - { field: vector, query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ], k: 2, num_candidates: 3 } + - { field: another_vector, query_vector: [ -0.5, 11.0, 0, 12, 111.0 ], k: 2, num_candidates: 3 } + + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0.fields.name.0: "rabbit.jpg" } + + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1.fields.name.0: "moose.jpg" } +--- +"kNN search plus query": + - requires: + cluster_features: "gte_v8.4.0" + reason: 'kNN added to search endpoint in 8.4' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + query: + term: + name: cow.jpg + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0.fields.name.0: "cow.jpg" } + + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1.fields.name.0: "moose.jpg" } + + - match: { hits.hits.2._id: "3" } + - match: { hits.hits.2.fields.name.0: "rabbit.jpg" } +--- +"kNN multi-field search with query": + - requires: + cluster_features: "gte_v8.7.0" + reason: 'multi-field kNN search added to search endpoint in 8.7' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + - { field: vector, query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ], k: 2, num_candidates: 3 } + - { field: another_vector, query_vector: [ -0.5, 11.0, 0, 12, 111.0 ], k: 2, num_candidates: 3 } + query: + term: + name: cow.jpg + + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0.fields.name.0: "rabbit.jpg" } + + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.1.fields.name.0: "cow.jpg" } + + - match: { hits.hits.2._id: "2" } + - match: { hits.hits.2.fields.name.0: "moose.jpg" } +--- +"kNN search with filter": + - requires: + cluster_features: "gte_v8.4.0" + reason: 'kNN added to search endpoint in 8.4' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + filter: + term: + name: "rabbit.jpg" + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0.fields.name.0: "rabbit.jpg" } + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + filter: + - term: + name: "rabbit.jpg" + - term: + _id: 2 + + - match: { hits.total.value: 0 } + +--- +"kNN search with explicit search_type": + - requires: + cluster_features: "gte_v8.4.0" + reason: 'kNN added to search endpoint in 8.4' + - do: + catch: bad_request + search: + index: test + search_type: query_then_fetch + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + + - match: { error.root_cause.0.type: "illegal_argument_exception" } + - match: { error.root_cause.0.reason: "cannot set [search_type] when using [knn] search, since the search type is determined automatically" } + +--- +"Test nonexistent field is match none": + - requires: + cluster_features: "gte_v8.16.0" + reason: 'non-existent field handling improved in 8.16' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: nonexistent + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + + - length: { hits.hits: 0 } + + - do: + indices.create: + index: test_nonexistent + body: + mappings: + properties: + name: + type: keyword + vector: + type: dense_vector + element_type: float + dims: 5 + index: true + similarity: l2_norm + settings: + index.query.parse.allow_unmapped_fields: false + + - do: + catch: bad_request + search: + index: test_nonexistent + body: + fields: [ "name" ] + knn: + field: nonexistent + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + + - match: { error.root_cause.0.type: "query_shard_exception" } + - match: { error.root_cause.0.reason: "No field mapping can be found for the field with name [nonexistent]" } + +--- +"KNN Vector similarity search only": + - requires: + cluster_features: "gte_v8.8.0" + reason: 'kNN similarity added in 8.8' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + num_candidates: 3 + k: 3 + field: vector + similarity: 11 + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + + - length: { hits.hits: 1 } + + - match: { hits.hits.0._id: "2" } + - match: { hits.hits.0.fields.name.0: "moose.jpg" } +--- +"Vector similarity with filter only": + - requires: + cluster_features: "gte_v8.8.0" + reason: 'kNN similarity added in 8.8' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + num_candidates: 3 + k: 3 + field: vector + similarity: 11 + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + filter: { "term": { "name": "moose.jpg" } } + + - length: { hits.hits: 1 } + + - match: { hits.hits.0._id: "2" } + - match: { hits.hits.0.fields.name.0: "moose.jpg" } + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + num_candidates: 3 + k: 3 + field: vector + similarity: 110 + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + filter: { "term": { "name": "cow.jpg" } } + + - length: { hits.hits: 0 } +--- +"Knn search with mip": + - requires: + cluster_features: "gte_v8.11.0" + reason: 'mip similarity added in 8.11' + test_runner_features: "close_to" + + - do: + indices.create: + index: mip + body: + mappings: + properties: + name: + type: keyword + vector: + type: dense_vector + dims: 5 + index: true + similarity: max_inner_product + index_options: + type: hnsw + m: 16 + ef_construction: 200 + + - do: + index: + index: mip + id: "1" + body: + name: cow.jpg + vector: [ 230.0, 300.33, -34.8988, 15.555, -200.0 ] + + - do: + index: + index: mip + id: "2" + body: + name: moose.jpg + vector: [ -0.5, 100.0, -13, 14.8, -156.0 ] + + - do: + index: + index: mip + id: "3" + body: + name: rabbit.jpg + vector: [ 0.5, 111.3, -13.0, 14.8, -156.0 ] + + - do: + indices.refresh: { } + + - do: + search: + index: mip + body: + fields: [ "name" ] + knn: + num_candidates: 3 + k: 3 + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + + + - length: { hits.hits: 3 } + - match: { hits.hits.0._id: "1" } + - close_to: { hits.hits.0._score: { value: 58694.902, error: 0.01 } } + - match: { hits.hits.1._id: "3" } + - close_to: { hits.hits.1._score: { value: 34702.79, error: 0.01 } } + - match: { hits.hits.2._id: "2" } + - close_to: { hits.hits.2._score: { value: 33686.29, error: 0.01 } } + + - do: + search: + index: mip + body: + fields: [ "name" ] + knn: + num_candidates: 3 + k: 3 + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + filter: { "term": { "name": "moose.jpg" } } + + + + - length: { hits.hits: 1 } + - match: { hits.hits.0._id: "2" } + - close_to: { hits.hits.0._score: { value: 33686.29, error: 0.01 } } +--- +"Knn search with _name": + - requires: + cluster_features: "gte_v8.15.0" + reason: 'support for _name in knn was added in 8.15' + test_runner_features: "close_to" + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 3 + num_candidates: 3 + _name: "my_knn_query" + query: + term: + name: + term: cow.jpg + _name: "my_query" + + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0.fields.name.0: "cow.jpg" } + - match: { hits.hits.0.matched_queries.0: "my_knn_query" } + - match: { hits.hits.0.matched_queries.1: "my_query" } + + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1.fields.name.0: "moose.jpg" } + - match: { hits.hits.1.matched_queries.0: "my_knn_query" } + + - match: { hits.hits.2._id: "3" } + - match: { hits.hits.2.fields.name.0: "rabbit.jpg" } + - match: { hits.hits.2.matched_queries.0: "my_knn_query" } + +--- +"kNN search on empty index should return 0 results and not an error": + - requires: + cluster_features: "gte_v8.15.1" + reason: 'Error fixed in 8.15.1' + - do: + indices.create: + index: test_empty + body: + mappings: + properties: + vector: + type: dense_vector + - do: + search: + index: test_empty + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + + - match: { hits.total.value: 0 } +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore_oversample] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } + +--- +"Dimensions are dynamically set": + - do: + indices.create: + index: test_index + body: + mappings: + properties: + embedding: + type: dense_vector + + - do: + index: + index: test_index + id: "0" + refresh: true + body: + embedding: [ 0.5, 111.3, -13.0, 14.8, -156.0 ] + + # wait and ensure that the mapping update is replicated + - do: + cluster.health: + wait_for_events: languid + + - do: + indices.get_mapping: + index: test_index + + - match: { test_index.mappings.properties.embedding.type: dense_vector } + - match: { test_index.mappings.properties.embedding.dims: 5 } + + - do: + catch: bad_request + index: + index: test_index + id: "0" + body: + embedding: [ 0.5, 111.3 ] + +--- +"Updating dim to null is not allowed": + - requires: + cluster_features: "mapper.npe_on_dims_update_fix" + reason: "dims update fix" + - do: + indices.create: + index: test_index + + - do: + indices.put_mapping: + index: test_index + body: + properties: + embedding: + type: dense_vector + dims: 4 + - do: + catch: bad_request + indices.put_mapping: + index: test_index + body: + properties: + embedding: + type: dense_vector + + +--- +"Searching with no data dimensions specified": + - requires: + cluster_features: "search.vectors.no_dimensions_bugfix" + reason: "Search with no dimensions bugfix" + + - do: + indices.create: + index: empty-test + body: + mappings: + properties: + vector: + type: dense_vector + index: true + + - do: + search: + index: empty-test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 1.5 + similarity: 0.1 + + - match: { hits.total.value: 0 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index e3c1155ed2000..ed15ed4d09806 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -338,6 +338,68 @@ setup: - match: { hits.hits.1._score: $rescore_score1 } - match: { hits.hits.2._score: $rescore_score2 } --- +"Test index configured rescore vector with on-disk rescoring": + - requires: + cluster_features: [ "mapper.vectors.generic_vector_format" ] + reason: Needs generic vector support + - skip: + features: "headers" + - do: + indices.create: + index: bbq_on_disk_rescore_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + on_disk_rescore: true + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_on_disk_rescore_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_on_disk_rescore_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } +--- "Test index configured rescore vector updateable and settable to 0": - requires: cluster_features: ["mapper.dense_vector.rescore_zero_vector"] diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml new file mode 100644 index 0000000000000..358089c5342ad --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml @@ -0,0 +1,580 @@ +setup: + - requires: + cluster_features: [ "mapper.vectors.generic_vector_format" ] + reason: Needs generic vector support + - do: + indices.create: + index: bbq_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "1" + body: + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "2" + body: + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "3" + body: + name: rabbit.jpg + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + indices.forcemerge: + index: bbq_hnsw + max_num_segments: 1 + + - do: + indices.refresh: { } +--- +"Test knn search": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ optimized_scalar_quantization_bbq ] + test_runner_features: capabilities + reason: "BBQ scoring improved and changed with optimized_scalar_quantization_bbq" + - do: + search: + index: bbq_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } +--- +"Vector rescoring has same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore_oversample] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- +"Test bad quantization parameters": + - do: + catch: bad_request + indices.create: + index: bad_bbq_hnsw + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: false + index_options: + type: bbq_hnsw +--- +"Test few dimensions fail indexing": + - do: + catch: bad_request + indices.create: + index: bad_bbq_hnsw + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 42 + index: true + index_options: + type: bbq_hnsw + + - do: + indices.create: + index: dynamic_dim_bbq_hnsw + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + index: true + similarity: l2_norm + index_options: + type: bbq_hnsw + + - do: + catch: bad_request + index: + index: dynamic_dim_bbq_hnsw + body: + vector: [1.0, 2.0, 3.0, 4.0, 5.0] + + - do: + index: + index: dynamic_dim_bbq_hnsw + body: + vector: [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_rescore_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_hnsw + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } +--- +"Test index configured rescore vector updateable and settable to 0": + - requires: + cluster_features: ["mapper.dense_vector.rescore_zero_vector"] + reason: Needs rescore_zero_vector feature + + - do: + indices.create: + index: bbq_rescore_0_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 0 + + - do: + indices.create: + index: bbq_rescore_update_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 1 + + - do: + indices.put_mapping: + index: bbq_rescore_update_hnsw + body: + properties: + vector: + type: dense_vector + element_type: bfloat16 + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 0 + + - do: + indices.get_mapping: + index: bbq_rescore_update_hnsw + + - match: { .bbq_rescore_update_hnsw.mappings.properties.vector.index_options.rescore_vector.oversample: 0 } +--- +"Test index configured rescore vector score consistency": + - requires: + cluster_features: ["mapper.dense_vector.rescore_zero_vector"] + reason: Needs rescore_zero_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_zero_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 0 + + - do: + bulk: + index: bbq_rescore_zero_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: raw_score0 } + - set: { hits.hits.1._score: raw_score1 } + - set: { hits.hits.2._score: raw_score2 } + + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 2 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: override_score0 } + - set: { hits.hits.1._score: override_score1 } + - set: { hits.hits.2._score: override_score2 } + + - do: + indices.put_mapping: + index: bbq_rescore_zero_hnsw + body: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 2 + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: default_rescore0 } + - set: { hits.hits.1._score: default_rescore1 } + - set: { hits.hits.2._score: default_rescore2 } + + - do: + indices.put_mapping: + index: bbq_rescore_zero_hnsw + body: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 0 + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_hnsw + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $override_score0 } + - match: { hits.hits.0._score: $default_rescore0 } + - match: { hits.hits.1._score: $override_score1 } + - match: { hits.hits.1._score: $default_rescore1 } + - match: { hits.hits.2._score: $override_score2 } + - match: { hits.hits.2._score: $default_rescore2 } + +--- +"default oversample value": + - requires: + cluster_features: ["mapper.dense_vector.default_oversample_value_for_bbq"] + reason: "Needs default_oversample_value_for_bbq feature" + - do: + indices.get_mapping: + index: bbq_hnsw + + - match: { bbq_hnsw.mappings.properties.vector.index_options.rescore_vector.oversample: 3.0 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml new file mode 100644 index 0000000000000..6cbeb9ecfd189 --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml @@ -0,0 +1,512 @@ +setup: + - requires: + cluster_features: [ "mapper.vectors.generic_vector_format" ] + reason: Needs generic vector support + - do: + indices.create: + index: bbq_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_flat + + - do: + index: + index: bbq_flat + id: "1" + body: + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_flat + + - do: + index: + index: bbq_flat + id: "2" + body: + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_flat + + - do: + index: + index: bbq_flat + id: "3" + body: + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_flat + + - do: + indices.forcemerge: + index: bbq_flat + max_num_segments: 1 +--- +"Test knn search": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ optimized_scalar_quantization_bbq ] + test_runner_features: capabilities + reason: "BBQ scoring improved and changed with optimized_scalar_quantization_bbq" + - do: + search: + index: bbq_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } +--- +"Vector rescoring has same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore_oversample] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17, + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_flat + body: + query: + script_score: + query: { match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17, + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- +"Test bad parameters": + - do: + catch: bad_request + indices.create: + index: bad_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + index_options: + type: bbq_flat + m: 42 +--- +"Test bad raw vector size": + - do: + catch: bad_request + indices.create: + index: bad_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + index_options: + type: bbq_flat + raw_vector_size: 25 +--- +"Test few dimensions fail indexing": + # verify index creation fails + - do: + catch: bad_request + indices.create: + index: bad_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 42 + index: true + similarity: l2_norm + index_options: + type: bbq_flat + + # verify dynamic dimension fails + - do: + indices.create: + index: dynamic_dim_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + index: true + similarity: l2_norm + index_options: + type: bbq_flat + + # verify index fails for odd dim vector + - do: + catch: bad_request + index: + index: dynamic_dim_bbq_flat + body: + vector: [1.0, 2.0, 3.0, 4.0, 5.0] + + # verify that we can index an even dim vector after the odd dim vector failure + - do: + index: + index: dynamic_dim_bbq_flat + body: + vector: [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_flat + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_rescore_flat + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_flat + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- +"default oversample value": + - requires: + cluster_features: ["mapper.dense_vector.default_oversample_value_for_bbq"] + reason: "Needs default_oversample_value_for_bbq feature" + - do: + indices.get_mapping: + index: bbq_flat + + - match: { bbq_flat.mappings.properties.vector.index_options.rescore_vector.oversample: 3.0 } +--- +"Test nested queries": + - do: + indices.create: + index: bbq_flat_nested + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + name: + type: keyword + nested: + type: nested + properties: + paragraph_id: + type: keyword + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_flat + + - do: + index: + index: bbq_flat_nested + id: "1" + body: + nested: + - paragraph_id: "1" + vector: [ 0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45 ] + - paragraph_id: "2" + vector: [ 0.7, 0.2 , 0.205, 0.63 , 0.032, 0.201, 0.167, 0.313, + 0.176, 0.1, 0.375, 0.334, 0.046, 0.078, 0.349, 0.272, + 0.307, 0.083, 0.504, 0.255, 0.404, 0.289, 0.226, 0.132, + 0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , 0.265, + 0.285, 0.336, 0.272, 0.369, -0.282, 0.086, 0.132, 0.475, + 0.224, 0.203, 0.439, 0.064, 0.246, 0.396, 0.297, 0.242, + 0.224, 0.203, 0.439, 0.064, 0.246, 0.396, 0.297, 0.242, + 0.028, 0.321, 0.022, 0.009, 0.001 , 0.031, -0.533, 0.45] + - do: + index: + index: bbq_flat_nested + id: "2" + body: + nested: + - paragraph_id: 0 + vector: [ 0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27, -0.013 ] + - paragraph_id: 2 + vector: [ 0.196, 0.514, 0.039, 0.555, 0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, 0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, 0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, 0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, 0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27, 0.013 ] + - paragraph_id: 3 + vector: [ 0.196, 0.514, 0.039, 0.555, 0.042, 0.242, 0.463, -0.348, + 0.08 , 0.442, -0.067, -0.05 , 0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, 0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, 0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, 0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, 0.209, -0.153, -0.27, -0.013 ] + + - do: + index: + index: bbq_flat_nested + id: "3" + body: + nested: + - paragraph_id: 0 + vector: [ 0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058 ] + + - do: + indices.flush: + index: bbq_flat_nested + + - do: + indices.forcemerge: + index: bbq_flat_nested + max_num_segments: 1 + + - do: + search: + index: bbq_flat_nested + body: + query: + nested: + path: nested + query: + knn: + field: nested.vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + num_candidates: 3 + k: 2 + + - match: {hits.hits.0._id: "3"} + + - do: + search: + index: bbq_flat_nested + body: + knn: + field: nested.vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + num_candidates: 3 + k: 2 + + - match: {hits.hits.0._id: "3"} diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml new file mode 100644 index 0000000000000..b47d337120c54 --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/46_knn_search_bbq_ivf_bfloat16.yml @@ -0,0 +1,629 @@ +setup: + - requires: + cluster_features: [ "mapper.vectors.generic_vector_format" ] + reason: Needs generic vector support + - skip: + features: "headers" + - do: + indices.create: + index: bbq_disk + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_disk + + - do: + index: + index: bbq_disk + id: "1" + body: + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_disk + + - do: + index: + index: bbq_disk + id: "2" + body: + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_disk + + - do: + index: + index: bbq_disk + id: "3" + body: + name: rabbit.jpg + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_disk + + - do: + indices.forcemerge: + index: bbq_disk + max_num_segments: 1 + + - do: + indices.refresh: { } +--- +"Test knn search": + - do: + search: + index: bbq_disk + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } +--- +"Test knn search with visit_percentage": + - do: + search: + index: bbq_disk + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + visit_percentage: 1.0 + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } +--- +"Vector rescoring has same scoring as exact search for kNN section": + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_disk + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- +"Test bad quantization parameters": + - do: + catch: bad_request + indices.create: + index: bad_bbq_ivf + body: + mappings: + properties: + vector: + type: dense_vector + dims: 64 + element_type: byte + index: true + index_options: + type: bbq_disk + + - do: + catch: bad_request + indices.create: + index: bad_bbq_ivf + body: + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: false + index_options: + type: bbq_disk +--- +"Test index configured rescore vector": + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_ivf + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_disk + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_rescore_ivf + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_ivf + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_ivf + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- +"Test index configured rescore vector with on-disk rescore": + - requires: + cluster_features: [ "mapper.vectors.diskbbq_on_disk_rescoring" ] + reason: Needs on_disk_rescoring feature for DiskBBQ + - skip: + features: "headers" + - do: + indices.create: + index: bbq_on_disk_rescore_ivf + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_disk + on_disk_rescore: true + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_on_disk_rescore_ivf + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_on_disk_rescore_ivf + body: + knn: + field: vector + query_vector: [ 0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158 ] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_on_disk_rescore_ivf + body: + query: + script_score: + query: { match_all: { } } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [ 0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158 ] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } +--- +"Test index configured rescore vector updateable and settable to 0": + - do: + indices.create: + index: bbq_rescore_0_ivf + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + index_options: + type: bbq_disk + rescore_vector: + oversample: 0 + + - do: + indices.create: + index: bbq_rescore_update_ivf + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + index_options: + type: bbq_disk + rescore_vector: + oversample: 1 + + - do: + indices.put_mapping: + index: bbq_rescore_update_ivf + body: + properties: + vector: + type: dense_vector + index_options: + type: bbq_disk + rescore_vector: + oversample: 0 + + - do: + indices.get_mapping: + index: bbq_rescore_update_ivf + + - match: { .bbq_rescore_update_ivf.mappings.properties.vector.index_options.rescore_vector.oversample: 0 } +--- +"Test index configured rescore vector score consistency": + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_zero_ivf + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_disk + rescore_vector: + oversample: 0 + + - do: + bulk: + index: bbq_rescore_zero_ivf + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_ivf + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_ivf + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 2 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: override_score0 } + - set: { hits.hits.1._score: override_score1 } + - set: { hits.hits.2._score: override_score2 } + + - do: + indices.put_mapping: + index: bbq_rescore_zero_ivf + body: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_disk + rescore_vector: + oversample: 2 + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_ivf + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: default_rescore0 } + - set: { hits.hits.1._score: default_rescore1 } + - set: { hits.hits.2._score: default_rescore2 } + + - do: + indices.put_mapping: + index: bbq_rescore_zero_ivf + body: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_disk + rescore_vector: + oversample: 0 + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_ivf + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $override_score0 } + - match: { hits.hits.0._score: $default_rescore0 } + - match: { hits.hits.1._score: $override_score1 } + - match: { hits.hits.1._score: $default_rescore1 } + - match: { hits.hits.2._score: $override_score2 } + - match: { hits.hits.2._score: $default_rescore2 } + +--- +"default oversample value": + - do: + indices.get_mapping: + index: bbq_disk + + - match: { bbq_disk.mappings.properties.vector.index_options.rescore_vector.oversample: 3.0 } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java index efbc19b30079c..c8ceb200ea0b3 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/store/DirectIOIT.java @@ -73,7 +73,7 @@ protected boolean useDirectIO(String name, IOContext context, OptionalLong fileL @ParametersFactory public static Iterable parameters() { - return List.of(new Object[] { "bbq_disk" }); + return List.of(new Object[] { "bbq_hnsw" }, new Object[] { "bbq_disk" }); } public DirectIOIT(String type) { @@ -113,15 +113,14 @@ private String indexVectors(boolean directIO) { indexDoc(indexName, Integer.toString(i), "fooVector", IntStream.range(0, 64).mapToDouble(d -> randomFloat()).toArray()); } refresh(); - assertBBQIndexType(indexName, type); // test assertion to ensure that the correct index type is being used + assertIndexType(indexName, type); // test assertion to ensure that the correct index type is being used return indexName; } - @SuppressWarnings("unchecked") - static void assertBBQIndexType(String indexName, String type) { + static void assertIndexType(String indexName, String type) { var response = indicesAdmin().prepareGetFieldMappings(indexName).setFields("fooVector").get(); - var map = (Map) response.fieldMappings(indexName, "fooVector").sourceAsMap().get("fooVector"); - assertThat((String) ((Map) map.get("index_options")).get("type"), is(equalTo(type))); + var map = (Map) response.fieldMappings(indexName, "fooVector").sourceAsMap().get("fooVector"); + assertThat(((Map) map.get("index_options")).get("type"), is(equalTo(type))); } @TestLogging(value = "org.elasticsearch.index.store.FsDirectoryFactory:DEBUG", reason = "to capture trace logging for direct IO") diff --git a/server/src/main/java/org/elasticsearch/index/IndexVersions.java b/server/src/main/java/org/elasticsearch/index/IndexVersions.java index 172bdc67e7872..1352c40890fc4 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexVersions.java +++ b/server/src/main/java/org/elasticsearch/index/IndexVersions.java @@ -189,7 +189,6 @@ private static Version parseUnchecked(String version) { public static final IndexVersion BACKPORT_UPGRADE_TO_LUCENE_10_3_1 = def(9_039_0_01, Version.LUCENE_10_3_1); public static final IndexVersion KEYWORD_MULTI_FIELDS_NOT_STORED_WHEN_IGNORED = def(9_040_0_00, Version.LUCENE_10_3_0); public static final IndexVersion UPGRADE_TO_LUCENE_10_3_1 = def(9_041_0_00, Version.LUCENE_10_3_1); - public static final IndexVersion REENABLED_TIMESTAMP_DOC_VALUES_SPARSE_INDEX = def(9_042_0_00, Version.LUCENE_10_3_1); public static final IndexVersion SKIPPERS_ENABLED_BY_DEFAULT = def(9_043_0_00, Version.LUCENE_10_3_1); public static final IndexVersion TIME_SERIES_USE_SYNTHETIC_ID = def(9_044_0_00, Version.LUCENE_10_3_1); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java index 8d25ab54d8ca1..8e558fe5f9191 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java @@ -11,7 +11,6 @@ import org.apache.lucene.util.BitUtil; -import java.nio.ByteOrder; import java.nio.ShortBuffer; public final class BFloat16 { @@ -29,13 +28,15 @@ public static short floatToBFloat16(float f) { return (short) (Float.floatToIntBits(f) >>> 16); } + public static float truncateToBFloat16(float f) { + return Float.intBitsToFloat(Float.floatToIntBits(f) & 0xffff0000); + } + public static float bFloat16ToFloat(short bf) { return Float.intBitsToFloat(bf << 16); } public static void floatToBFloat16(float[] floats, ShortBuffer bFloats) { - assert bFloats.remaining() == floats.length; - assert bFloats.order() == ByteOrder.LITTLE_ENDIAN; for (float v : floats) { bFloats.put(floatToBFloat16(v)); } @@ -49,8 +50,6 @@ public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) { } public static void bFloat16ToFloat(ShortBuffer bFloats, float[] floats) { - assert floats.length == bFloats.remaining(); - assert bFloats.order() == ByteOrder.LITTLE_ENDIAN; for (int i = 0; i < floats.length; i++) { floats[i] = bFloat16ToFloat(bFloats.get()); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java index 99bc9a9d7bdb2..25b711da8c18f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java @@ -18,6 +18,8 @@ import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.es93.DirectIOCapableLucene99FlatVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93BFloat16FlatVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.Map; @@ -61,9 +63,14 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat { private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); + private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); private static final Map supportedFormats = Map.of( float32VectorFormat.getName(), - float32VectorFormat + float32VectorFormat, + bfloat16VectorFormat.getName(), + bfloat16VectorFormat ); // This dynamically sets the cluster probe based on the `k` requested and the number of clusters. @@ -78,14 +85,19 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat { private final int vectorPerCluster; private final int centroidsPerParentCluster; - private final boolean useDirectIO; private final DirectIOCapableFlatVectorsFormat rawVectorFormat; + private final boolean useDirectIO; public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster) { - this(vectorPerCluster, centroidsPerParentCluster, false); + this(vectorPerCluster, centroidsPerParentCluster, DenseVectorFieldMapper.ElementType.FLOAT, false); } - public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useDirectIO) { + public ES920DiskBBQVectorsFormat( + int vectorPerCluster, + int centroidsPerParentCluster, + DenseVectorFieldMapper.ElementType elementType, + boolean useDirectIO + ) { super(NAME); if (vectorPerCluster < MIN_VECTORS_PER_CLUSTER || vectorPerCluster > MAX_VECTORS_PER_CLUSTER) { throw new IllegalArgumentException( @@ -109,8 +121,12 @@ public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentClu } this.vectorPerCluster = vectorPerCluster; this.centroidsPerParentCluster = centroidsPerParentCluster; + this.rawVectorFormat = switch (elementType) { + case FLOAT -> float32VectorFormat; + case BFLOAT16 -> bfloat16VectorFormat; + default -> throw new IllegalArgumentException("Unsupported element type " + elementType); + }; this.useDirectIO = useDirectIO; - this.rawVectorFormat = float32VectorFormat; } /** Constructs a format using the given graph construction parameters and scalar quantization. */ diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java index ed224c82a5aaa..290b010fef78c 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java @@ -30,6 +30,7 @@ import org.elasticsearch.index.codec.vectors.es818.ES818BinaryFlatVectorsScorer; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsWriter; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; @@ -97,10 +98,10 @@ public class ES93BinaryQuantizedVectorsFormat extends AbstractFlatVectorsFormat private final ES93GenericFlatVectorsFormat rawFormat; public ES93BinaryQuantizedVectorsFormat() { - this(ES93GenericFlatVectorsFormat.ElementType.STANDARD, false); + this(DenseVectorFieldMapper.ElementType.FLOAT, false); } - public ES93BinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType elementType, boolean useDirectIO) { + public ES93BinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO) { super(NAME); rawFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java index d3bded4080088..d2278bc0f30bf 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java @@ -15,20 +15,17 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.common.util.FeatureFlag; import org.elasticsearch.index.codec.vectors.AbstractFlatVectorsFormat; import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.Map; public class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat { - // TODO: replace with DenseVectorFieldMapper.ElementType - public enum ElementType { - STANDARD, - BIT, // only supports byte[] - BFLOAT16 // only supports float[] - } + public static final FeatureFlag GENERIC_VECTOR_FORMAT = new FeatureFlag("generic_vector_format"); static final String NAME = "ES93GenericFlatVectorsFormat"; static final String VECTOR_FORMAT_INFO_EXTENSION = "vfi"; @@ -44,7 +41,7 @@ public enum ElementType { VERSION_CURRENT ); - private static final DirectIOCapableFlatVectorsFormat standardVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( + private static final DirectIOCapableFlatVectorsFormat defaultVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); private static final DirectIOCapableFlatVectorsFormat bitVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat( @@ -61,10 +58,10 @@ public String getName() { ); private static final Map supportedFormats = Map.of( + defaultVectorFormat.getName(), + defaultVectorFormat, bitVectorFormat.getName(), bitVectorFormat, - standardVectorFormat.getName(), - standardVectorFormat, bfloat16VectorFormat.getName(), bfloat16VectorFormat ); @@ -73,13 +70,13 @@ public String getName() { private final boolean useDirectIO; public ES93GenericFlatVectorsFormat() { - this(ElementType.STANDARD, false); + this(DenseVectorFieldMapper.ElementType.FLOAT, false); } - public ES93GenericFlatVectorsFormat(ElementType elementType, boolean useDirectIO) { + public ES93GenericFlatVectorsFormat(DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO) { super(NAME); writeFormat = switch (elementType) { - case STANDARD -> standardVectorFormat; + case FLOAT, BYTE -> defaultVectorFormat; case BIT -> bitVectorFormat; case BFLOAT16 -> bfloat16VectorFormat; }; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java index 78f356f8762da..15dce38cb742c 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java @@ -27,6 +27,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.concurrent.ExecutorService; @@ -49,7 +50,7 @@ public ES93HnswBinaryQuantizedVectorsFormat() { * * @param useDirectIO whether to use direct IO when reading raw vectors */ - public ES93HnswBinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType elementType, boolean useDirectIO) { + public ES93HnswBinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO) { super(NAME); flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(elementType, useDirectIO); } @@ -64,7 +65,7 @@ public ES93HnswBinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.Element public ES93HnswBinaryQuantizedVectorsFormat( int maxConn, int beamWidth, - ES93GenericFlatVectorsFormat.ElementType elementType, + DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO ) { super(NAME, maxConn, beamWidth); @@ -85,7 +86,7 @@ public ES93HnswBinaryQuantizedVectorsFormat( public ES93HnswBinaryQuantizedVectorsFormat( int maxConn, int beamWidth, - ES93GenericFlatVectorsFormat.ElementType elementType, + DenseVectorFieldMapper.ElementType elementType, boolean useDirectIO, int numMergeWorkers, ExecutorService mergeExec diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java index bd191f2dfed64..bfa14632b4a4a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java @@ -17,6 +17,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.concurrent.ExecutorService; @@ -32,26 +33,25 @@ public ES93HnswVectorsFormat() { flatVectorsFormat = new ES93GenericFlatVectorsFormat(); } - public ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType elementType, boolean useDirectIO) { + public ES93HnswVectorsFormat(DenseVectorFieldMapper.ElementType elementType) { super(NAME); - flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO); + flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, false); } - public ES93HnswVectorsFormat(int maxConn, int beamWidth, ES93GenericFlatVectorsFormat.ElementType elementType, boolean useDirectIO) { + public ES93HnswVectorsFormat(int maxConn, int beamWidth, DenseVectorFieldMapper.ElementType elementType) { super(NAME, maxConn, beamWidth); - flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO); + flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, false); } public ES93HnswVectorsFormat( int maxConn, int beamWidth, - ES93GenericFlatVectorsFormat.ElementType elementType, - boolean useDirectIO, + DenseVectorFieldMapper.ElementType elementType, int numMergeWorkers, ExecutorService mergeExec ) { super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); - flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO); + flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, false); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java index 42f02d2d21366..ecae96b08c576 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es93/OffHeapBFloat16VectorValues.java @@ -21,6 +21,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorEncoding; @@ -63,6 +64,10 @@ abstract class OffHeapBFloat16VectorValues extends FloatVectorValues { this.flatVectorsScorer = flatVectorsScorer; bfloatBytes = new byte[dimension * BFloat16.BYTES]; value = new float[dimension]; + + assert (this instanceof HasIndexSlice) == false + : "BFloat16 should not implement HasIndexSlice until a bfloat16 scorer is created," + + " else Lucene99MemorySegmentFlatVectorsScorer will try to access 4-byte floats here"; } @Override diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index b540cd8ab4a61..b0cedb3e779e9 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -11,7 +11,9 @@ import org.elasticsearch.features.FeatureSpecification; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat; +import java.util.HashSet; import java.util.Set; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.RESCORE_VECTOR_QUANTIZED_VECTOR_MAPPING; @@ -59,10 +61,11 @@ public class MapperFeatures implements FeatureSpecification { ); static final NodeFeature EXCLUDE_VECTORS_DOCVALUE_BUGFIX = new NodeFeature("mapper.exclude_vectors_docvalue_bugfix"); static final NodeFeature BASE64_DENSE_VECTORS = new NodeFeature("mapper.base64_dense_vectors"); + public static final NodeFeature GENERIC_VECTOR_FORMAT = new NodeFeature("mapper.vectors.generic_vector_format"); @Override public Set getTestFeatures() { - return Set.of( + var features = Set.of( RangeFieldMapper.DATE_RANGE_INDEXING_FIX, IgnoredSourceFieldMapper.DONT_EXPAND_DOTS_IN_IGNORED_SOURCE, SourceFieldMapper.REMOVE_SYNTHETIC_SOURCE_ONLY_VALIDATION, @@ -103,5 +106,10 @@ public Set getTestFeatures() { EXCLUDE_VECTORS_DOCVALUE_BUGFIX, BASE64_DENSE_VECTORS ); + if (ES93GenericFlatVectorsFormat.GENERIC_VECTOR_FORMAT.isEnabled()) { + features = new HashSet<>(features); + features.add(GENERIC_VECTOR_FORMAT); + } + return features; } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorBlockLoader.java b/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorBlockLoader.java index e5561428364de..32869f8743878 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorBlockLoader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorBlockLoader.java @@ -53,7 +53,7 @@ public Builder builder(BlockFactory factory, int expectedCount) { @Override public AllReader reader(LeafReaderContext context) throws IOException { switch (fieldType.getElementType()) { - case FLOAT -> { + case FLOAT, BFLOAT16 -> { FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); if (floatVectorValues != null) { if (fieldType.isNormalized()) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorFromBinaryBlockLoader.java b/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorFromBinaryBlockLoader.java index f5f9e8dc88295..7e729eca933e5 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorFromBinaryBlockLoader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/blockloader/docvalues/DenseVectorFromBinaryBlockLoader.java @@ -50,6 +50,7 @@ public AllReader reader(LeafReaderContext context) throws IOException { } return switch (elementType) { case FLOAT -> new FloatDenseVectorFromBinary(docValues, dims, indexVersion); + case BFLOAT16 -> new BFloat16DenseVectorFromBinary(docValues, dims, indexVersion); case BYTE -> new ByteDenseVectorFromBinary(docValues, dims, indexVersion); case BIT -> new BitDenseVectorFromBinary(docValues, dims, indexVersion); }; @@ -132,6 +133,29 @@ public String toString() { } } + private static class BFloat16DenseVectorFromBinary extends AbstractDenseVectorFromBinary { + BFloat16DenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) { + super(docValues, dims, indexVersion, new float[dims]); + } + + @Override + protected void writeScratchToBuilder(float[] scratch, BlockLoader.FloatBuilder builder) { + for (float value : scratch) { + builder.appendFloat(value); + } + } + + @Override + protected void decodeDenseVector(BytesRef bytesRef, float[] scratch) { + VectorEncoderDecoder.decodeBFloat16DenseVector(bytesRef, scratch); + } + + @Override + public String toString() { + return "BFloat16DenseVectorFromBinary.Bytes"; + } + } + private static class ByteDenseVectorFromBinary extends AbstractDenseVectorFromBinary { ByteDenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) { this(docValues, dims, indexVersion, dims); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 27d578d62befe..2e83828944a67 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -47,6 +47,7 @@ import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; +import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; @@ -55,6 +56,10 @@ import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.BlockLoader; @@ -127,15 +132,10 @@ */ public class DenseVectorFieldMapper extends FieldMapper { public static final String COSINE_MAGNITUDE_FIELD_SUFFIX = "._magnitude"; - private static final float EPS = 1e-3f; public static final int BBQ_MIN_DIMS = 64; private static final boolean DEFAULT_HNSW_EARLY_TERMINATION = false; - public static boolean isNotUnitVector(float magnitude) { - return Math.abs(magnitude - 1.0f) > EPS; - } - /** * The heuristic to utilize when executing a filtered search against vectors indexed in an HNSW graph. */ @@ -241,7 +241,8 @@ public static class Builder extends FieldMapper.Builder { private final Parameter elementType = new Parameter<>("element_type", false, () -> ElementType.FLOAT, (n, c, o) -> { ElementType elementType = namesToElementType.get((String) o); - if (elementType == null) { + if (elementType == null + || (elementType == ElementType.BFLOAT16 && ES93GenericFlatVectorsFormat.GENERIC_VECTOR_FORMAT.isEnabled() == false)) { throw new MapperParsingException("invalid element_type [" + o + "]; available types are " + namesToElementType.keySet()); } return elementType; @@ -395,6 +396,7 @@ private DenseVectorIndexOptions defaultIndexOptions(boolean defaultInt8Hnsw, boo return new BBQHnswIndexOptions( Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + false, new RescoreVector(DEFAULT_OVERSAMPLE) ); } else if (defaultInt8Hnsw) { @@ -464,6 +466,7 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) { public enum ElementType { BYTE, FLOAT, + BFLOAT16, BIT; public static ElementType fromString(String name) { @@ -478,6 +481,7 @@ public String toString() { public static final Element BYTE_ELEMENT = new ByteElement(); public static final Element FLOAT_ELEMENT = new FloatElement(); + public static final Element BFLOAT16_ELEMENT = new BFloat16Element(); public static final Element BIT_ELEMENT = new BitElement(); public static final Map namesToElementType = Map.of( @@ -485,6 +489,8 @@ public String toString() { ElementType.BYTE, ElementType.FLOAT.toString(), ElementType.FLOAT, + ElementType.BFLOAT16.toString(), + ElementType.BFLOAT16, ElementType.BIT.toString(), ElementType.BIT ); @@ -494,6 +500,7 @@ public abstract static class Element { public static Element getElement(ElementType elementType) { return switch (elementType) { case FLOAT -> FLOAT_ELEMENT; + case BFLOAT16 -> BFLOAT16_ELEMENT; case BYTE -> BYTE_ELEMENT; case BIT -> BIT_ELEMENT; }; @@ -531,7 +538,7 @@ public static ElementType checkValidVector(float[] vector, ElementType... possib public abstract ElementType elementType(); - public abstract void writeValue(ByteBuffer byteBuffer, float value); + public abstract void writeValues(ByteBuffer byteBuffer, float[] values); public abstract void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException; @@ -550,6 +557,10 @@ public abstract VectorData parseKnnVector( public abstract ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes); + public boolean isUnitVector(float squaredMagnitude) { + return Math.abs(squaredMagnitude - 1.0f) < 1e-3f; + } + public void checkVectorBounds(float[] vector) { StringBuilder errors = checkVectorErrors(vector); if (errors != null) { @@ -630,8 +641,10 @@ public ElementType elementType() { } @Override - public void writeValue(ByteBuffer byteBuffer, float value) { - byteBuffer.put((byte) value); + public void writeValues(ByteBuffer byteBuffer, float[] values) { + for (float f : values) { + byteBuffer.put((byte) f); + } } @Override @@ -936,8 +949,9 @@ public ElementType elementType() { } @Override - public void writeValue(ByteBuffer byteBuffer, float value) { - byteBuffer.putFloat(value); + public void writeValues(ByteBuffer byteBuffer, float[] values) { + byteBuffer.asFloatBuffer().put(values); + byteBuffer.position(byteBuffer.position() + (values.length * Float.BYTES)); } @Override @@ -1003,7 +1017,7 @@ void checkVectorMagnitude(VectorSimilarity similarity, UnaryOperator { int index = 0; for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { @@ -1072,7 +1085,7 @@ public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFie vandm.squaredMagnitude ); float[] vector = vandm.vectorData.asFloatVector(); - if (fieldMapper.fieldType().isNormalized() && isNotUnitVector(vandm.squaredMagnitude)) { + if (fieldMapper.fieldType().isNormalized() && isUnitVector(vandm.squaredMagnitude) == false) { float length = (float) Math.sqrt(vandm.squaredMagnitude); for (int i = 0; i < vector.length; i++) { vector[i] /= length; @@ -1141,20 +1154,26 @@ VectorDataAndMagnitude parseBase64EncodedVector(DocumentParserContext context, I throws IOException { // BIG_ENDIAN is the default, but just being explicit here ByteBuffer byteBuffer = ByteBuffer.wrap(Base64.getDecoder().decode(context.parser().text())).order(ByteOrder.BIG_ENDIAN); - if (byteBuffer.remaining() != dims * Float.BYTES) { + float[] decodedVector = new float[dims]; + if (byteBuffer.remaining() == dims * Float.BYTES) { + byteBuffer.asFloatBuffer().get(decodedVector); + } else if (byteBuffer.remaining() == dims * BFloat16.BYTES) { + BFloat16.bFloat16ToFloat(byteBuffer.asShortBuffer(), decodedVector); + } else { throw new ParsingException( context.parser().getTokenLocation(), "Failed to parse object: Base64 decoded vector byte length [" + byteBuffer.remaining() + "] does not match the expected length of [" + (dims * Float.BYTES) + + "] or [" + + (dims * BFloat16.BYTES) + "] for dimension count [" + dims + "]" ); } - float[] decodedVector = new float[dims]; - byteBuffer.asFloatBuffer().get(decodedVector); + dimChecker.accept(decodedVector.length, true); VectorData vectorData = VectorData.fromFloats(decodedVector); float squaredMagnitude = (float) computeSquaredMagnitude(vectorData); @@ -1196,6 +1215,36 @@ static UnaryOperator errorElementsAppender(float[] vector) { } } + private static class BFloat16Element extends FloatElement { + + @Override + public ElementType elementType() { + return ElementType.BFLOAT16; + } + + @Override + public void writeValues(ByteBuffer byteBuffer, float[] values) { + BFloat16.floatToBFloat16(values, byteBuffer.asShortBuffer()); + byteBuffer.position(byteBuffer.position() + (values.length * BFloat16.BYTES)); + } + + @Override + public void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException { + b.value(BFloat16.bFloat16ToFloat(byteBuffer.getShort())); + } + + @Override + public boolean isUnitVector(float squaredMagnitude) { + // bfloat16 needs to be more lenient + return Math.abs(squaredMagnitude - 1.0f) < 0.02f; + } + + @Override + public int getNumBytes(int dimensions) { + return dimensions * BFloat16.BYTES; + } + } + private static class BitElement extends ByteElement { @Override @@ -1267,7 +1316,7 @@ public enum VectorSimilarity { @Override float score(float similarity, ElementType elementType, int dim) { return switch (elementType) { - case BYTE, FLOAT -> 1f / (1f + similarity * similarity); + case BYTE, FLOAT, BFLOAT16 -> 1f / (1f + similarity * similarity); case BIT -> (dim - similarity) / dim; }; } @@ -1282,14 +1331,14 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi float score(float similarity, ElementType elementType, int dim) { assert elementType != ElementType.BIT; return switch (elementType) { - case BYTE, FLOAT -> (1 + similarity) / 2f; + case BYTE, FLOAT, BFLOAT16 -> (1 + similarity) / 2f; default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]"); }; } @Override public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersion, ElementType elementType) { - return indexVersion.onOrAfter(NORMALIZE_COSINE) && ElementType.FLOAT.equals(elementType) + return indexVersion.onOrAfter(NORMALIZE_COSINE) && (elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16) ? VectorSimilarityFunction.DOT_PRODUCT : VectorSimilarityFunction.COSINE; } @@ -1299,7 +1348,7 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi float score(float similarity, ElementType elementType, int dim) { return switch (elementType) { case BYTE -> 0.5f + similarity / (float) (dim * (1 << 15)); - case FLOAT -> (1 + similarity) / 2f; + case FLOAT, BFLOAT16 -> (1 + similarity) / 2f; default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]"); }; } @@ -1313,7 +1362,7 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi @Override float score(float similarity, ElementType elementType, int dim) { return switch (elementType) { - case BYTE, FLOAT -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1; + case BYTE, FLOAT, BFLOAT16 -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1; default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]"); }; } @@ -1428,16 +1477,11 @@ public enum VectorIndexType { public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); - if (mNode == null) { - mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; - } - if (efConstructionNode == null) { - efConstructionNode = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; - } - int m = XContentMapValues.nodeIntegerValue(mNode); - int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode); + int m = XContentMapValues.nodeIntegerValue(mNode, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN); + int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH); MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); + return new HnswIndexOptions(m, efConstruction); } @@ -1596,14 +1640,13 @@ public boolean supportsDimension(int dims) { public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); - if (mNode == null) { - mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; - } - if (efConstructionNode == null) { - efConstructionNode = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; - } - int m = XContentMapValues.nodeIntegerValue(mNode); - int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode); + Object onDiskRescoreNode = ES93GenericFlatVectorsFormat.GENERIC_VECTOR_FORMAT.isEnabled() + ? indexOptionsMap.remove("on_disk_rescore") + : false; + + int m = XContentMapValues.nodeIntegerValue(mNode, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN); + int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH); + boolean onDiskRescore = XContentMapValues.nodeBooleanValue(onDiskRescoreNode, false); RescoreVector rescoreVector = null; if (hasRescoreIndexVersion(indexVersion)) { @@ -1614,12 +1657,12 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map createExactKnnByteQuery(queryVector.asByteVector()); - case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector()); + case FLOAT, BFLOAT16 -> createExactKnnFloatQuery(queryVector.asFloatVector()); case BIT -> createExactKnnBitQuery(queryVector.asByteVector()); }; if (vectorSimilarity != null) { @@ -2554,7 +2615,7 @@ private Query createExactKnnFloatQuery(float[] queryVector) { if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); element.checkVectorMagnitude(similarity, FloatElement.errorElementsAppender(queryVector), squaredMagnitude); - if (isNormalized() && isNotUnitVector(squaredMagnitude)) { + if (isNormalized() && element.isUnitVector(squaredMagnitude) == false) { float length = (float) Math.sqrt(squaredMagnitude); queryVector = Arrays.copyOf(queryVector, queryVector.length); for (int i = 0; i < queryVector.length; i++) { @@ -2598,7 +2659,7 @@ public Query createKnnQuery( knnSearchStrategy, hnswEarlyTermination ); - case FLOAT -> createKnnFloatQuery( + case FLOAT, BFLOAT16 -> createKnnFloatQuery( queryVector.asFloatVector(), k, numCands, @@ -2750,7 +2811,7 @@ private Query createKnnFloatQuery( if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); element.checkVectorMagnitude(similarity, FloatElement.errorElementsAppender(queryVector), squaredMagnitude); - if (isNormalized() && isNotUnitVector(squaredMagnitude)) { + if (isNormalized() && element.isUnitVector(squaredMagnitude) == false) { float length = (float) Math.sqrt(squaredMagnitude); queryVector = Arrays.copyOf(queryVector, queryVector.length); for (int i = 0; i < queryVector.length; i++) { @@ -3039,7 +3100,7 @@ private void parseBinaryDocValuesVectorAndIndex(DocumentParserContext context) t checkDimensionExceeded(i, context); } }, fieldType().similarity); - vectorData.addToBuffer(byteBuffer); + vectorData.addToBuffer(element, byteBuffer); if (indexCreatedVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)) { // encode vector magnitude at the end double dotProduct = element.computeSquaredMagnitude(vectorData); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java index 120682d185535..8363f32f8f27f 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java @@ -22,6 +22,7 @@ import org.elasticsearch.index.fielddata.SortedBinaryDocValues; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.script.field.DocValuesScriptFieldFactory; +import org.elasticsearch.script.field.vectors.BFloat16BinaryDenseVectorDocValuesField; import org.elasticsearch.script.field.vectors.BinaryDenseVectorDocValuesField; import org.elasticsearch.script.field.vectors.BitBinaryDenseVectorDocValuesField; import org.elasticsearch.script.field.vectors.BitKnnDenseVectorDocValuesField; @@ -69,7 +70,8 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { if (indexed) { return switch (elementType) { case BYTE -> new ByteKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims); - case FLOAT -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims); + // bfloat16 is hidden by the FloatVectorValues implementation + case FLOAT, BFLOAT16 -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims); case BIT -> new BitKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims); }; } else { @@ -77,6 +79,7 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { return switch (elementType) { case BYTE -> new ByteBinaryDenseVectorDocValuesField(values, name, elementType, dims); case FLOAT -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion); + case BFLOAT16 -> new BFloat16BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion); case BIT -> new BitBinaryDenseVectorDocValuesField(values, name, elementType, dims); }; } @@ -85,105 +88,125 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { } } - @Override - public FormattedDocValues getFormattedValues(DocValueFormat format) { - int dims = elementType == ElementType.BIT ? this.dims / Byte.SIZE : this.dims; - return switch (elementType) { - case BYTE, BIT -> new FormattedDocValues() { - private byte[] vector = new byte[dims]; - private ByteVectorValues byteVectorValues; // use when indexed - private KnnVectorValues.DocIndexIterator iterator; // use when indexed - private BinaryDocValues binary; // use when not indexed - { - try { - if (indexed) { - byteVectorValues = reader.getByteVectorValues(field); - iterator = (byteVectorValues == null) ? null : byteVectorValues.iterator(); - } else { - binary = DocValues.getBinary(reader, field); - } - } catch (IOException e) { - throw new IllegalStateException("Cannot load doc values", e); - } - + private class ByteDocValues implements FormattedDocValues { + private final int dims; + private byte[] vector; + private ByteVectorValues byteVectorValues; // use when indexed + private KnnVectorValues.DocIndexIterator iterator; // use when indexed + private BinaryDocValues binary; // use when not indexed + + ByteDocValues(int dims) { + this.dims = dims; + this.vector = new byte[dims]; + try { + if (indexed) { + byteVectorValues = reader.getByteVectorValues(field); + iterator = (byteVectorValues == null) ? null : byteVectorValues.iterator(); + } else { + binary = DocValues.getBinary(reader, field); } + } catch (IOException e) { + throw new IllegalStateException("Cannot load doc values", e); + } - @Override - public boolean advanceExact(int docId) throws IOException { - if (indexed) { - if (iteratorAdvanceExact(iterator, docId) == false) { - return false; - } - vector = byteVectorValues.vectorValue(iterator.index()); - } else { - if (binary == null || binary.advanceExact(docId) == false) { - return false; - } - BytesRef ref = binary.binaryValue(); - System.arraycopy(ref.bytes, ref.offset, vector, 0, dims); - } - return true; - } + } - @Override - public int docValueCount() { - return 1; + @Override + public boolean advanceExact(int docId) throws IOException { + if (indexed) { + if (iteratorAdvanceExact(iterator, docId) == false) { + return false; } - - public Object nextValue() { - Byte[] vectorValue = new Byte[dims]; - for (int i = 0; i < dims; i++) { - vectorValue[i] = vector[i]; - } - return vectorValue; + vector = byteVectorValues.vectorValue(iterator.index()); + } else { + if (binary == null || binary.advanceExact(docId) == false) { + return false; } - }; - case FLOAT -> new FormattedDocValues() { - float[] vector = new float[dims]; - private FloatVectorValues floatVectorValues; // use when indexed - private KnnVectorValues.DocIndexIterator iterator; // use when indexed - private BinaryDocValues binary; // use when not indexed - { - try { - if (indexed) { - floatVectorValues = reader.getFloatVectorValues(field); - iterator = (floatVectorValues == null) ? null : floatVectorValues.iterator(); - } else { - binary = DocValues.getBinary(reader, field); - } - } catch (IOException e) { - throw new IllegalStateException("Cannot load doc values", e); - } + BytesRef ref = binary.binaryValue(); + System.arraycopy(ref.bytes, ref.offset, vector, 0, dims); + } + return true; + } - } + @Override + public int docValueCount() { + return 1; + } - @Override - public boolean advanceExact(int docId) throws IOException { - if (indexed) { - if (iteratorAdvanceExact(iterator, docId) == false) { - return false; - } - vector = floatVectorValues.vectorValue(iterator.index()); - } else { - if (binary == null || binary.advanceExact(docId) == false) { - return false; - } - BytesRef ref = binary.binaryValue(); - VectorEncoderDecoder.decodeDenseVector(indexVersion, ref, vector); - } - return true; - } + public Object nextValue() { + Byte[] vectorValue = new Byte[dims]; + for (int i = 0; i < dims; i++) { + vectorValue[i] = vector[i]; + } + return vectorValue; + } + } - @Override - public int docValueCount() { - return 1; + private class FloatDocValues implements FormattedDocValues { + private float[] vector = new float[dims]; + private FloatVectorValues floatVectorValues; // use when indexed + private KnnVectorValues.DocIndexIterator iterator; // use when indexed + private BinaryDocValues binary; // use when not indexed + + FloatDocValues() { + try { + if (indexed) { + floatVectorValues = reader.getFloatVectorValues(field); + iterator = (floatVectorValues == null) ? null : floatVectorValues.iterator(); + } else { + binary = DocValues.getBinary(reader, field); } + } catch (IOException e) { + throw new IllegalStateException("Cannot load doc values", e); + } + } - @Override - public Object nextValue() { - return Arrays.copyOf(vector, vector.length); + @Override + public boolean advanceExact(int docId) throws IOException { + if (indexed) { + if (iteratorAdvanceExact(iterator, docId) == false) { + return false; + } + vector = floatVectorValues.vectorValue(iterator.index()); + } else { + if (binary == null || binary.advanceExact(docId) == false) { + return false; } - }; + BytesRef ref = binary.binaryValue(); + decodeDenseVector(indexVersion, ref, vector); + } + return true; + } + + void decodeDenseVector(IndexVersion indexVersion, BytesRef ref, float[] vector) { + VectorEncoderDecoder.decodeDenseVector(indexVersion, ref, vector); + } + + @Override + public int docValueCount() { + return 1; + } + + @Override + public Object nextValue() { + return Arrays.copyOf(vector, vector.length); + } + } + + private class BFloat16DocValues extends FloatDocValues { + @Override + void decodeDenseVector(IndexVersion indexVersion, BytesRef vectorBR, float[] vector) { + VectorEncoderDecoder.decodeBFloat16DenseVector(vectorBR, vector); + } + } + + @Override + public FormattedDocValues getFormattedValues(DocValueFormat format) { + return switch (elementType) { + case BYTE -> new ByteDocValues(dims); + case BIT -> new ByteDocValues(dims / Byte.SIZE); + case FLOAT -> new FloatDocValues(); + case BFLOAT16 -> new BFloat16DocValues(); }; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java index 9dec4a4f2dd61..f60e7139cd3d9 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java @@ -12,10 +12,12 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.codec.vectors.BFloat16; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; +import java.nio.ShortBuffer; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAGNITUDE_STORED_INDEX_VERSION; @@ -84,6 +86,14 @@ public static void decodeDenseVector(IndexVersion indexVersion, BytesRef vectorB } } + public static void decodeBFloat16DenseVector(BytesRef vectorBR, float[] vector) { + if (vectorBR == null) { + throw new IllegalArgumentException(DenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE); + } + ShortBuffer sb = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer(); + BFloat16.bFloat16ToFloat(sb, vector); + } + /** * Decodes a BytesRef into the provided array of bytes * @param vectorBR - dense vector encoded in BytesRef diff --git a/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java b/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java index 13be089753fb7..6932b5c718dc6 100644 --- a/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java +++ b/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java @@ -210,7 +210,7 @@ public L1Norm(ScoreScript scoreScript, Object queryVector, String fieldName) { } throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new FloatL1Norm(scoreScript, field, (List) queryVector); } @@ -252,7 +252,7 @@ public static final class Hamming { @SuppressWarnings("unchecked") public Hamming(ScoreScript scoreScript, Object queryVector, String fieldName) { DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName); - if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) { + if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT || field.getElementType() == ElementType.BFLOAT16) { throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors"); } if (queryVector instanceof List) { @@ -320,7 +320,7 @@ public L2Norm(ScoreScript scoreScript, Object queryVector, String fieldName) { } throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new FloatL2Norm(scoreScript, field, (List) queryVector); } @@ -478,7 +478,7 @@ public DotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) } throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new FloatDotProduct(scoreScript, field, (List) queryVector); } @@ -547,7 +547,7 @@ public CosineSimilarity(ScoreScript scoreScript, Object queryVector, String fiel } throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new FloatCosineSimilarity(scoreScript, field, (List) queryVector); } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16BinaryDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16BinaryDenseVectorDocValuesField.java new file mode 100644 index 0000000000000..4805b77250c00 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16BinaryDenseVectorDocValuesField.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder; + +public class BFloat16BinaryDenseVectorDocValuesField extends BinaryDenseVectorDocValuesField { + public BFloat16BinaryDenseVectorDocValuesField( + BinaryDocValues input, + String name, + DenseVectorFieldMapper.ElementType elementType, + int dims, + IndexVersion indexVersion + ) { + super(input, name, elementType, dims, indexVersion); + } + + @Override + void decodeDenseVector(IndexVersion indexVersion, BytesRef vectorBR, float[] vector) { + VectorEncoderDecoder.decodeBFloat16DenseVector(vectorBR, vector); + } +} diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BinaryDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BinaryDenseVectorDocValuesField.java index 0bb9d2a3a0b0d..376d04cbeb957 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/BinaryDenseVectorDocValuesField.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BinaryDenseVectorDocValuesField.java @@ -86,8 +86,12 @@ public DenseVector getInternal() { private void decodeVectorIfNecessary() { if (decoded == false && value != null) { - VectorEncoderDecoder.decodeDenseVector(indexVersion, value, vectorValue); + decodeDenseVector(indexVersion, value, vectorValue); decoded = true; } } + + void decodeDenseVector(IndexVersion indexVersion, BytesRef value, float[] vector) { + VectorEncoderDecoder.decodeDenseVector(indexVersion, value, vectorValue); + } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/VectorData.java b/server/src/main/java/org/elasticsearch/search/vectors/VectorData.java index a128130a71cfc..84cd638cceef2 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/VectorData.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/VectorData.java @@ -75,11 +75,9 @@ public float[] asFloatVector() { return vec; } - public void addToBuffer(ByteBuffer byteBuffer) { + public void addToBuffer(DenseVectorFieldMapper.Element element, ByteBuffer byteBuffer) { if (floatVector != null) { - for (float val : floatVector) { - byteBuffer.putFloat(val); - } + element.writeValues(byteBuffer, floatVector); } else { byteBuffer.put(byteVector); } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..509d96b5159e1 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQBFloat16VectorsFormatTests.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.diskbbq; + +import com.carrotsearch.randomizedtesting.generators.RandomPicks; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.tests.util.TestUtil; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BaseBFloat16KnnVectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.junit.AssumptionViolatedException; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.MIN_CENTROIDS_PER_PARENT_CLUSTER; +import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.MIN_VECTORS_PER_CLUSTER; +import static org.hamcrest.Matchers.anEmptyMap; +import static org.hamcrest.Matchers.equalTo; + +public class ES920DiskBBQBFloat16VectorsFormatTests extends BaseBFloat16KnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + private KnnVectorsFormat format; + + @Before + @Override + public void setUp() throws Exception { + if (rarely()) { + format = new ES920DiskBBQVectorsFormat( + random().nextInt(2 * MIN_VECTORS_PER_CLUSTER, ES920DiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER), + random().nextInt(8, ES920DiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER), + DenseVectorFieldMapper.ElementType.FLOAT, + random().nextBoolean() + ); + } else { + // run with low numbers to force many clusters with parents + format = new ES920DiskBBQVectorsFormat( + random().nextInt(MIN_VECTORS_PER_CLUSTER, 2 * MIN_VECTORS_PER_CLUSTER), + random().nextInt(MIN_CENTROIDS_PER_PARENT_CLUSTER, 8), + DenseVectorFieldMapper.ElementType.FLOAT, + random().nextBoolean() + ); + } + super.setUp(); + } + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(format); + } + + @Override + protected VectorSimilarityFunction randomSimilarity() { + return RandomPicks.randomFrom( + random(), + List.of( + VectorSimilarityFunction.DOT_PRODUCT, + VectorSimilarityFunction.EUCLIDEAN, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT + ) + ); + } + + @Override + public void testSearchWithVisitedLimit() { + throw new AssumptionViolatedException("ivf doesn't enforce visitation limit"); + } + + @Override + public void testAdvance() throws Exception { + // TODO re-enable with hierarchical IVF, clustering as it is is flaky + } + + @Override + protected void assertOffHeapByteSize(LeafReader r, String fieldName) throws IOException { + var fieldInfo = r.getFieldInfos().fieldInfo(fieldName); + + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader(fieldName); + } + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + long totalByteSize = offHeap.values().stream().mapToLong(Long::longValue).sum(); + // IVF doesn't report stats at the moment + assertThat(offHeap, anEmptyMap()); + assertThat(totalByteSize, equalTo(0L)); + } else { + throw new AssertionError("unexpected:" + r.getClass()); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java index 29e6c59d995be..9b452fe4fb9cb 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormatTests.java @@ -11,7 +11,6 @@ import com.carrotsearch.randomizedtesting.generators.RandomPicks; import org.apache.lucene.codecs.Codec; -import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; @@ -36,11 +35,12 @@ import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.util.TestUtil; import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.junit.AssumptionViolatedException; import org.junit.Before; import java.io.IOException; import java.util.List; -import java.util.Locale; import java.util.concurrent.atomic.AtomicBoolean; import static java.lang.String.format; @@ -51,8 +51,7 @@ import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.MIN_VECTORS_PER_CLUSTER; import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.oneOf; +import static org.hamcrest.Matchers.hasToString; public class ES920DiskBBQVectorsFormatTests extends BaseKnnVectorsFormatTestCase { @@ -70,6 +69,7 @@ public void setUp() throws Exception { format = new ES920DiskBBQVectorsFormat( random().nextInt(2 * MIN_VECTORS_PER_CLUSTER, ES920DiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER), random().nextInt(8, ES920DiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER), + DenseVectorFieldMapper.ElementType.FLOAT, random().nextBoolean() ); } else { @@ -77,6 +77,7 @@ public void setUp() throws Exception { format = new ES920DiskBBQVectorsFormat( random().nextInt(MIN_VECTORS_PER_CLUSTER, 2 * MIN_VECTORS_PER_CLUSTER), random().nextInt(MIN_CENTROIDS_PER_PARENT_CLUSTER, 8), + DenseVectorFieldMapper.ElementType.FLOAT, random().nextBoolean() ); } @@ -102,7 +103,7 @@ protected VectorEncoding randomVectorEncoding() { @Override public void testSearchWithVisitedLimit() { - // ivf doesn't enforce visitation limit + throw new AssumptionViolatedException("ivf doesn't enforce visitation limit"); } @Override @@ -135,17 +136,9 @@ public void testAdvance() throws Exception { } public void testToString() { - FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { - @Override - public KnnVectorsFormat knnVectorsFormat() { - return new ES920DiskBBQVectorsFormat(128, 4); - } - }; - String expectedPattern = "ES920DiskBBQVectorsFormat(vectorPerCluster=128)"; + KnnVectorsFormat format = new ES920DiskBBQVectorsFormat(128, 4); - var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); - var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); - assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + assertThat(format, hasToString("ES920DiskBBQVectorsFormat(vectorPerCluster=128)")); } public void testLimits() { diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java index 6cb1f7e61af5f..fb287e39b37d1 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedBFloat16VectorsFormatTests.java @@ -46,6 +46,7 @@ import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.codec.vectors.BaseBFloat16KnnVectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.junit.AssumptionViolatedException; import java.io.IOException; @@ -75,7 +76,7 @@ public class ES93BinaryQuantizedBFloat16VectorsFormatTests extends BaseBFloat16K @Override public void setUp() throws Exception { - format = new ES93BinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, random().nextBoolean()); + format = new ES93BinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.BFLOAT16, random().nextBoolean()); super.setUp(); } @@ -196,7 +197,7 @@ public void testToString() { var defaultScorer = expected.replaceAll("\\{}", "DefaultFlatVectorScorer"); var memSegScorer = expected.replaceAll("\\{}", "Lucene99MemorySegmentFlatVectorsScorer"); - KnnVectorsFormat format = new ES93BinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, false); + KnnVectorsFormat format = new ES93BinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.BFLOAT16, false); assertThat(format, hasToString(oneOf(defaultScorer, memSegScorer))); } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java index b9d53944ac654..39439bac02a8c 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java @@ -55,6 +55,7 @@ import org.apache.lucene.tests.store.MockDirectoryWrapper; import org.apache.lucene.tests.util.TestUtil; import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.junit.AssumptionViolatedException; import java.io.IOException; @@ -84,7 +85,7 @@ public class ES93BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatT @Override public void setUp() throws Exception { - format = new ES93BinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.STANDARD, random().nextBoolean()); + format = new ES93BinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT, random().nextBoolean()); super.setUp(); } @@ -201,7 +202,7 @@ public void testToString() { var defaultScorer = expected.replaceAll("\\{}", "DefaultFlatVectorScorer"); var memSegScorer = expected.replaceAll("\\{}", "Lucene99MemorySegmentFlatVectorsScorer"); - KnnVectorsFormat format = new ES93BinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.STANDARD, false); + KnnVectorsFormat format = new ES93BinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT, false); assertThat(format, hasToString(oneOf(defaultScorer, memSegScorer))); } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java index e6c7ab2f256d4..0220ba831e75d 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java @@ -13,6 +13,7 @@ import org.apache.lucene.store.Directory; import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.codec.vectors.BaseHnswBFloat16VectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.Locale; @@ -30,24 +31,17 @@ public class ES93HnswBFloat16VectorsFormatTests extends BaseHnswBFloat16VectorsF @Override protected KnnVectorsFormat createFormat() { - return new ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, random().nextBoolean()); + return new ES93HnswVectorsFormat(DenseVectorFieldMapper.ElementType.BFLOAT16); } @Override protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { - return new ES93HnswVectorsFormat(maxConn, beamWidth, ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, random().nextBoolean()); + return new ES93HnswVectorsFormat(maxConn, beamWidth, DenseVectorFieldMapper.ElementType.BFLOAT16); } @Override protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) { - return new ES93HnswVectorsFormat( - maxConn, - beamWidth, - ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, - random().nextBoolean(), - numMergeWorkers, - service - ); + return new ES93HnswVectorsFormat(maxConn, beamWidth, DenseVectorFieldMapper.ElementType.BFLOAT16, numMergeWorkers, service); } public void testToString() { diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java index 49dd1ba7c64e4..8d09eafb81ca2 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedBFloat16VectorsFormatTests.java @@ -16,6 +16,7 @@ import org.apache.lucene.tests.store.MockDirectoryWrapper; import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.codec.vectors.BaseHnswBFloat16VectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.Locale; @@ -34,7 +35,7 @@ public class ES93HnswBinaryQuantizedBFloat16VectorsFormatTests extends BaseHnswB @Override protected KnnVectorsFormat createFormat() { - return new ES93HnswBinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, random().nextBoolean()); + return new ES93HnswBinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.BFLOAT16, random().nextBoolean()); } @Override @@ -42,7 +43,7 @@ protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { return new ES93HnswBinaryQuantizedVectorsFormat( maxConn, beamWidth, - ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, + DenseVectorFieldMapper.ElementType.BFLOAT16, random().nextBoolean() ); } @@ -52,7 +53,7 @@ protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMerge return new ES93HnswBinaryQuantizedVectorsFormat( maxConn, beamWidth, - ES93GenericFlatVectorsFormat.ElementType.BFLOAT16, + DenseVectorFieldMapper.ElementType.BFLOAT16, random().nextBoolean(), numMergeWorkers, service diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java index 8afc974cfca7d..75235d894bb74 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java @@ -25,6 +25,7 @@ import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.tests.store.MockDirectoryWrapper; import org.elasticsearch.index.codec.vectors.BaseHnswVectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.Locale; @@ -43,7 +44,7 @@ public class ES93HnswBinaryQuantizedVectorsFormatTests extends BaseHnswVectorsFo @Override protected KnnVectorsFormat createFormat() { - return new ES93HnswBinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.STANDARD, random().nextBoolean()); + return new ES93HnswBinaryQuantizedVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT, random().nextBoolean()); } @Override @@ -51,7 +52,7 @@ protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { return new ES93HnswBinaryQuantizedVectorsFormat( maxConn, beamWidth, - ES93GenericFlatVectorsFormat.ElementType.STANDARD, + DenseVectorFieldMapper.ElementType.FLOAT, random().nextBoolean() ); } @@ -61,7 +62,7 @@ protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMerge return new ES93HnswBinaryQuantizedVectorsFormat( maxConn, beamWidth, - ES93GenericFlatVectorsFormat.ElementType.STANDARD, + DenseVectorFieldMapper.ElementType.FLOAT, random().nextBoolean(), numMergeWorkers, service diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBitVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBitVectorsFormatTests.java index f5f15bb1f06df..b54db35d77273 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBitVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBitVectorsFormatTests.java @@ -23,6 +23,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.util.TestUtil; import org.elasticsearch.index.codec.vectors.BaseKnnBitVectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.junit.Before; import java.io.IOException; @@ -36,9 +37,7 @@ public class ES93HnswBitVectorsFormatTests extends BaseKnnBitVectorsFormatTestCa @Override protected Codec getCodec() { - return TestUtil.alwaysKnnVectorsFormat( - new ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.BIT, random().nextBoolean()) - ); + return TestUtil.alwaysKnnVectorsFormat(new ES93HnswVectorsFormat(DenseVectorFieldMapper.ElementType.BIT)); } @Before diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java index 5fa507c23d756..84057c7709063 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java @@ -12,6 +12,7 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.store.Directory; import org.elasticsearch.index.codec.vectors.BaseHnswVectorsFormatTestCase; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import java.io.IOException; import java.util.Locale; @@ -29,24 +30,17 @@ public class ES93HnswVectorsFormatTests extends BaseHnswVectorsFormatTestCase { @Override protected KnnVectorsFormat createFormat() { - return new ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.STANDARD, random().nextBoolean()); + return new ES93HnswVectorsFormat(DenseVectorFieldMapper.ElementType.FLOAT); } @Override protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) { - return new ES93HnswVectorsFormat(maxConn, beamWidth, ES93GenericFlatVectorsFormat.ElementType.STANDARD, random().nextBoolean()); + return new ES93HnswVectorsFormat(maxConn, beamWidth, DenseVectorFieldMapper.ElementType.FLOAT); } @Override protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) { - return new ES93HnswVectorsFormat( - maxConn, - beamWidth, - ES93GenericFlatVectorsFormat.ElementType.STANDARD, - random().nextBoolean(), - numMergeWorkers, - service - ); + return new ES93HnswVectorsFormat(maxConn, beamWidth, DenseVectorFieldMapper.ElementType.FLOAT, numMergeWorkers, service); } public void testToString() { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java index 9478508da88d0..c3e3ed0c93c98 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java @@ -22,14 +22,14 @@ private DenseVectorFieldMapperTestUtils() {} public static List getSupportedSimilarities(DenseVectorFieldMapper.ElementType elementType) { return switch (elementType) { - case FLOAT, BYTE -> List.of(SimilarityMeasure.values()); + case FLOAT, BFLOAT16, BYTE -> List.of(SimilarityMeasure.values()); case BIT -> List.of(SimilarityMeasure.L2_NORM); }; } public static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) { return switch (elementType) { - case FLOAT, BYTE -> dimensions; + case FLOAT, BFLOAT16, BYTE -> dimensions; case BIT -> { assert dimensions % Byte.SIZE == 0; yield dimensions / Byte.SIZE; @@ -43,7 +43,7 @@ public static int randomCompatibleDimensions(DenseVectorFieldMapper.ElementType } return switch (elementType) { - case FLOAT, BYTE -> RandomNumbers.randomIntBetween(random(), 1, max); + case FLOAT, BFLOAT16, BYTE -> RandomNumbers.randomIntBetween(random(), 1, max); case BIT -> { if (max < 8) { throw new IllegalArgumentException("max must be at least 8 for bit vectors"); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index e2a30e7df9f31..f919397170d42 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -32,6 +32,7 @@ import org.elasticsearch.index.codec.CodecService; import org.elasticsearch.index.codec.LegacyPerFieldMapperCodec; import org.elasticsearch.index.codec.PerFieldMapperCodec; +import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.DocumentParsingException; @@ -76,6 +77,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasToString; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.startsWith; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -88,11 +90,14 @@ public class DenseVectorFieldMapperTests extends SyntheticVectorsMapperTestCase private final int dims; public DenseVectorFieldMapperTests() { - this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT); + this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BFLOAT16, ElementType.BIT); this.indexed = usually(); this.indexOptionsSet = this.indexed && randomBoolean(); int baseDims = ElementType.BIT == elementType ? 4 * Byte.SIZE : 4; - int randomMultiplier = ElementType.FLOAT == elementType ? randomIntBetween(1, 64) : 1; + int randomMultiplier = switch (elementType) { + case FLOAT, BFLOAT16 -> randomIntBetween(1, 64); + case BYTE, BIT -> 1; + }; this.dims = baseDims * randomMultiplier; } @@ -153,22 +158,31 @@ private void indexMapping(XContentBuilder b, IndexVersion indexVersion) throws I @Override protected Object getSampleValueForDocument(boolean binaryFormat) { if (binaryFormat) { - final byte[] toEncode; - if (elementType == ElementType.FLOAT) { - float[] array = randomNormalizedVector(this.dims); - final ByteBuffer buffer = ByteBuffer.allocate(Float.BYTES * array.length); - buffer.asFloatBuffer().put(array); - toEncode = buffer.array(); - } else { - toEncode = elementType == ElementType.BIT - ? randomByteArrayOfLength(this.dims / Byte.SIZE) - : randomByteArrayOfLength(this.dims); - } + byte[] toEncode = switch (elementType) { + case FLOAT -> { + float[] array = randomNormalizedVector(this.dims); + final ByteBuffer buffer = ByteBuffer.allocate(Float.BYTES * array.length); + buffer.asFloatBuffer().put(array); + yield buffer.array(); + } + case BFLOAT16 -> { + float[] array = randomNormalizedVector(this.dims); + final ByteBuffer buffer = ByteBuffer.allocate(BFloat16.BYTES * array.length); + BFloat16.floatToBFloat16(array, buffer.asShortBuffer()); + yield buffer.array(); + } + case BYTE -> randomByteArrayOfLength(dims); + case BIT -> randomByteArrayOfLength(this.dims / Byte.SIZE); + }; return Base64.getEncoder().encodeToString(toEncode); + } else { + return switch (elementType) { + case FLOAT -> convertToList(randomNormalizedVector(this.dims)); + case BFLOAT16 -> convertToBFloat16List(randomNormalizedVector(this.dims)); + case BYTE -> convertToList(randomByteArrayOfLength(dims)); + case BIT -> convertToList(randomByteArrayOfLength(this.dims / Byte.SIZE)); + }; } - return elementType == ElementType.FLOAT - ? convertToList(randomNormalizedVector(this.dims)) - : convertToList(randomByteArrayOfLength(elementType == ElementType.BIT ? this.dims / Byte.SIZE : dims)); } @Override @@ -184,6 +198,14 @@ public static List convertToList(float[] vector) { return list; } + public static List convertToBFloat16List(float[] vector) { + List list = new ArrayList<>(vector.length); + for (float v : vector) { + list.add(BFloat16.truncateToBFloat16(v)); + } + return list; + } + public static List convertToList(byte[] vector) { List list = new ArrayList<>(vector.length); for (byte v : vector) { @@ -1590,7 +1612,7 @@ protected Object generateRandomInputValue(MappedFieldType ft) { DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) ft; return switch (vectorFieldType.getElementType()) { case BYTE -> randomByteArrayOfLength(vectorFieldType.getVectorDimensions()); - case FLOAT -> randomNormalizedVector(vectorFieldType.getVectorDimensions()); + case FLOAT, BFLOAT16 -> randomNormalizedVector(vectorFieldType.getVectorDimensions()); case BIT -> randomByteArrayOfLength(vectorFieldType.getVectorDimensions() / 8); }; } @@ -1904,12 +1926,12 @@ public void testKnnVectorsFormat() throws IOException { assertThat(codec, instanceOf(LegacyPerFieldMapperCodec.class)); knnVectorsFormat = ((LegacyPerFieldMapperCodec) codec).getKnnVectorsFormatForField("field"); } - String expectedString = "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=" + String expectedString = "ES93HnswVectorsFormat(name=ES93HnswVectorsFormat, maxConn=" + (setM ? m : DEFAULT_MAX_CONN) + ", beamWidth=" + (setEfConstruction ? efConstruction : DEFAULT_BEAM_WIDTH) - + ", flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer())" - + ")"; + + ", flatVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat" + + ", format=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=DefaultFlatVectorScorer())))"; assertEquals(expectedString, knnVectorsFormat.toString()); } @@ -2039,14 +2061,15 @@ public void testKnnBBQHNSWVectorsFormat() throws IOException { assertThat(codec, instanceOf(LegacyPerFieldMapperCodec.class)); knnVectorsFormat = ((LegacyPerFieldMapperCodec) codec).getKnnVectorsFormatForField("field"); } - String expectedString = "ES818HnswBinaryQuantizedVectorsFormat(name=ES818HnswBinaryQuantizedVectorsFormat, maxConn=" + String expectedString = "ES93HnswBinaryQuantizedVectorsFormat(name=ES93HnswBinaryQuantizedVectorsFormat, maxConn=" + m + ", beamWidth=" + efConstruction - + ", flatVectorFormat=ES818BinaryQuantizedVectorsFormat(" - + "name=ES818BinaryQuantizedVectorsFormat, " - + "flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())))"; - assertEquals(expectedString, knnVectorsFormat.toString()); + + ", flatVectorFormat=ES93BinaryQuantizedVectorsFormat(" + + "name=ES93BinaryQuantizedVectorsFormat, " + + "rawVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat," + + " format=Lucene99FlatVectorsFormat"; + assertThat(knnVectorsFormat, hasToString(startsWith(expectedString))); } public void testKnnBBQIVFVectorsFormat() throws IOException { @@ -2176,24 +2199,23 @@ protected boolean supportsEmptyInputArray() { private static class DenseVectorSyntheticSourceSupport implements SyntheticSourceSupport { private final int dims = between(5, 1000); - private final ElementType elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT); + private final ElementType elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BFLOAT16, ElementType.BIT); private final boolean indexed = randomBoolean(); private final boolean indexOptionsSet = indexed && randomBoolean(); @Override public SyntheticSourceExample example(int maxValues) throws IOException { Object value = switch (elementType) { - case BYTE, BIT: - yield randomList(dims, dims, ESTestCase::randomByte); - case FLOAT: - yield randomList(dims, dims, ESTestCase::randomFloat); + case BYTE, BIT -> randomList(dims, dims, ESTestCase::randomByte); + case FLOAT -> randomList(dims, dims, ESTestCase::randomFloat); + case BFLOAT16 -> randomList(dims, dims, () -> BFloat16.truncateToBFloat16(randomFloat())); }; return new SyntheticSourceExample(value, value, this::mapping); } private void mapping(XContentBuilder b) throws IOException { b.field("type", "dense_vector"); - if (elementType == ElementType.BYTE || elementType == ElementType.BIT || randomBoolean()) { + if (elementType != ElementType.FLOAT || randomBoolean()) { b.field("element_type", elementType.toString()); } b.field("dims", elementType == ElementType.BIT ? dims * Byte.SIZE : dims); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 24db52cb10c0c..583472a6ff076 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -43,6 +43,7 @@ import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER; import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.MIN_VECTORS_PER_CLUSTER; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BFLOAT16; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BIT; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.FLOAT; @@ -133,6 +134,7 @@ public static DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsA new DenseVectorFieldMapper.BBQHnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), + randomBoolean(), randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.BBQFlatIndexOptions( @@ -173,7 +175,12 @@ private DenseVectorFieldMapper.DenseVectorIndexOptions randomIndexOptionsHnswQua randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), rescoreVector ), - new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), rescoreVector) + new DenseVectorFieldMapper.BBQHnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomBoolean(), + rescoreVector + ) ); } @@ -782,6 +789,7 @@ public void testRescoreOversampleQueryOverrides() { public void testFilterSearchThreshold() { List>> cases = List.of( Tuple.tuple(FLOAT, q -> ((ESKnnFloatVectorQuery) q).getStrategy()), + Tuple.tuple(BFLOAT16, q -> ((ESKnnFloatVectorQuery) q).getStrategy()), Tuple.tuple(BYTE, q -> ((ESKnnByteVectorQuery) q).getStrategy()), Tuple.tuple(BIT, q -> ((ESKnnByteVectorQuery) q).getStrategy()) ); diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index 04731f193fb14..65644e0dc7f50 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -247,7 +247,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que approxFilterQuery, expectedStrategy ); - case FLOAT -> new ESKnnFloatVectorQuery( + case FLOAT, BFLOAT16 -> new ESKnnFloatVectorQuery( VECTOR_FIELD, queryBuilder.queryVector().asFloatVector(), k, @@ -268,7 +268,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que yield new DenseVectorQuery.Bytes(queryBuilder.queryVector().asByteVector(), VECTOR_FIELD); } } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (filterQuery != null) { yield new BooleanQuery.Builder().add( new DenseVectorQuery.Floats(queryBuilder.queryVector().asFloatVector(), VECTOR_FIELD), diff --git a/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java index 5871c1d8f18c5..27728aaa550b1 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java @@ -101,7 +101,7 @@ protected void doAssertLuceneQuery(ExactKnnQueryBuilder queryBuilder, Query quer float[] expected = Arrays.copyOf(queryBuilder.getQuery().asFloatVector(), queryBuilder.getQuery().asFloatVector().length); float magnitude = VectorUtil.dotProduct(expected, expected); if (context.getIndexSettings().getIndexVersionCreated().onOrAfter(IndexVersions.NORMALIZED_VECTOR_COSINE) - && DenseVectorFieldMapper.isNotUnitVector(magnitude)) { + && DenseVectorFieldMapper.FLOAT_ELEMENT.isUnitVector(magnitude) == false) { VectorUtil.l2normalize(expected); assertArrayEquals(expected, denseVectorQuery.getQuery(), 0.0f); } else { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index ebe62fa0cba98..38b3631862690 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -43,10 +43,10 @@ import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; -import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; -import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.search.profile.query.QueryProfiler; import org.elasticsearch.test.ESTestCase; @@ -283,10 +283,20 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th IndexWriterConfig iwc = new IndexWriterConfig(); // Pick codec from quantized vector formats to ensure scores use real scores when using knn rescore KnnVectorsFormat format = randomFrom( - new ES920DiskBBQVectorsFormat(DEFAULT_VECTORS_PER_CLUSTER, DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, randomBoolean()), - new ES818BinaryQuantizedVectorsFormat(), - new ES818HnswBinaryQuantizedVectorsFormat(), - new ES93HnswBinaryQuantizedVectorsFormat(), + new ES920DiskBBQVectorsFormat( + DEFAULT_VECTORS_PER_CLUSTER, + DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, + randomFrom(DenseVectorFieldMapper.ElementType.FLOAT, DenseVectorFieldMapper.ElementType.BFLOAT16), + randomBoolean() + ), + new ES93BinaryQuantizedVectorsFormat( + randomFrom(DenseVectorFieldMapper.ElementType.FLOAT, DenseVectorFieldMapper.ElementType.BFLOAT16), + randomBoolean() + ), + new ES93HnswBinaryQuantizedVectorsFormat( + randomFrom(DenseVectorFieldMapper.ElementType.FLOAT, DenseVectorFieldMapper.ElementType.BFLOAT16), + randomBoolean() + ), new ES813Int8FlatVectorFormat(), new ES813Int8FlatVectorFormat(), new ES814HnswScalarQuantizedVectorsFormat() diff --git a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java index a1a0486aecfc8..527d961197e8a 100644 --- a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java +++ b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java @@ -27,7 +27,8 @@ public enum FeatureFlag { null ), RANDOM_SAMPLING("es.random_sampling_feature_flag_enabled=true", Version.fromString("9.2.0"), null), - INFERENCE_API_CCM("es.inference_api_ccm_feature_flag_enabled=true", Version.fromString("9.3.0"), null); + INFERENCE_API_CCM("es.inference_api_ccm_feature_flag_enabled=true", Version.fromString("9.3.0"), null), + GENERIC_VECTOR_FORMAT("es.generic_vector_format_feature_flag_enabled=true", Version.fromString("9.3.0"), null); public final String systemProperty; public final Version from; diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index d852ddeaf1d8b..ff720fed2210c 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.codec.vectors.BFloat16; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.script.field.vectors.DenseVector; @@ -250,6 +251,15 @@ public void setup() throws IOException { buffer.asFloatBuffer().put(array); yield Base64.getEncoder().encodeToString(buffer.array()); } + case BFLOAT16 -> { + float[] array = new float[numDims]; + for (int k = 0; k < numDims; k++) { + array[k] = vector.get(k).floatValue(); + } + final ByteBuffer buffer = ByteBuffer.allocate(BFloat16.BYTES * numDims); + BFloat16.floatToBFloat16(array, buffer.asShortBuffer()); + yield Base64.getEncoder().encodeToString(buffer.array()); + } case BYTE, BIT -> { byte[] array = new byte[numDims]; for (int k = 0; k < numDims; k++) { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index dd63facfe7dda..6415467f536d1 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -270,7 +270,7 @@ private static List generateEmbedding( // Copied from DenseVectorFieldMapperTestUtils due to dependency restrictions private static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) { return switch (elementType) { - case FLOAT, BYTE -> dimensions; + case FLOAT, BFLOAT16, BYTE -> dimensions; case BIT -> { assert dimensions % Byte.SIZE == 0; yield dimensions / Byte.SIZE; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 390c32bb773f8..36465095effbf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -1455,7 +1455,7 @@ public static DenseVectorFieldMapper.DenseVectorIndexOptions defaultBbqHnswDense int m = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; int efConstruction = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; DenseVectorFieldMapper.RescoreVector rescoreVector = new DenseVectorFieldMapper.RescoreVector(DEFAULT_RESCORE_OVERSAMPLE); - return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, rescoreVector); + return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, false, rescoreVector); } static SemanticTextIndexOptions defaultIndexOptions(IndexVersion indexVersionCreated, MinimalServiceSettings modelSettings) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 5fcb23e1fa23b..a0cb0ea1016be 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -333,7 +333,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { /** * Returns the default similarity measure for the embedding type. * Cohere embeddings are expected to be normalized to unit vectors, but due to floating point precision issues, - * our check ({@link DenseVectorFieldMapper#isNotUnitVector(float)}) often fails. + * our check ({@link DenseVectorFieldMapper.Element#isUnitVector(float)}) often fails. * Therefore, we use cosine similarity to ensure compatibility. * * @return The default similarity measure. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java index dbbf82f5703b9..78647304dcfa8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java @@ -98,6 +98,7 @@ public DenseEmbeddingResults responseBody(SageMakerModel model, InvokeEndpoin case BIT -> TextEmbeddingBinary.PARSER.apply(p, null); case BYTE -> TextEmbeddingBytes.PARSER.apply(p, null); case FLOAT -> TextEmbeddingFloat.PARSER.apply(p, null); + case BFLOAT16 -> throw new UnsupportedOperationException("Bfloat16 not supported"); }; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java index 175c3e90f798d..9bc1736a85c7b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java @@ -269,7 +269,7 @@ private static SemanticTextField randomSemanticText( ) throws IOException { ChunkedInference results = switch (model.getTaskType()) { case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { - case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BYTE, BIT -> randomChunkedInferenceEmbeddingByte(model, inputs); }; case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs, false); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 2c9c404fcb275..6beaa18cd9222 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -1861,7 +1861,7 @@ private static DenseVectorFieldMapper.DenseVectorIndexOptions defaultBbqHnswDens int m = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; int efConstruction = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; DenseVectorFieldMapper.RescoreVector rescoreVector = new DenseVectorFieldMapper.RescoreVector(DEFAULT_RESCORE_OVERSAMPLE); - return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, rescoreVector); + return new DenseVectorFieldMapper.BBQHnswIndexOptions(m, efConstruction, false, rescoreVector); } private static SemanticTextIndexOptions defaultBbqHnswSemanticTextIndexOptions() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 04f9dfbc4bebf..5eb64696b5917 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -192,6 +192,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbedding(Model mo case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BIT, BYTE -> randomChunkedInferenceEmbeddingByte(model, inputs); + case BFLOAT16 -> throw new AssertionError(); }; default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); }; @@ -222,7 +223,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingByte(Mode public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingFloat(Model model, List inputs) { DenseVectorFieldMapper.ElementType elementType = model.getServiceSettings().elementType(); int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(elementType, model.getServiceSettings().dimensions()); - assert elementType == DenseVectorFieldMapper.ElementType.FLOAT; + assert elementType == DenseVectorFieldMapper.ElementType.FLOAT || elementType == DenseVectorFieldMapper.ElementType.BFLOAT16; List chunks = new ArrayList<>(); for (String input : inputs) { @@ -272,7 +273,7 @@ public static SemanticTextField randomSemanticText( ) throws IOException { ChunkedInference results = switch (model.getTaskType()) { case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { - case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BIT, BYTE -> randomChunkedInferenceEmbeddingByte(model, inputs); }; case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); @@ -417,6 +418,7 @@ public static ChunkedInference toChunkedResult( EmbeddingResults.Embedding embedding = switch (elementType) { case FLOAT -> new DenseEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values)); case BYTE, BIT -> new DenseEmbeddingByteResults.Embedding(byteArrayOf(values)); + case BFLOAT16 -> throw new AssertionError(); }; chunks.add(new EmbeddingResults.Chunk(embedding, offset)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index f99e8ce562b42..5eb6d41be1bd0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -54,7 +54,12 @@ public static TestModel createRandomInstance(TaskType taskType, List excludedSimilarities, int maxDimensions) { if (taskType == TaskType.TEXT_EMBEDDING) { - var elementType = randomFrom(DenseVectorFieldMapper.ElementType.values()); + // TODO: bfloat16 + var elementType = randomFrom( + DenseVectorFieldMapper.ElementType.FLOAT, + DenseVectorFieldMapper.ElementType.BYTE, + DenseVectorFieldMapper.ElementType.BIT + ); var dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, maxDimensions); List supportedSimilarities = new ArrayList<>( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index cb5d1d40e2c2a..a475b8a1f4342 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -277,7 +277,7 @@ private void assertTextEmbeddingLuceneQuery(Query query) { Query innerQuery = assertOuterBooleanQuery(query); Class expectedKnnQueryClass = switch (denseVectorElementType) { - case FLOAT -> KnnFloatVectorQuery.class; + case FLOAT, BFLOAT16 -> KnnFloatVectorQuery.class; case BYTE, BIT -> KnnByteVectorQuery.class; }; assertThat(innerQuery, instanceOf(expectedKnnQueryClass)); diff --git a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java index b858b935c1483..f56b974e0a95a 100644 --- a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java +++ b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java @@ -128,6 +128,7 @@ public Object nextValue() { return vectors; } }; + case BFLOAT16 -> throw new IllegalArgumentException("Unsupported element type: bfloat16"); }; } @@ -140,6 +141,7 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { case BYTE -> new ByteRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); case FLOAT -> new FloatRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); case BIT -> new BitRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); + case BFLOAT16 -> throw new IllegalArgumentException("Unsupported element type: bfloat16"); }; } catch (IOException e) { throw new IllegalStateException("Cannot load doc values for multi-vector field!", e); diff --git a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java index adb925757b6ca..7f2bd456db4da 100644 --- a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java +++ b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java @@ -77,6 +77,9 @@ public static class Builder extends FieldMapper.Builder { "invalid element_type [" + o + "]; available types are " + namesToElementType.keySet() ); } + if (elementType == ElementType.BFLOAT16) { + throw new MapperParsingException("Rank vectors does not support bfloat16"); + } return elementType; }, m -> toType(m).fieldType().element.elementType(), @@ -342,7 +345,7 @@ public void parse(DocumentParserContext context) throws IOException { ByteBuffer buffer = ByteBuffer.allocate(bufferSize).order(ByteOrder.LITTLE_ENDIAN); ByteBuffer magnitudeBuffer = ByteBuffer.allocate(vectors.size() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); for (VectorData vector : vectors) { - vector.addToBuffer(buffer); + vector.addToBuffer(element, buffer); magnitudeBuffer.putFloat((float) Math.sqrt(element.computeSquaredMagnitude(vector))); } String vectorFieldName = fieldType().name(); diff --git a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java index 1c533e9ec88cd..bd1c06f7c1dd1 100644 --- a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java +++ b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java @@ -357,6 +357,7 @@ public MaxSimDotProduct(ScoreScript scoreScript, Object queryVector, String fiel } throw new IllegalArgumentException("Unsupported input object for float vectors: " + queryVector.getClass().getName()); } + case BFLOAT16 -> throw new IllegalArgumentException("Unsupported element type: bfloat16"); }; } diff --git a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java index 32863cbd96d25..87a7e302a7fba 100644 --- a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java +++ b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java @@ -439,6 +439,7 @@ protected Object generateRandomInputValue(MappedFieldType ft) { } yield vectors; } + case BFLOAT16 -> throw new AssertionError(); }; } @@ -477,10 +478,9 @@ private static class DenseVectorSyntheticSourceSupport implements SyntheticSourc @Override public SyntheticSourceExample example(int maxValues) { Object value = switch (elementType) { - case BYTE, BIT: - yield randomList(numVecs, numVecs, () -> randomList(dims, dims, ESTestCase::randomByte)); - case FLOAT: - yield randomList(numVecs, numVecs, () -> randomList(dims, dims, ESTestCase::randomFloat)); + case BYTE, BIT -> randomList(numVecs, numVecs, () -> randomList(dims, dims, ESTestCase::randomByte)); + case FLOAT -> randomList(numVecs, numVecs, () -> randomList(dims, dims, ESTestCase::randomFloat)); + case BFLOAT16 -> throw new AssertionError(); }; return new SyntheticSourceExample(value, value, this::mapping); }