Skip to content

Commit c0a9732

Browse files
committed
fix additional tests
1 parent e3bbccc commit c0a9732

File tree

2 files changed

+29
-39
lines changed

2 files changed

+29
-39
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -452,14 +452,8 @@ private void validateIndexOptions(SemanticTextIndexOptions indexOptions, String
452452
return;
453453
}
454454

455-
if (modelSettings == null) {
456-
throw new IllegalArgumentException(
457-
"Model settings must be set to validate index options for inference ID [" + inferenceId + "]"
458-
);
459-
}
460-
461455
if (indexOptions.type() == SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR) {
462-
if (modelSettings.taskType() != SPARSE_EMBEDDING) {
456+
if (modelSettings != null && modelSettings.taskType() != SPARSE_EMBEDDING) {
463457
throw new IllegalArgumentException(
464458
"Invalid task type for index options, required ["
465459
+ SPARSE_EMBEDDING
@@ -471,6 +465,12 @@ private void validateIndexOptions(SemanticTextIndexOptions indexOptions, String
471465
return;
472466
}
473467

468+
if (modelSettings == null) {
469+
throw new IllegalArgumentException(
470+
"Model settings must be set to validate index options for inference ID [" + inferenceId + "]"
471+
);
472+
}
473+
474474
if (indexOptions.type() == SemanticTextIndexOptions.SupportedIndexOptions.DENSE_VECTOR) {
475475
if (modelSettings.taskType() != TEXT_EMBEDDING) {
476476
throw new IllegalArgumentException(

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import org.elasticsearch.common.lucene.search.Queries;
3737
import org.elasticsearch.common.settings.Settings;
3838
import org.elasticsearch.core.CheckedConsumer;
39-
import org.elasticsearch.core.Nullable;
4039
import org.elasticsearch.core.Tuple;
4140
import org.elasticsearch.index.IndexVersion;
4241
import org.elasticsearch.index.IndexVersions;
@@ -385,6 +384,18 @@ public void testInvalidInferenceEndpoints() {
385384
}
386385
}
387386

387+
private SemanticTextIndexOptions getDefaultIndexOptionsForMapper(MapperService mapperService) {
388+
var mapperIndexVersion = mapperService.getIndexSettings().getIndexVersionCreated();
389+
var defaultSparseVectorIndexOptions = SparseVectorFieldMapper.SparseVectorIndexOptions
390+
.getDefaultIndexOptions(mapperIndexVersion);
391+
return defaultSparseVectorIndexOptions == null
392+
? null
393+
: new SemanticTextIndexOptions(
394+
SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR,
395+
defaultSparseVectorIndexOptions
396+
);
397+
}
398+
388399
public void testInvalidTaskTypes() {
389400
for (var taskType : TaskType.values()) {
390401
if (taskType == TaskType.TEXT_EMBEDDING || taskType == TaskType.SPARSE_EMBEDDING) {
@@ -446,7 +457,8 @@ public void testMultiFieldsSupport() throws IOException {
446457
b.endObject();
447458
b.endObject();
448459
}), useLegacyFormat);
449-
assertSemanticTextField(mapperService, "field", true, null, null);
460+
var expectedIndexOptions = getDefaultIndexOptionsForMapper(mapperService);
461+
assertSemanticTextField(mapperService, "field", true, null, expectedIndexOptions);
450462

451463
mapperService = createMapperService(fieldMapping(b -> {
452464
b.field("type", "semantic_text");
@@ -477,7 +489,6 @@ public void testMultiFieldsSupport() throws IOException {
477489
b.endObject();
478490
}), useLegacyFormat));
479491
assertThat(e.getMessage(), containsString("is already used by another field"));
480-
481492
}
482493
}
483494

@@ -719,13 +730,12 @@ public void testSparseVectorMappingUpdate() throws IOException {
719730
model.getInferenceEntityId(),
720731
null,
721732
chunkingSettings,
722-
indexOptions,
723-
TaskType.SPARSE_EMBEDDING
733+
indexOptions
724734
)
725735
),
726736
useLegacyFormat
727737
);
728-
assertSemanticTextField(mapperService, fieldName, true, chunkingSettings, indexOptions);
738+
assertSemanticTextField(mapperService, fieldName, false, chunkingSettings, indexOptions);
729739

730740
final SemanticTextIndexOptions newIndexOptions = randomSemanticTextIndexOptions(TaskType.SPARSE_EMBEDDING);
731741
ChunkingSettings newChunkingSettings = generateRandomChunkingSettingsOtherThan(chunkingSettings);
@@ -738,12 +748,11 @@ public void testSparseVectorMappingUpdate() throws IOException {
738748
model.getInferenceEntityId(),
739749
null,
740750
newChunkingSettings,
741-
newIndexOptions,
742-
TaskType.SPARSE_EMBEDDING
751+
newIndexOptions
743752
)
744753
)
745754
);
746-
assertSemanticTextField(mapperService, fieldName, true, newChunkingSettings, newIndexOptions);
755+
assertSemanticTextField(mapperService, fieldName, false, newChunkingSettings, newIndexOptions);
747756
}
748757
}
749758

@@ -939,29 +948,27 @@ public void testSuccessfulParse() throws IOException {
939948
model1.getInferenceEntityId(),
940949
setSearchInferenceId ? searchInferenceId : null,
941950
chunkingSettings,
942-
indexOptions,
943-
TaskType.SPARSE_EMBEDDING
951+
indexOptions
944952
);
945953
addSemanticTextMapping(
946954
b,
947955
fieldName2,
948956
model2.getInferenceEntityId(),
949957
setSearchInferenceId ? searchInferenceId : null,
950958
chunkingSettings,
951-
indexOptions,
952-
TaskType.SPARSE_EMBEDDING
959+
indexOptions
953960
);
954961
});
955962

956963
MapperService mapperService = createMapperService(mapping, useLegacyFormat);
957-
assertSemanticTextField(mapperService, fieldName1, true, null, indexOptions);
964+
assertSemanticTextField(mapperService, fieldName1, false, null, indexOptions);
958965
assertInferenceEndpoints(
959966
mapperService,
960967
fieldName1,
961968
model1.getInferenceEntityId(),
962969
setSearchInferenceId ? searchInferenceId : model1.getInferenceEntityId()
963970
);
964-
assertSemanticTextField(mapperService, fieldName2, true, null, indexOptions);
971+
assertSemanticTextField(mapperService, fieldName2, false, null, indexOptions);
965972
assertInferenceEndpoints(
966973
mapperService,
967974
fieldName2,
@@ -1732,18 +1739,6 @@ private static void addSemanticTextMapping(
17321739
String searchInferenceId,
17331740
ChunkingSettings chunkingSettings,
17341741
SemanticTextIndexOptions indexOptions
1735-
) throws IOException {
1736-
addSemanticTextMapping(mappingBuilder, fieldName, inferenceId, searchInferenceId, chunkingSettings, indexOptions, null);
1737-
}
1738-
1739-
private static void addSemanticTextMapping(
1740-
XContentBuilder mappingBuilder,
1741-
String fieldName,
1742-
String inferenceId,
1743-
String searchInferenceId,
1744-
ChunkingSettings chunkingSettings,
1745-
SemanticTextIndexOptions indexOptions,
1746-
@Nullable TaskType modelSettingsTaskType
17471742
) throws IOException {
17481743
mappingBuilder.startObject(fieldName);
17491744
mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE);
@@ -1760,11 +1755,6 @@ private static void addSemanticTextMapping(
17601755
mappingBuilder.field(INDEX_OPTIONS_FIELD);
17611756
indexOptions.toXContent(mappingBuilder, null);
17621757
}
1763-
if (modelSettingsTaskType != null) {
1764-
mappingBuilder.startObject("model_settings");
1765-
mappingBuilder.field("task_type", modelSettingsTaskType.toString());
1766-
mappingBuilder.endObject();
1767-
}
17681758
mappingBuilder.endObject();
17691759
}
17701760

0 commit comments

Comments
 (0)