|
37 | 37 | import org.elasticsearch.common.lucene.search.Queries; |
38 | 38 | import org.elasticsearch.common.settings.Settings; |
39 | 39 | import org.elasticsearch.core.CheckedConsumer; |
| 40 | +import org.elasticsearch.core.Nullable; |
40 | 41 | import org.elasticsearch.index.IndexVersion; |
41 | 42 | import org.elasticsearch.index.IndexVersions; |
42 | 43 | import org.elasticsearch.index.mapper.DocumentMapper; |
|
103 | 104 | import java.util.Map; |
104 | 105 | import java.util.Set; |
105 | 106 | import java.util.function.BiConsumer; |
106 | | -import java.util.function.Function; |
107 | 107 | import java.util.function.Supplier; |
108 | 108 |
|
109 | 109 | import static org.elasticsearch.index.IndexVersions.NEW_SPARSE_VECTOR; |
@@ -933,26 +933,30 @@ private static void assertEmbeddingsFieldMapperMatchesModel(MapperService mapper |
933 | 933 | Mapper embeddingsFieldMapper = mapperService.mappingLookup().getMapper(getEmbeddingsFieldName(fieldName)); |
934 | 934 | switch (model.getTaskType()) { |
935 | 935 | case SPARSE_EMBEDDING -> assertThat(embeddingsFieldMapper, is(instanceOf(SparseVectorFieldMapper.class))); |
936 | | - case TEXT_EMBEDDING -> assertTextEmbeddingsFieldMapperMatchesModel(embeddingsFieldMapper, model); |
| 936 | + case TEXT_EMBEDDING -> { |
| 937 | + SemanticTextFieldMapper semanticFieldMapper = getSemanticFieldMapper(mapperService, fieldName); |
| 938 | + DenseVectorFieldMapper.ElementType expectedElementType = getExpectedElementType( |
| 939 | + mapperService.getIndexSettings().getIndexVersionCreated(), |
| 940 | + model.getServiceSettings().elementType(), |
| 941 | + semanticFieldMapper.fieldType().getIndexOptions() |
| 942 | + ); |
| 943 | + assertTextEmbeddingsFieldMapperMatchesModel(embeddingsFieldMapper, model, expectedElementType); |
| 944 | + } |
937 | 945 | default -> throw new AssertionError("Unexpected task type [" + model.getTaskType() + "]"); |
938 | 946 | } |
939 | 947 | } |
940 | 948 |
|
941 | | - private static void assertTextEmbeddingsFieldMapperMatchesModel(Mapper embeddingsFieldMapper, Model model) { |
942 | | - Function<SimilarityMeasure, DenseVectorFieldMapper.VectorSimilarity> convertToVectorSimilarity = s -> switch (s) { |
943 | | - case COSINE -> DenseVectorFieldMapper.VectorSimilarity.COSINE; |
944 | | - case DOT_PRODUCT -> DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT; |
945 | | - case L2_NORM -> DenseVectorFieldMapper.VectorSimilarity.L2_NORM; |
946 | | - }; |
| 949 | + private static void assertTextEmbeddingsFieldMapperMatchesModel( |
| 950 | + Mapper embeddingsFieldMapper, |
| 951 | + Model model, |
| 952 | + DenseVectorFieldMapper.ElementType expectedElementType |
| 953 | + ) { |
947 | 954 | assertThat(embeddingsFieldMapper, is(instanceOf(DenseVectorFieldMapper.class))); |
948 | 955 | DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) embeddingsFieldMapper; |
949 | 956 | ServiceSettings modelServiceSettings = model.getConfigurations().getServiceSettings(); |
950 | 957 | assertThat(denseVectorFieldMapper.fieldType().getVectorDimensions(), equalTo(modelServiceSettings.dimensions())); |
951 | | - assertThat(denseVectorFieldMapper.fieldType().getElementType(), equalTo(modelServiceSettings.elementType())); |
952 | | - assertThat( |
953 | | - denseVectorFieldMapper.fieldType().getSimilarity(), |
954 | | - equalTo(convertToVectorSimilarity.apply(modelServiceSettings.similarity())) |
955 | | - ); |
| 958 | + assertThat(denseVectorFieldMapper.fieldType().getElementType(), equalTo(expectedElementType)); |
| 959 | + assertThat(denseVectorFieldMapper.fieldType().getSimilarity(), equalTo(modelServiceSettings.similarity().vectorSimilarity())); |
956 | 960 | } |
957 | 961 |
|
958 | 962 | private void testUpdateInferenceId_GivenDenseModelsWithDifferentSettings( |
@@ -1325,22 +1329,23 @@ private static void assertSemanticTextField( |
1325 | 1329 | assertThat(embeddingsMapper, instanceOf(DenseVectorFieldMapper.class)); |
1326 | 1330 | DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) embeddingsMapper; |
1327 | 1331 |
|
1328 | | - MinimalServiceSettings modelSettings = semanticTextFieldType.getModelSettings(); |
1329 | | - DenseVectorFieldMapper.ElementType expectedElementType = getExpectedDefaultElementType(indexVersion, modelSettings); |
1330 | 1332 | if (expectedIndexOptions != null) { |
1331 | 1333 | IndexOptions expectedEmbeddingFieldIndexOptions = expectedIndexOptions.indexOptions(); |
1332 | 1334 | if (expectedEmbeddingFieldIndexOptions instanceof ExtendedDenseVectorIndexOptions edvio) { |
1333 | 1335 | assertEquals(edvio.getBaseIndexOptions(), denseVectorFieldMapper.fieldType().getIndexOptions()); |
1334 | | - if (edvio.getElementType() != null) { |
1335 | | - expectedElementType = edvio.getElementType(); |
1336 | | - } |
1337 | 1336 | } else { |
1338 | 1337 | assertEquals(expectedEmbeddingFieldIndexOptions, denseVectorFieldMapper.fieldType().getIndexOptions()); |
1339 | 1338 | } |
1340 | 1339 | } else { |
1341 | 1340 | assertNull(denseVectorFieldMapper.fieldType().getIndexOptions()); |
1342 | 1341 | } |
1343 | 1342 |
|
| 1343 | + MinimalServiceSettings modelSettings = semanticTextFieldType.getModelSettings(); |
| 1344 | + DenseVectorFieldMapper.ElementType expectedElementType = getExpectedElementType( |
| 1345 | + indexVersion, |
| 1346 | + modelSettings.elementType(), |
| 1347 | + expectedIndexOptions |
| 1348 | + ); |
1344 | 1349 | assertEquals(expectedElementType, denseVectorFieldMapper.fieldType().getElementType()); |
1345 | 1350 | assertEquals(modelSettings.dimensions().intValue(), denseVectorFieldMapper.fieldType().getVectorDimensions()); |
1346 | 1351 | if (modelSettings.similarity() != null && indexVersion.onOrAfter(NEW_SPARSE_VECTOR)) { |
@@ -1368,19 +1373,6 @@ private static SemanticTextFieldMapper getSemanticFieldMapper(MapperService mapp |
1368 | 1373 | return (SemanticTextFieldMapper) mapper; |
1369 | 1374 | } |
1370 | 1375 |
|
1371 | | - private static DenseVectorFieldMapper.ElementType getExpectedDefaultElementType( |
1372 | | - IndexVersion indexVersion, |
1373 | | - MinimalServiceSettings modelSettings |
1374 | | - ) { |
1375 | | - DenseVectorFieldMapper.ElementType expected = modelSettings.elementType(); |
1376 | | - if (indexVersion.onOrAfter(SEMANTIC_TEXT_DEFAULTS_TO_BFLOAT16) |
1377 | | - && modelSettings.elementType() == DenseVectorFieldMapper.ElementType.FLOAT) { |
1378 | | - expected = DenseVectorFieldMapper.ElementType.BFLOAT16; |
1379 | | - } |
1380 | | - |
1381 | | - return expected; |
1382 | | - } |
1383 | | - |
1384 | 1376 | private static void assertInferenceEndpoints( |
1385 | 1377 | MapperService mapperService, |
1386 | 1378 | String fieldName, |
@@ -2431,6 +2423,24 @@ public static SemanticTextIndexOptions randomSemanticTextIndexOptions(TaskType t |
2431 | 2423 | return null; |
2432 | 2424 | } |
2433 | 2425 |
|
| 2426 | + private static DenseVectorFieldMapper.ElementType getExpectedElementType( |
| 2427 | + IndexVersion indexVersion, |
| 2428 | + DenseVectorFieldMapper.ElementType modelElementType, |
| 2429 | + @Nullable SemanticTextIndexOptions semanticTextIndexOptions |
| 2430 | + ) { |
| 2431 | + if (semanticTextIndexOptions != null && semanticTextIndexOptions.indexOptions() instanceof ExtendedDenseVectorIndexOptions edvio) { |
| 2432 | + if (edvio.getElementType() != null) { |
| 2433 | + return edvio.getElementType(); |
| 2434 | + } |
| 2435 | + } |
| 2436 | + |
| 2437 | + DenseVectorFieldMapper.ElementType expectedElementType = modelElementType; |
| 2438 | + if (indexVersion.onOrAfter(SEMANTIC_TEXT_DEFAULTS_TO_BFLOAT16) && expectedElementType == DenseVectorFieldMapper.ElementType.FLOAT) { |
| 2439 | + expectedElementType = DenseVectorFieldMapper.ElementType.BFLOAT16; |
| 2440 | + } |
| 2441 | + return expectedElementType; |
| 2442 | + } |
| 2443 | + |
2434 | 2444 | @Override |
2435 | 2445 | protected void assertExistsQuery(MappedFieldType fieldType, Query query, LuceneDocument fields) { |
2436 | 2446 | // Until a doc is indexed, the query is rewritten as match no docs |
|
0 commit comments