Skip to content

Commit d7f2ee3

Browse files
Decoupled namedWritables to use separate fake plugin and simplified other override methods
1 parent 84f7ae5 commit d7f2ee3

File tree

2 files changed

+71
-54
lines changed

2 files changed

+71
-54
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java

Lines changed: 69 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@
1010
import org.elasticsearch.action.DocWriteResponse;
1111
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
1212
import org.elasticsearch.action.search.SearchRequest;
13+
import org.elasticsearch.cluster.metadata.IndexMetadata;
14+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1315
import org.elasticsearch.common.settings.Settings;
1416
import org.elasticsearch.core.TimeValue;
1517
import org.elasticsearch.index.IndexVersion;
16-
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
18+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
19+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils;
20+
import org.elasticsearch.index.query.QueryBuilder;
21+
import org.elasticsearch.inference.SimilarityMeasure;
1722
import org.elasticsearch.license.LicenseSettings;
1823
import org.elasticsearch.plugins.Plugin;
1924
import org.elasticsearch.search.builder.SearchSourceBuilder;
@@ -22,20 +27,22 @@
2227
import org.elasticsearch.test.index.IndexVersionUtils;
2328
import org.elasticsearch.xcontent.XContentBuilder;
2429
import org.elasticsearch.xcontent.XContentFactory;
30+
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
31+
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
2532
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
2633
import org.elasticsearch.xpack.inference.Utils;
2734
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
2835
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
2936
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
3037
import org.junit.Before;
3138

32-
import java.io.IOException;
39+
import java.util.ArrayList;
3340
import java.util.Collection;
34-
import java.util.HashMap;
3541
import java.util.List;
3642
import java.util.Locale;
3743
import java.util.Map;
3844
import java.util.Set;
45+
import java.util.function.Function;
3946
import java.util.stream.Collectors;
4047

4148
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
@@ -47,16 +54,38 @@
4754
public class SemanticTextIndexVersionIT extends ESIntegTestCase {
4855
private static final IndexVersion SEMANTIC_TEXT_INTRODUCED_VERSION = IndexVersion.fromId(8512000);
4956
private static final double PERCENTAGE_TO_TEST = 0.5;
50-
private Set<IndexVersion> availableVersions;
57+
private static final int MAXIMUM_NUMBER_OF_VERSIONS_TO_TEST = 25;
58+
private List<IndexVersion> selectedVersions;
5159

5260
@Before
5361
public void setup() throws Exception {
5462
ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class);
63+
DenseVectorFieldMapper.ElementType elementType = randomFrom(DenseVectorFieldMapper.ElementType.values());
64+
// dot product means that we need normalized vectors; it's not worth doing that in this test
65+
SimilarityMeasure similarity = randomValueOtherThan(
66+
SimilarityMeasure.DOT_PRODUCT,
67+
() -> randomFrom(DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType))
68+
);
69+
int dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 100);
5570
Utils.storeSparseModel(modelRegistry);
56-
availableVersions = IndexVersionUtils.allReleasedVersions()
71+
Utils.storeDenseModel(modelRegistry, dimensions, similarity, elementType);
72+
73+
Set<IndexVersion> availableVersions = IndexVersionUtils.allReleasedVersions()
5774
.stream()
5875
.filter(indexVersion -> indexVersion.after(SEMANTIC_TEXT_INTRODUCED_VERSION))
5976
.collect(Collectors.toSet());
77+
78+
Function<Set<IndexVersion>, Integer> determineNumberOfVersionsToTest = versions -> {
79+
int totalVersions = versions.size();
80+
int percentageTestSize = (int) Math.ceil(totalVersions * PERCENTAGE_TO_TEST);
81+
82+
return totalVersions < MAXIMUM_NUMBER_OF_VERSIONS_TO_TEST
83+
? totalVersions
84+
: Math.min(percentageTestSize, MAXIMUM_NUMBER_OF_VERSIONS_TO_TEST);
85+
};
86+
87+
int versionsCount = determineNumberOfVersionsToTest.apply(availableVersions);
88+
selectedVersions = randomSubsetOf(versionsCount, availableVersions);
6089
}
6190

6291
@Override
@@ -71,44 +100,24 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
71100

72101
@Override
73102
protected Collection<Class<? extends Plugin>> nodePlugins() {
74-
return List.of(LocalStateInferencePlugin.class);
103+
return List.of(LocalStateInferencePlugin.class, FakeMlPlugin.class, FakeSemanticQueryBuilderPlugin.class);
75104
}
76105

77106
/**
78107
* Generate settings for an index with a specific version.
79108
*/
80109
private Settings getIndexSettingsWithVersion(IndexVersion version) {
81-
return Settings.builder().put(indexSettings()).put("index.version.created", version).build();
82-
}
83-
84-
/**
85-
* Creates a subset of indices with different versions for testing.
86-
*
87-
* @return Map of created indices with their versions
88-
*/
89-
protected Map<String, IndexVersion> createRandomVersionIndices() throws IOException {
90-
int versionsCount = (int) Math.ceil(availableVersions.size() * PERCENTAGE_TO_TEST);
91-
List<IndexVersion> selectedVersions = randomSubsetOf(versionsCount, availableVersions);
92-
Map<String, IndexVersion> result = new HashMap<>();
93-
94-
for (int i = 0; i < selectedVersions.size(); i++) {
95-
String indexName = "test_semantic" + "_" + i;
96-
IndexVersion version = selectedVersions.get(i);
97-
createIndex(indexName, getIndexSettingsWithVersion(version));
98-
result.put(indexName, version);
99-
}
100-
101-
return result;
110+
return Settings.builder().put(indexSettings()).put(IndexMetadata.SETTING_VERSION_CREATED, version).build();
102111
}
103112

104113
/**
105114
* This test creates an index, ingests data, and performs searches (including highlighting when applicable)
106115
* for a selected subset of index versions.
107116
*/
108117
public void testSemanticText() throws Exception {
109-
Map<String, IndexVersion> indices = createRandomVersionIndices();
110-
for (String indexName : indices.keySet()) {
111-
IndexVersion version = indices.get(indexName);
118+
for (IndexVersion version : selectedVersions) {
119+
String indexName = "test_semantic_" + randomAlphaOfLength(5).toLowerCase(Locale.ROOT);
120+
createIndex(indexName, getIndexSettingsWithVersion(version));
112121

113122
// Test index creation
114123
assertTrue("Index " + indexName + " should exist", indexExists(indexName));
@@ -164,17 +173,36 @@ public void testSemanticText() throws Exception {
164173
.getIndexToSettings()
165174
.get(indexName);
166175

167-
// Semantic Search with highlighter only available from 8.18 and 9.0
168-
if (InferenceMetadataFieldsMapper.isEnabled(settings)) {
169-
SearchSourceBuilder sourceHighlighterBuilder = new SearchSourceBuilder().query(
170-
new SemanticQueryBuilder("semantic_field", "inference")
171-
).highlighter(new HighlightBuilder().field("semantic_field")).trackTotalHits(true);
172-
173-
assertResponse(client().search(new SearchRequest(indexName).source(sourceHighlighterBuilder)), response -> {
174-
assertHighlight(response, 0, "semantic_field", 0, 2, equalTo("inference test"));
175-
assertHighlight(response, 0, "semantic_field", 1, 2, equalTo("another inference test"));
176-
});
177-
}
176+
// Semantic Search with highlighter
177+
SearchSourceBuilder sourceHighlighterBuilder = new SearchSourceBuilder().query(
178+
new SemanticQueryBuilder("semantic_field", "inference")
179+
).highlighter(new HighlightBuilder().field("semantic_field")).trackTotalHits(true);
180+
181+
assertResponse(client().search(new SearchRequest(indexName).source(sourceHighlighterBuilder)), response -> {
182+
assertHighlight(response, 0, "semantic_field", 0, 2, equalTo("inference test"));
183+
assertHighlight(response, 0, "semantic_field", 1, 2, equalTo("another inference test"));
184+
});
185+
186+
beforeIndexDeletion();
187+
assertAcked(client().admin().indices().prepareDelete(indexName));
188+
}
189+
}
190+
191+
public static class FakeMlPlugin extends Plugin {
192+
@Override
193+
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
194+
return new MlInferenceNamedXContentProvider().getNamedWriteables();
195+
}
196+
}
197+
198+
public static class FakeSemanticQueryBuilderPlugin extends Plugin {
199+
@Override
200+
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
201+
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
202+
namedWriteables.add(
203+
new NamedWriteableRegistry.Entry(QueryBuilder.class, SparseVectorQueryBuilder.NAME, SparseVectorQueryBuilder::new)
204+
);
205+
return namedWriteables;
178206
}
179207
}
180208
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,16 @@
1111
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1212
import org.elasticsearch.common.settings.Settings;
1313
import org.elasticsearch.index.mapper.Mapper;
14-
import org.elasticsearch.index.query.QueryBuilder;
1514
import org.elasticsearch.inference.InferenceServiceExtension;
1615
import org.elasticsearch.license.XPackLicenseState;
1716
import org.elasticsearch.plugins.SearchPlugin;
1817
import org.elasticsearch.search.fetch.subphase.highlight.Highlighter;
1918
import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
20-
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
21-
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
2219
import org.elasticsearch.xpack.core.ssl.SSLService;
23-
import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter;
2420
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
2521
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
2622

2723
import java.nio.file.Path;
28-
import java.util.ArrayList;
2924
import java.util.Collection;
3025
import java.util.List;
3126
import java.util.Map;
@@ -72,7 +67,7 @@ public Map<String, Mapper.TypeParser> getMappers() {
7267

7368
@Override
7469
public Map<String, Highlighter> getHighlighters() {
75-
return Map.of(SemanticTextHighlighter.NAME, new SemanticTextHighlighter());
70+
return inferencePlugin.getHighlighters();
7671
}
7772

7873
@Override
@@ -87,12 +82,6 @@ public List<QuerySpec<?>> getQueries() {
8782

8883
@Override
8984
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
90-
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>(super.getNamedWriteables());
91-
namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
92-
namedWriteables.add(
93-
new NamedWriteableRegistry.Entry(QueryBuilder.class, SparseVectorQueryBuilder.NAME, SparseVectorQueryBuilder::new)
94-
);
95-
96-
return namedWriteables;
85+
return super.getNamedWriteables();
9786
}
9887
}

0 commit comments

Comments
 (0)