Skip to content

Commit 0821047

Browse files
Adding test for search functionality
1 parent af86ba5 commit 0821047

File tree

2 files changed

+59
-54
lines changed

2 files changed

+59
-54
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java

Lines changed: 37 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,16 @@
77

88
package org.elasticsearch.xpack.inference.integration;
99

10-
import org.elasticsearch.action.DocWriteResponse;
10+
import org.elasticsearch.action.DocWriteResponse;;
1111
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
1212
import org.elasticsearch.action.search.SearchRequest;
13-
import org.elasticsearch.action.search.SearchResponse;
1413
import org.elasticsearch.common.settings.Settings;
1514
import org.elasticsearch.core.TimeValue;
1615
import org.elasticsearch.index.IndexVersion;
17-
import org.elasticsearch.index.IndexVersions;
18-
import org.elasticsearch.index.query.QueryBuilders;
1916
import org.elasticsearch.license.LicenseSettings;
2017
import org.elasticsearch.plugins.Plugin;
2118
import org.elasticsearch.search.builder.SearchSourceBuilder;
19+
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
2220
import org.elasticsearch.test.ESIntegTestCase;
2321
import org.elasticsearch.test.index.IndexVersionUtils;
2422
import org.elasticsearch.xcontent.XContentBuilder;
@@ -38,21 +36,22 @@
3836
import java.util.stream.Collectors;
3937

4038
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
39+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHighlight;
40+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
41+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
4142
import static org.hamcrest.Matchers.equalTo;
4243

4344
public class SemanticTextIndexVersionIT extends ESIntegTestCase {
44-
private static final IndexVersion SEMANTIC_TEXT_INTRODUCED_VERSION = IndexVersions.SEMANTIC_TEXT_FIELD_TYPE;
45-
private static final IndexVersion SEMANTIC_TEXT_NEW_FORMAT = IndexVersions.INFERENCE_METADATA_FIELDS_BACKPORT;
45+
private static final IndexVersion SEMANTIC_TEXT_INTRODUCED_VERSION = IndexVersion.fromId(8512000);
4646

4747
private Set<IndexVersion> availableVersions;
48-
private static final int MIN_NUMBER_OF_TESTS_TO_RUN = 10;
48+
private static final int MIN_NUMBER_OF_TESTS_TO_RUN = 1;
4949

5050
@Before
5151
public void setup() throws Exception {
5252
Utils.storeSparseModel(client());
53-
availableVersions = IndexVersionUtils.allReleasedVersions()
54-
.stream()
55-
.filter((version -> version.onOrAfter(SEMANTIC_TEXT_INTRODUCED_VERSION)))
53+
availableVersions = IndexVersionUtils.allReleasedVersions().stream()
54+
.filter(indexVersion -> indexVersion.after(SEMANTIC_TEXT_INTRODUCED_VERSION))
5655
.collect(Collectors.toSet());
5756

5857
logger.info("Available versions for testing: {}", availableVersions);
@@ -100,11 +99,11 @@ protected Map<String, IndexVersion> createRandomVersionIndices() throws IOExcept
10099
return result;
101100
}
102101

103-
public void test() throws Exception {
102+
public void testSemanticText() throws Exception {
104103
Map<String, IndexVersion> indices = createRandomVersionIndices();
105104
for (String indexName : indices.keySet()) {
106105
IndexVersion version = indices.get(indexName);
107-
logger.info("Testing index [{}] with version [{}]", indexName, version);
106+
logger.info("Testing index [{}] with version [{}] [{}]", indexName, version, version.toReleaseVersion());
108107

109108
// Test index creation
110109
assertTrue("Index " + indexName + " should exist", indexExists(indexName));
@@ -136,51 +135,35 @@ public void test() throws Exception {
136135

137136
// Test data ingestion
138137
String[] text = new String[] { "inference test", "another inference test" };
138+
DocWriteResponse docWriteResponse = client().prepareIndex(indexName)
139+
.setSource(Map.of("semantic_field", text))
140+
.get();
139141

140-
DocWriteResponse response = client().prepareIndex(indexName).setSource(Map.of("semantic_field", text)).get();
141-
142-
assertEquals("Document should be created", "created", response.getResult().toString().toLowerCase());
142+
assertEquals("Document should be created", "created", docWriteResponse.getResult().toString().toLowerCase());
143143

144+
// Ensure index is ready
144145
client().admin().indices().refresh(new RefreshRequest(indexName)).get();
145-
146-
// Simple search
147-
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().trackTotalHits(true);
148-
SearchResponse searchResponse = client().search(new SearchRequest(indexName).source(sourceBuilder)).get();
149-
try {
150-
assertThat(searchResponse.getHits().getTotalHits().value(), equalTo(1L));
151-
} finally {
152-
searchResponse.decRef();
153-
}
154-
155-
// Search with query
156-
SearchResponse searchWithQueryResponse = null;
157-
if (version.after(SEMANTIC_TEXT_NEW_FORMAT)) {
158-
searchWithQueryResponse = client().search(
159-
new SearchRequest(indexName).source(
160-
sourceBuilder.query(QueryBuilders.matchQuery("semantic_field", "another inference test"))
161-
)
162-
).get();
163-
} else {
164-
String semanticQuery = """
165-
{
166-
"semantic": {
167-
"field": "semantic_field",
168-
"query": "inference"
169-
}
170-
}
171-
""";
172-
searchWithQueryResponse = client().search(
173-
new SearchRequest(indexName).source(sourceBuilder.query(new SemanticQueryBuilder("semantic_field", "inference test")))
174-
).get();
175-
}
176-
177-
try {
178-
assertThat(searchResponse.getHits().getTotalHits().value(), equalTo(1L));
179-
} finally {
180-
searchResponse.decRef();
181-
}
182-
146+
ensureGreen(indexName);
147+
148+
// Semantic Search
149+
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
150+
.query(new SemanticQueryBuilder("semantic_field", "inference"))
151+
.trackTotalHits(true);
152+
153+
assertResponse(client().search(new SearchRequest(indexName).source(sourceBuilder)), response -> {
154+
assertHitCount(response, 1L);
155+
});
156+
157+
//Semantic Search with highlighter
158+
SearchSourceBuilder sourceHighlighterBuilder = new SearchSourceBuilder()
159+
.query(new SemanticQueryBuilder("semantic_field", "inference"))
160+
.highlighter(new HighlightBuilder().field(new HighlightBuilder.Field("semantic_field").numOfFragments(1)))
161+
.trackTotalHits(true);
162+
163+
assertResponse(client().search(new SearchRequest(indexName).source(sourceBuilder)), response -> {
164+
assertHitCount(response, 1L);
165+
assertHighlight(response, 0, "semantic_field", 0, 2, equalTo("inference"));
166+
});
183167
}
184168
}
185-
186169
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,22 @@
88
package org.elasticsearch.xpack.inference;
99

1010
import org.elasticsearch.action.support.MappedActionFilter;
11+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1112
import org.elasticsearch.common.settings.Settings;
1213
import org.elasticsearch.index.mapper.Mapper;
14+
import org.elasticsearch.index.query.QueryBuilder;
1315
import org.elasticsearch.inference.InferenceServiceExtension;
1416
import org.elasticsearch.license.XPackLicenseState;
1517
import org.elasticsearch.plugins.SearchPlugin;
1618
import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
19+
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
20+
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
1721
import org.elasticsearch.xpack.core.ssl.SSLService;
1822
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
1923
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
2024

2125
import java.nio.file.Path;
26+
import java.util.ArrayList;
2227
import java.util.Collection;
2328
import java.util.List;
2429
import java.util.Map;
@@ -68,4 +73,21 @@ public Collection<MappedActionFilter> getMappedActionFilters() {
6873
return inferencePlugin.getMappedActionFilters();
6974
}
7075

76+
@Override
77+
public List<QuerySpec<?>> getQueries() {
78+
return inferencePlugin.getQueries();
79+
}
80+
81+
@Override
82+
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
83+
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>(super.getNamedWriteables());
84+
namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
85+
namedWriteables.add(new NamedWriteableRegistry.Entry(
86+
QueryBuilder.class,
87+
SparseVectorQueryBuilder.NAME,
88+
SparseVectorQueryBuilder::new
89+
));
90+
91+
return namedWriteables;
92+
}
7193
}

0 commit comments

Comments
 (0)