Skip to content

Commit 6ea41e2

Browse files
committed
fix tests; fix yaml tests;
1 parent 28ba8e1 commit 6ea41e2

File tree

5 files changed

+316
-135
lines changed

5 files changed

+316
-135
lines changed

server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldTypeTests.java

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,26 @@ public void testIsNotAggregatable() {
4242
}
4343

4444
public static SparseVectorFieldMapper.SparseVectorIndexOptions randomSparseVectorIndexOptions() {
45-
return randomFrom(
46-
new SparseVectorFieldMapper.SparseVectorIndexOptions(null, null),
47-
new SparseVectorFieldMapper.SparseVectorIndexOptions(false, null),
48-
new SparseVectorFieldMapper.SparseVectorIndexOptions(true, null),
49-
new SparseVectorFieldMapper.SparseVectorIndexOptions(
50-
true,
51-
new TokenPruningConfig(randomFloatBetween(1.0f, 100.0f, true), randomFloatBetween(0.0f, 1.0f, true), randomBoolean())
52-
),
53-
new SparseVectorFieldMapper.SparseVectorIndexOptions(
54-
true,
55-
new TokenPruningConfig(randomFloatBetween(1.0f, 100.0f, true), randomFloatBetween(0.0f, 1.0f, true), randomBoolean())
56-
),
57-
new SparseVectorFieldMapper.SparseVectorIndexOptions(
58-
true,
59-
new TokenPruningConfig(randomFloatBetween(1.0f, 100.0f, true), randomFloatBetween(0.0f, 1.0f, true), randomBoolean())
60-
)
45+
return randomSparseVectorIndexOptions(true);
46+
}
47+
48+
public static SparseVectorFieldMapper.SparseVectorIndexOptions randomSparseVectorIndexOptions(boolean includeNull) {
49+
if (includeNull && randomBoolean()) {
50+
return null;
51+
}
52+
53+
Boolean prune = randomBoolean() ? null : randomBoolean();
54+
if (prune == null) {
55+
new SparseVectorFieldMapper.SparseVectorIndexOptions(null, null);
56+
}
57+
58+
if (prune == Boolean.FALSE) {
59+
new SparseVectorFieldMapper.SparseVectorIndexOptions(false, null);
60+
}
61+
62+
return new SparseVectorFieldMapper.SparseVectorIndexOptions(
63+
true,
64+
new TokenPruningConfig(randomFloatBetween(1.0f, 100.0f, true), randomFloatBetween(0.0f, 1.0f, true), randomBoolean())
6165
);
6266
}
6367
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,14 @@ private void validateIndexOptions(SemanticTextIndexOptions indexOptions, String
455455
return;
456456
}
457457

458+
if (modelSettings == null) {
459+
throw new IllegalArgumentException(
460+
"Model settings must be set to validate index options for inference ID [" + inferenceId + "]"
461+
);
462+
}
463+
458464
if (indexOptions.type() == SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR) {
459-
if (modelSettings != null && modelSettings.taskType() != SPARSE_EMBEDDING) {
465+
if (modelSettings.taskType() != SPARSE_EMBEDDING) {
460466
throw new IllegalArgumentException(
461467
"Invalid task type for index options, required ["
462468
+ SPARSE_EMBEDDING
@@ -468,12 +474,6 @@ private void validateIndexOptions(SemanticTextIndexOptions indexOptions, String
468474
return;
469475
}
470476

471-
if (modelSettings == null) {
472-
throw new IllegalArgumentException(
473-
"Model settings must be set to validate index options for inference ID [" + inferenceId + "]"
474-
);
475-
}
476-
477477
if (indexOptions.type() == SemanticTextIndexOptions.SupportedIndexOptions.DENSE_VECTOR) {
478478
if (modelSettings.taskType() != TEXT_EMBEDDING) {
479479
throw new IllegalArgumentException(

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java

Lines changed: 19 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import org.elasticsearch.common.lucene.search.Queries;
3737
import org.elasticsearch.common.settings.Settings;
3838
import org.elasticsearch.core.CheckedConsumer;
39-
import org.elasticsearch.core.Tuple;
4039
import org.elasticsearch.index.IndexVersion;
4140
import org.elasticsearch.index.IndexVersions;
4241
import org.elasticsearch.index.mapper.DocumentMapper;
@@ -643,56 +642,6 @@ private void addSparseVectorModelSettingsToBuilder(XContentBuilder b) throws IOE
643642
b.endObject();
644643
}
645644

646-
private void setSparseVectorIndexOptionInMapper(XContentBuilder b, SparseVectorFieldMapper.SparseVectorIndexOptions indexOptions)
647-
throws IOException {
648-
setSparseVectorIndexOptionInMapper(b, indexOptions, null);
649-
}
650-
651-
private void setSparseVectorIndexOptionInMapper(
652-
XContentBuilder b,
653-
SparseVectorFieldMapper.SparseVectorIndexOptions indexOptions,
654-
Tuple<String, Object> injectExtraField
655-
) throws IOException {
656-
if (indexOptions == null) {
657-
return;
658-
}
659-
660-
b.startObject(INDEX_OPTIONS_FIELD);
661-
{
662-
b.startObject(SparseVectorFieldMapper.CONTENT_TYPE);
663-
{
664-
if (indexOptions.getPrune() != null) {
665-
b.field(SparseVectorFieldMapper.SparseVectorIndexOptions.PRUNE_FIELD_NAME.getPreferredName(), indexOptions.getPrune());
666-
}
667-
668-
if (indexOptions.getPruningConfig() != null) {
669-
b.startObject(SparseVectorFieldMapper.SparseVectorIndexOptions.PRUNING_CONFIG_FIELD_NAME.getPreferredName());
670-
{
671-
b.field(
672-
TokenPruningConfig.TOKENS_FREQ_RATIO_THRESHOLD.getPreferredName(),
673-
indexOptions.getPruningConfig().getTokensFreqRatioThreshold()
674-
);
675-
b.field(
676-
TokenPruningConfig.TOKENS_WEIGHT_THRESHOLD.getPreferredName(),
677-
indexOptions.getPruningConfig().getTokensWeightThreshold()
678-
);
679-
b.field(
680-
TokenPruningConfig.ONLY_SCORE_PRUNED_TOKENS_FIELD.getPreferredName(),
681-
indexOptions.getPruningConfig().isOnlyScorePrunedTokens()
682-
);
683-
}
684-
b.endObject();
685-
}
686-
687-
if (injectExtraField != null) {
688-
b.field(injectExtraField.v1(), injectExtraField.v2());
689-
}
690-
}
691-
b.endObject();
692-
}
693-
694-
}
695-
696645
public void testSparseVectorIndexOptionsValidationAndMapping() throws IOException {
697646
for (int depth = 1; depth < 5; depth++) {
698647
SparseVectorFieldMapper.SparseVectorIndexOptions indexOptions = SparseVectorFieldTypeTests.randomSparseVectorIndexOptions();
@@ -706,8 +655,14 @@ public void testSparseVectorIndexOptionsValidationAndMapping() throws IOExceptio
706655
b.field("type", SemanticTextFieldMapper.CONTENT_TYPE);
707656
b.field(INFERENCE_ID_FIELD, inferenceId);
708657
addSparseVectorModelSettingsToBuilder(b);
709-
setSparseVectorIndexOptionInMapper(b, indexOptions);
710-
b.endObject();
658+
if (indexOptions != null) {
659+
b.startObject(INDEX_OPTIONS_FIELD);
660+
{
661+
b.field(SparseVectorFieldMapper.CONTENT_TYPE);
662+
indexOptions.toXContent(b, null);
663+
}
664+
b.endObject();
665+
}
711666
}
712667
b.endObject();
713668
}));
@@ -717,7 +672,9 @@ public void testSparseVectorIndexOptionsValidationAndMapping() throws IOExceptio
717672
fieldName,
718673
true,
719674
null,
720-
new SemanticTextIndexOptions(SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR, indexOptions)
675+
indexOptions == null
676+
? null
677+
: new SemanticTextIndexOptions(SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR, indexOptions)
721678
);
722679
}
723680
}
@@ -770,26 +727,6 @@ public void testSparseVectorMappingUpdate() throws IOException {
770727
}
771728
}
772729

773-
public void testSparseVectorValidationWithUnknownParameter() throws IOException {
774-
for (int depth = 1; depth < 5; depth++) {
775-
SparseVectorFieldMapper.SparseVectorIndexOptions indexOptions = SparseVectorFieldTypeTests.randomSparseVectorIndexOptions();
776-
String inferenceId = "test_model";
777-
String fieldName = randomFieldName(depth);
778-
779-
Exception exc = expectThrows(MapperParsingException.class, () -> createMapperService(mapping(b -> {
780-
b.startObject(fieldName);
781-
{
782-
b.field("type", SemanticTextFieldMapper.CONTENT_TYPE);
783-
b.field(INFERENCE_ID_FIELD, inferenceId);
784-
setSparseVectorIndexOptionInMapper(b, indexOptions, new Tuple<>("unknown_parameter", "test"));
785-
b.endObject();
786-
}
787-
b.endObject();
788-
})));
789-
assertTrue(exc.getMessage().contains("[index_options] unknown field [unknown_parameter]"));
790-
}
791-
}
792-
793730
public void testUpdateSearchInferenceId() throws IOException {
794731
final String inferenceId = "test_inference_id";
795732
final String searchInferenceId1 = "test_search_inference_id_1";
@@ -1546,7 +1483,7 @@ public void testSparseVectorIndexOptionsDefaultsBeforeSupport() throws IOExcepti
15461483
"field",
15471484
true,
15481485
null,
1549-
defaultSparseVectorIndexOptions(mapperService.getIndexSettings().getIndexVersionCreated())
1486+
null
15501487
);
15511488
}
15521489

@@ -1662,12 +1599,16 @@ public void testSpecifiedDenseVectorIndexOptions() throws IOException {
16621599

16631600
public void testSpecificSparseVectorIndexOptions() throws IOException {
16641601
for (int i = 0; i < 10; i++) {
1665-
SparseVectorFieldMapper.SparseVectorIndexOptions testIndexOptions = randomSparseVectorIndexOptions();
1602+
SparseVectorFieldMapper.SparseVectorIndexOptions testIndexOptions = randomSparseVectorIndexOptions(false);
16661603
var mapperService = createMapperService(fieldMapping(b -> {
16671604
b.field("type", SemanticTextFieldMapper.CONTENT_TYPE);
16681605
b.field(INFERENCE_ID_FIELD, "test_inference_id");
16691606
addSparseVectorModelSettingsToBuilder(b);
1670-
setSparseVectorIndexOptionInMapper(b, testIndexOptions);
1607+
b.startObject(INDEX_OPTIONS_FIELD);
1608+
{
1609+
b.field(SparseVectorFieldMapper.CONTENT_TYPE);
1610+
testIndexOptions.toXContent(b, null);
1611+
}
16711612
b.endObject();
16721613
}), useLegacyFormat, IndexVersions.INFERENCE_METADATA_FIELDS_BACKPORT);
16731614

@@ -1741,7 +1682,7 @@ public static SemanticTextIndexOptions randomSemanticTextIndexOptions(TaskType t
17411682
? null
17421683
: new SemanticTextIndexOptions(
17431684
SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR,
1744-
randomSparseVectorIndexOptions()
1685+
randomSparseVectorIndexOptions(false)
17451686
);
17461687
}
17471688

0 commit comments

Comments
 (0)