Skip to content

Commit eb6c470

Browse files
authored
Fix semantic_text inference ID update tests that default to BFLOAT16 (elastic#145679)
1 parent d77a9f3 commit eb6c470

File tree

3 files changed

+45
-37
lines changed

3 files changed

+45
-37
lines changed

muted-tests.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -354,9 +354,6 @@ tests:
354354
- class: org.elasticsearch.xpack.esql.CsvTests
355355
method: test {csv-spec:dense_vector_aggs.sumDenseVectorSingleRow}
356356
issue: https://github.com/elastic/elasticsearch/issues/145549
357-
- class: org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapperTests
358-
method: testUpdateInferenceId_GivenCurrentHasNoModelSettingsAndNewIsIncompatibleTaskType_ShouldSucceed {p0=false}
359-
issue: https://github.com/elastic/elasticsearch/issues/145550
360357
- class: org.elasticsearch.xpack.sql.qa.jdbc.single_node.SingleNodeJdbcResultSetIT
361358
method: testGettingValidByteWithCasting
362359
issue: https://github.com/elastic/elasticsearch/issues/145564
@@ -374,9 +371,6 @@ tests:
374371
- class: org.elasticsearch.xpack.esql.CsvTests
375372
method: test {csv-spec:stats_percentile.percentileOfLong}
376373
issue: https://github.com/elastic/elasticsearch/issues/145589
377-
- class: org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapperTests
378-
method: testUpdateInferenceId_GivenCurrentHasNoModelSettingsAndNewIsIncompatibleTaskType_ShouldSucceed {p0=true}
379-
issue: https://github.com/elastic/elasticsearch/issues/145591
380374
- class: org.elasticsearch.xpack.sql.qa.jdbc.single_node.SingleNodeJdbcFetchSizeIT
381375
method: testScrollWithDatetimeAndTimezoneParam
382376
issue: https://github.com/elastic/elasticsearch/issues/145592

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,10 @@ public ChunkingSettings getChunkingSettings() {
10231023
return chunkingSettings;
10241024
}
10251025

1026+
public SemanticTextIndexOptions getIndexOptions() {
1027+
return indexOptions;
1028+
}
1029+
10261030
public ObjectMapper getInferenceField() {
10271031
return inferenceField;
10281032
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.elasticsearch.common.lucene.search.Queries;
3838
import org.elasticsearch.common.settings.Settings;
3939
import org.elasticsearch.core.CheckedConsumer;
40+
import org.elasticsearch.core.Nullable;
4041
import org.elasticsearch.index.IndexVersion;
4142
import org.elasticsearch.index.IndexVersions;
4243
import org.elasticsearch.index.mapper.DocumentMapper;
@@ -103,7 +104,6 @@
103104
import java.util.Map;
104105
import java.util.Set;
105106
import java.util.function.BiConsumer;
106-
import java.util.function.Function;
107107
import java.util.function.Supplier;
108108

109109
import static org.elasticsearch.index.IndexVersions.NEW_SPARSE_VECTOR;
@@ -933,26 +933,30 @@ private static void assertEmbeddingsFieldMapperMatchesModel(MapperService mapper
933933
Mapper embeddingsFieldMapper = mapperService.mappingLookup().getMapper(getEmbeddingsFieldName(fieldName));
934934
switch (model.getTaskType()) {
935935
case SPARSE_EMBEDDING -> assertThat(embeddingsFieldMapper, is(instanceOf(SparseVectorFieldMapper.class)));
936-
case TEXT_EMBEDDING -> assertTextEmbeddingsFieldMapperMatchesModel(embeddingsFieldMapper, model);
936+
case TEXT_EMBEDDING -> {
937+
SemanticTextFieldMapper semanticFieldMapper = getSemanticFieldMapper(mapperService, fieldName);
938+
DenseVectorFieldMapper.ElementType expectedElementType = getExpectedElementType(
939+
mapperService.getIndexSettings().getIndexVersionCreated(),
940+
model.getServiceSettings().elementType(),
941+
semanticFieldMapper.fieldType().getIndexOptions()
942+
);
943+
assertTextEmbeddingsFieldMapperMatchesModel(embeddingsFieldMapper, model, expectedElementType);
944+
}
937945
default -> throw new AssertionError("Unexpected task type [" + model.getTaskType() + "]");
938946
}
939947
}
940948

941-
private static void assertTextEmbeddingsFieldMapperMatchesModel(Mapper embeddingsFieldMapper, Model model) {
942-
Function<SimilarityMeasure, DenseVectorFieldMapper.VectorSimilarity> convertToVectorSimilarity = s -> switch (s) {
943-
case COSINE -> DenseVectorFieldMapper.VectorSimilarity.COSINE;
944-
case DOT_PRODUCT -> DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT;
945-
case L2_NORM -> DenseVectorFieldMapper.VectorSimilarity.L2_NORM;
946-
};
949+
private static void assertTextEmbeddingsFieldMapperMatchesModel(
950+
Mapper embeddingsFieldMapper,
951+
Model model,
952+
DenseVectorFieldMapper.ElementType expectedElementType
953+
) {
947954
assertThat(embeddingsFieldMapper, is(instanceOf(DenseVectorFieldMapper.class)));
948955
DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) embeddingsFieldMapper;
949956
ServiceSettings modelServiceSettings = model.getConfigurations().getServiceSettings();
950957
assertThat(denseVectorFieldMapper.fieldType().getVectorDimensions(), equalTo(modelServiceSettings.dimensions()));
951-
assertThat(denseVectorFieldMapper.fieldType().getElementType(), equalTo(modelServiceSettings.elementType()));
952-
assertThat(
953-
denseVectorFieldMapper.fieldType().getSimilarity(),
954-
equalTo(convertToVectorSimilarity.apply(modelServiceSettings.similarity()))
955-
);
958+
assertThat(denseVectorFieldMapper.fieldType().getElementType(), equalTo(expectedElementType));
959+
assertThat(denseVectorFieldMapper.fieldType().getSimilarity(), equalTo(modelServiceSettings.similarity().vectorSimilarity()));
956960
}
957961

958962
private void testUpdateInferenceId_GivenDenseModelsWithDifferentSettings(
@@ -1325,22 +1329,23 @@ private static void assertSemanticTextField(
13251329
assertThat(embeddingsMapper, instanceOf(DenseVectorFieldMapper.class));
13261330
DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) embeddingsMapper;
13271331

1328-
MinimalServiceSettings modelSettings = semanticTextFieldType.getModelSettings();
1329-
DenseVectorFieldMapper.ElementType expectedElementType = getExpectedDefaultElementType(indexVersion, modelSettings);
13301332
if (expectedIndexOptions != null) {
13311333
IndexOptions expectedEmbeddingFieldIndexOptions = expectedIndexOptions.indexOptions();
13321334
if (expectedEmbeddingFieldIndexOptions instanceof ExtendedDenseVectorIndexOptions edvio) {
13331335
assertEquals(edvio.getBaseIndexOptions(), denseVectorFieldMapper.fieldType().getIndexOptions());
1334-
if (edvio.getElementType() != null) {
1335-
expectedElementType = edvio.getElementType();
1336-
}
13371336
} else {
13381337
assertEquals(expectedEmbeddingFieldIndexOptions, denseVectorFieldMapper.fieldType().getIndexOptions());
13391338
}
13401339
} else {
13411340
assertNull(denseVectorFieldMapper.fieldType().getIndexOptions());
13421341
}
13431342

1343+
MinimalServiceSettings modelSettings = semanticTextFieldType.getModelSettings();
1344+
DenseVectorFieldMapper.ElementType expectedElementType = getExpectedElementType(
1345+
indexVersion,
1346+
modelSettings.elementType(),
1347+
expectedIndexOptions
1348+
);
13441349
assertEquals(expectedElementType, denseVectorFieldMapper.fieldType().getElementType());
13451350
assertEquals(modelSettings.dimensions().intValue(), denseVectorFieldMapper.fieldType().getVectorDimensions());
13461351
if (modelSettings.similarity() != null && indexVersion.onOrAfter(NEW_SPARSE_VECTOR)) {
@@ -1368,19 +1373,6 @@ private static SemanticTextFieldMapper getSemanticFieldMapper(MapperService mapp
13681373
return (SemanticTextFieldMapper) mapper;
13691374
}
13701375

1371-
private static DenseVectorFieldMapper.ElementType getExpectedDefaultElementType(
1372-
IndexVersion indexVersion,
1373-
MinimalServiceSettings modelSettings
1374-
) {
1375-
DenseVectorFieldMapper.ElementType expected = modelSettings.elementType();
1376-
if (indexVersion.onOrAfter(SEMANTIC_TEXT_DEFAULTS_TO_BFLOAT16)
1377-
&& modelSettings.elementType() == DenseVectorFieldMapper.ElementType.FLOAT) {
1378-
expected = DenseVectorFieldMapper.ElementType.BFLOAT16;
1379-
}
1380-
1381-
return expected;
1382-
}
1383-
13841376
private static void assertInferenceEndpoints(
13851377
MapperService mapperService,
13861378
String fieldName,
@@ -2431,6 +2423,24 @@ public static SemanticTextIndexOptions randomSemanticTextIndexOptions(TaskType t
24312423
return null;
24322424
}
24332425

2426+
private static DenseVectorFieldMapper.ElementType getExpectedElementType(
2427+
IndexVersion indexVersion,
2428+
DenseVectorFieldMapper.ElementType modelElementType,
2429+
@Nullable SemanticTextIndexOptions semanticTextIndexOptions
2430+
) {
2431+
if (semanticTextIndexOptions != null && semanticTextIndexOptions.indexOptions() instanceof ExtendedDenseVectorIndexOptions edvio) {
2432+
if (edvio.getElementType() != null) {
2433+
return edvio.getElementType();
2434+
}
2435+
}
2436+
2437+
DenseVectorFieldMapper.ElementType expectedElementType = modelElementType;
2438+
if (indexVersion.onOrAfter(SEMANTIC_TEXT_DEFAULTS_TO_BFLOAT16) && expectedElementType == DenseVectorFieldMapper.ElementType.FLOAT) {
2439+
expectedElementType = DenseVectorFieldMapper.ElementType.BFLOAT16;
2440+
}
2441+
return expectedElementType;
2442+
}
2443+
24342444
@Override
24352445
protected void assertExistsQuery(MappedFieldType fieldType, Query query, LuceneDocument fields) {
24362446
// Until a doc is indexed, the query is rewritten as match no docs

0 commit comments

Comments
 (0)