Skip to content

Commit 4dcb691

Browse files
Fix SemanticQueryBuilder dependencies
1 parent 0dbca60 commit 4dcb691

File tree

3 files changed

+61
-46
lines changed

3 files changed

+61
-46
lines changed

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,13 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
285285
return entries;
286286
}
287287

288+
@Override
289+
public List<QuerySpec<?>> getQueries() {
290+
List<QuerySpec<?>> querySpecs = new ArrayList<>(super.getQueries());
291+
filterPlugins(SearchPlugin.class).stream().flatMap(p -> p.getQueries().stream()).forEach(querySpecs::add);
292+
return querySpecs;
293+
}
294+
288295
@Override
289296
public List<NamedXContentRegistry.Entry> getNamedXContent() {
290297
List<NamedXContentRegistry.Entry> entries = new ArrayList<>(super.getNamedXContent());

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java

Lines changed: 54 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import org.elasticsearch.index.IndexVersion;
1818
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1919
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils;
20-
import org.elasticsearch.index.query.QueryBuilder;
2120
import org.elasticsearch.inference.SimilarityMeasure;
2221
import org.elasticsearch.license.LicenseSettings;
2322
import org.elasticsearch.plugins.Plugin;
@@ -28,16 +27,16 @@
2827
import org.elasticsearch.xcontent.XContentBuilder;
2928
import org.elasticsearch.xcontent.XContentFactory;
3029
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
31-
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
3230
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
3331
import org.elasticsearch.xpack.inference.Utils;
32+
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
3433
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
3534
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
3635
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
3736
import org.junit.Before;
3837

39-
import java.util.ArrayList;
4038
import java.util.Collection;
39+
import java.util.HashMap;
4140
import java.util.List;
4241
import java.util.Locale;
4342
import java.util.Map;
@@ -55,6 +54,8 @@ public class SemanticTextIndexVersionIT extends ESIntegTestCase {
5554
private static final IndexVersion SEMANTIC_TEXT_INTRODUCED_VERSION = IndexVersion.fromId(8512000);
5655
private static final double PERCENTAGE_TO_TEST = 0.5;
5756
private static final int MAXIMUM_NUMBER_OF_VERSIONS_TO_TEST = 25;
57+
private static final String SPARSE_SEMANTIC_FIELD = "sparse_field";
58+
private static final String DENSE_SEMANTIC_FIELD = "dense_field";
5859
private List<IndexVersion> selectedVersions;
5960

6061
@Before
@@ -100,7 +101,7 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
100101

101102
@Override
102103
protected Collection<Class<? extends Plugin>> nodePlugins() {
103-
return List.of(LocalStateInferencePlugin.class, FakeMlPlugin.class, FakeSemanticQueryBuilderPlugin.class);
104+
return List.of(LocalStateInferencePlugin.class, FakeMlPlugin.class);
104105
}
105106

106107
/**
@@ -117,9 +118,23 @@ private Settings getIndexSettingsWithVersion(IndexVersion version) {
117118
public void testSemanticText() throws Exception {
118119
for (IndexVersion version : selectedVersions) {
119120
String indexName = "test_semantic_" + randomAlphaOfLength(5).toLowerCase(Locale.ROOT);
120-
createIndex(indexName, getIndexSettingsWithVersion(version));
121+
XContentBuilder mapping = XContentFactory.jsonBuilder()
122+
.startObject()
123+
.startObject("properties")
124+
.startObject(SPARSE_SEMANTIC_FIELD)
125+
.field("type", "semantic_text")
126+
.field("inference_id", TestSparseInferenceServiceExtension.TestInferenceService.NAME)
127+
.endObject()
128+
.startObject(DENSE_SEMANTIC_FIELD)
129+
.field("type", "semantic_text")
130+
.field("inference_id", TestDenseInferenceServiceExtension.TestInferenceService.NAME)
131+
.endObject()
132+
.endObject()
133+
.endObject();
134+
135+
assertAcked(prepareCreate(indexName).setSettings(getIndexSettingsWithVersion(version)).setMapping(mapping).get());
121136

122-
// Test index creation
137+
// Test index creation with expected version id
123138
assertTrue("Index " + indexName + " should exist", indexExists(indexName));
124139
assertEquals(
125140
"Index version should match",
@@ -134,46 +149,55 @@ public void testSemanticText() throws Exception {
134149
.id()
135150
);
136151

137-
// Test update mapping
138-
XContentBuilder mapping = XContentFactory.jsonBuilder()
139-
.startObject()
140-
.startObject("properties")
141-
.startObject("semantic_field")
142-
.field("type", "semantic_text")
143-
.field("inference_id", TestSparseInferenceServiceExtension.TestInferenceService.NAME)
144-
.endObject()
145-
.endObject()
146-
.endObject();
147-
148-
assertAcked(client().admin().indices().preparePutMapping(indexName).setSource(mapping).get());
149-
150152
// Test data ingestion
151153
String[] text = new String[] { "inference test", "another inference test" };
152-
DocWriteResponse docWriteResponse = client().prepareIndex(indexName).setSource(Map.of("semantic_field", text)).get();
154+
Map<String, String[]> sourceMap = new HashMap<>();
155+
sourceMap.put(SPARSE_SEMANTIC_FIELD, text);
156+
sourceMap.put(DENSE_SEMANTIC_FIELD, text);
157+
DocWriteResponse docWriteResponse = client().prepareIndex(indexName).setSource(sourceMap).get();
153158

154159
assertEquals("Document should be created", "created", docWriteResponse.getResult().toString().toLowerCase(Locale.ROOT));
155160

156161
// Ensure index is ready
157162
client().admin().indices().refresh(new RefreshRequest(indexName)).get();
158163
ensureGreen(indexName);
159164

160-
// Semantic Search
161-
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(new SemanticQueryBuilder("semantic_field", "inference"))
165+
// Semantic search with sparse embedding
166+
SearchSourceBuilder sparseSourceBuilder = new SearchSourceBuilder().query(new SemanticQueryBuilder(SPARSE_SEMANTIC_FIELD, "inference"))
167+
.trackTotalHits(true);
168+
169+
assertResponse(
170+
client().search(new SearchRequest(indexName).source(sparseSourceBuilder)),
171+
response -> { assertHitCount(response, 1L); }
172+
);
173+
174+
// Highlighting semantic search with sparse embedding
175+
SearchSourceBuilder sparseSourceHighlighterBuilder = new SearchSourceBuilder().query(
176+
new SemanticQueryBuilder(SPARSE_SEMANTIC_FIELD, "inference")
177+
).highlighter(new HighlightBuilder().field(SPARSE_SEMANTIC_FIELD)).trackTotalHits(true);
178+
179+
assertResponse(client().search(new SearchRequest(indexName).source(sparseSourceHighlighterBuilder)), response -> {
180+
assertHighlight(response, 0, SPARSE_SEMANTIC_FIELD, 0, 2, equalTo("inference test"));
181+
assertHighlight(response, 0, SPARSE_SEMANTIC_FIELD, 1, 2, equalTo("another inference test"));
182+
});
183+
184+
// Semantic search with text embedding
185+
SearchSourceBuilder textSourceBuilder = new SearchSourceBuilder().query(new SemanticQueryBuilder(DENSE_SEMANTIC_FIELD, "inference"))
162186
.trackTotalHits(true);
163187

164188
assertResponse(
165-
client().search(new SearchRequest(indexName).source(sourceBuilder)),
189+
client().search(new SearchRequest(indexName).source(textSourceBuilder)),
166190
response -> { assertHitCount(response, 1L); }
167191
);
168192

169-
// Semantic Search with highlighter
170-
SearchSourceBuilder sourceHighlighterBuilder = new SearchSourceBuilder().query(
171-
new SemanticQueryBuilder("semantic_field", "inference")
172-
).highlighter(new HighlightBuilder().field("semantic_field")).trackTotalHits(true);
193+
// Highlighting semantic search with text embedding
194+
SearchSourceBuilder textSourceHighlighterBuilder = new SearchSourceBuilder().query(
195+
new SemanticQueryBuilder(DENSE_SEMANTIC_FIELD, "inference")
196+
).highlighter(new HighlightBuilder().field(DENSE_SEMANTIC_FIELD)).trackTotalHits(true);
173197

174-
assertResponse(client().search(new SearchRequest(indexName).source(sourceHighlighterBuilder)), response -> {
175-
assertHighlight(response, 0, "semantic_field", 0, 2, equalTo("inference test"));
176-
assertHighlight(response, 0, "semantic_field", 1, 2, equalTo("another inference test"));
198+
assertResponse(client().search(new SearchRequest(indexName).source(textSourceHighlighterBuilder)), response -> {
199+
assertHighlight(response, 0, DENSE_SEMANTIC_FIELD, 0, 2, equalTo("inference test"));
200+
assertHighlight(response, 0, DENSE_SEMANTIC_FIELD, 1, 2, equalTo("another inference test"));
177201
});
178202

179203
beforeIndexDeletion();
@@ -187,15 +211,4 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
187211
return new MlInferenceNamedXContentProvider().getNamedWriteables();
188212
}
189213
}
190-
191-
public static class FakeSemanticQueryBuilderPlugin extends Plugin {
192-
@Override
193-
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
194-
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
195-
namedWriteables.add(
196-
new NamedWriteableRegistry.Entry(QueryBuilder.class, SparseVectorQueryBuilder.NAME, SparseVectorQueryBuilder::new)
197-
);
198-
return namedWriteables;
199-
}
200-
}
201214
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,6 @@ public Collection<MappedActionFilter> getMappedActionFilters() {
7575
return inferencePlugin.getMappedActionFilters();
7676
}
7777

78-
@Override
79-
public List<QuerySpec<?>> getQueries() {
80-
return inferencePlugin.getQueries();
81-
}
82-
8378
@Override
8479
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
8580
return super.getNamedWriteables();

0 commit comments

Comments
 (0)