Skip to content

Commit 6ad05ed

Browse files
authored
Use Suppliers To Get Inference Results In Semantic Queries (#136720) (#136868)
(cherry picked from commit e531d64) # Conflicts: # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java
1 parent 421246d commit 6ad05ed

File tree

14 files changed

+524
-291
lines changed

14 files changed

+524
-291
lines changed

docs/changelog/136720.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 136720
2+
summary: Use Suppliers To Get Inference Results In Semantic Queries
3+
area: Vector Search
4+
type: bug
5+
issues:
6+
- 136621

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java

Lines changed: 2 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import org.elasticsearch.action.support.broadcast.BroadcastResponse;
2121
import org.elasticsearch.client.internal.Client;
2222
import org.elasticsearch.common.bytes.BytesReference;
23-
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
2423
import org.elasticsearch.common.settings.Settings;
2524
import org.elasticsearch.core.Nullable;
2625
import org.elasticsearch.core.TimeValue;
@@ -31,27 +30,16 @@
3130
import org.elasticsearch.inference.SimilarityMeasure;
3231
import org.elasticsearch.inference.TaskType;
3332
import org.elasticsearch.license.LicenseSettings;
34-
import org.elasticsearch.plugins.ActionPlugin;
3533
import org.elasticsearch.plugins.Plugin;
36-
import org.elasticsearch.plugins.SearchPlugin;
3734
import org.elasticsearch.rest.RestStatus;
3835
import org.elasticsearch.search.SearchHit;
3936
import org.elasticsearch.search.builder.SearchSourceBuilder;
4037
import org.elasticsearch.test.AbstractMultiClustersTestCase;
4138
import org.elasticsearch.transport.RemoteConnectionInfo;
42-
import org.elasticsearch.xcontent.XContentBuilder;
43-
import org.elasticsearch.xcontent.XContentFactory;
44-
import org.elasticsearch.xcontent.XContentType;
45-
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
46-
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
47-
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
48-
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
39+
import org.elasticsearch.xpack.inference.FakeMlPlugin;
4940
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
5041
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
51-
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
5242
import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin;
53-
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
54-
import org.elasticsearch.xpack.ml.action.TransportCoordinatedInferenceAction;
5543

5644
import java.io.IOException;
5745
import java.util.Collection;
@@ -66,6 +54,7 @@
6654

6755
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
6856
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
57+
import static org.elasticsearch.xpack.inference.integration.IntegrationTestUtils.createInferenceEndpoint;
6958
import static org.hamcrest.Matchers.equalTo;
7059
import static org.hamcrest.Matchers.is;
7160

@@ -165,35 +154,6 @@ protected BytesReference openPointInTime(String[] indices, TimeValue keepAlive)
165154
return response.getPointInTimeId();
166155
}
167156

168-
protected static void createInferenceEndpoint(Client client, TaskType taskType, String inferenceId, Map<String, Object> serviceSettings)
169-
throws IOException {
170-
final String service = switch (taskType) {
171-
case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME;
172-
case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME;
173-
default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]");
174-
};
175-
176-
final BytesReference content;
177-
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
178-
builder.startObject();
179-
builder.field("service", service);
180-
builder.field("service_settings", serviceSettings);
181-
builder.endObject();
182-
183-
content = BytesReference.bytes(builder);
184-
}
185-
186-
PutInferenceModelAction.Request request = new PutInferenceModelAction.Request(
187-
taskType,
188-
inferenceId,
189-
content,
190-
XContentType.JSON,
191-
TEST_REQUEST_TIMEOUT
192-
);
193-
var responseFuture = client.execute(PutInferenceModelAction.INSTANCE, request);
194-
assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId));
195-
}
196-
197157
protected void assertSearchResponse(QueryBuilder queryBuilder, List<IndexWithBoost> indices, List<SearchResult> expectedSearchResults)
198158
throws Exception {
199159
assertSearchResponse(queryBuilder, indices, expectedSearchResults, null, null);
@@ -307,29 +267,6 @@ protected static String[] convertToArray(List<IndexWithBoost> indices) {
307267
return indices.stream().map(IndexWithBoost::index).toArray(String[]::new);
308268
}
309269

310-
public static class FakeMlPlugin extends Plugin implements ActionPlugin, SearchPlugin {
311-
@Override
312-
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
313-
return new MlInferenceNamedXContentProvider().getNamedWriteables();
314-
}
315-
316-
@Override
317-
public List<QueryVectorBuilderSpec<?>> getQueryVectorBuilders() {
318-
return List.of(
319-
new QueryVectorBuilderSpec<>(
320-
TextEmbeddingQueryVectorBuilder.NAME,
321-
TextEmbeddingQueryVectorBuilder::new,
322-
TextEmbeddingQueryVectorBuilder.PARSER
323-
)
324-
);
325-
}
326-
327-
@Override
328-
public Collection<ActionHandler> getActions() {
329-
return List.of(new ActionHandler(CoordinatedInferenceAction.INSTANCE, TransportCoordinatedInferenceAction.class));
330-
}
331-
}
332-
333270
protected record TestIndexInfo(
334271
String name,
335272
Map<String, MinimalServiceSettings> inferenceEndpoints,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.integration;
9+
10+
import org.elasticsearch.action.support.IndicesOptions;
11+
import org.elasticsearch.client.internal.Client;
12+
import org.elasticsearch.common.bytes.BytesReference;
13+
import org.elasticsearch.inference.TaskType;
14+
import org.elasticsearch.xcontent.XContentBuilder;
15+
import org.elasticsearch.xcontent.XContentFactory;
16+
import org.elasticsearch.xcontent.XContentType;
17+
import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction;
18+
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
19+
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
20+
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
21+
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
22+
23+
import java.io.IOException;
24+
import java.util.Map;
25+
26+
import static org.elasticsearch.test.ESTestCase.TEST_REQUEST_TIMEOUT;
27+
import static org.elasticsearch.test.ESTestCase.safeGet;
28+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
29+
import static org.hamcrest.MatcherAssert.assertThat;
30+
import static org.hamcrest.Matchers.equalTo;
31+
32+
public class IntegrationTestUtils {
33+
private IntegrationTestUtils() {}
34+
35+
public static void createInferenceEndpoint(Client client, TaskType taskType, String inferenceId, Map<String, Object> serviceSettings)
36+
throws IOException {
37+
final String service = switch (taskType) {
38+
case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME;
39+
case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME;
40+
default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]");
41+
};
42+
43+
final BytesReference content;
44+
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
45+
builder.startObject();
46+
builder.field("service", service);
47+
builder.field("service_settings", serviceSettings);
48+
builder.endObject();
49+
50+
content = BytesReference.bytes(builder);
51+
}
52+
53+
PutInferenceModelAction.Request request = new PutInferenceModelAction.Request(
54+
taskType,
55+
inferenceId,
56+
content,
57+
XContentType.JSON,
58+
TEST_REQUEST_TIMEOUT
59+
);
60+
var responseFuture = client.execute(PutInferenceModelAction.INSTANCE, request);
61+
assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId));
62+
}
63+
64+
public static void deleteInferenceEndpoint(Client client, TaskType taskType, String inferenceId) {
65+
assertAcked(
66+
safeGet(
67+
client.execute(
68+
DeleteInferenceEndpointAction.INSTANCE,
69+
new DeleteInferenceEndpointAction.Request(inferenceId, taskType, true, false)
70+
)
71+
)
72+
);
73+
}
74+
75+
public static void deleteIndex(Client client, String indexName) {
76+
assertAcked(
77+
safeGet(
78+
client.admin()
79+
.indices()
80+
.prepareDelete(indexName)
81+
.setIndicesOptions(
82+
IndicesOptions.builder().concreteTargetOptions(new IndicesOptions.ConcreteTargetOptions(true)).build()
83+
)
84+
.execute()
85+
)
86+
);
87+
}
88+
89+
public static XContentBuilder generateSemanticTextMapping(Map<String, String> semanticTextFields) throws IOException {
90+
XContentBuilder mapping = XContentFactory.jsonBuilder().startObject().startObject("properties");
91+
for (var entry : semanticTextFields.entrySet()) {
92+
mapping.startObject(entry.getKey());
93+
mapping.field("type", SemanticTextFieldMapper.CONTENT_TYPE);
94+
mapping.field("inference_id", entry.getValue());
95+
mapping.endObject();
96+
}
97+
mapping.endObject().endObject();
98+
99+
return mapping;
100+
}
101+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.integration;
9+
10+
import org.elasticsearch.action.DocWriteResponse;
11+
import org.elasticsearch.action.search.SearchRequest;
12+
import org.elasticsearch.common.settings.Settings;
13+
import org.elasticsearch.index.query.BoolQueryBuilder;
14+
import org.elasticsearch.index.query.MatchQueryBuilder;
15+
import org.elasticsearch.index.query.QueryBuilder;
16+
import org.elasticsearch.index.query.QueryBuilders;
17+
import org.elasticsearch.inference.TaskType;
18+
import org.elasticsearch.license.LicenseSettings;
19+
import org.elasticsearch.plugins.Plugin;
20+
import org.elasticsearch.reindex.ReindexPlugin;
21+
import org.elasticsearch.search.builder.SearchSourceBuilder;
22+
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
23+
import org.elasticsearch.test.ESIntegTestCase;
24+
import org.elasticsearch.xcontent.XContentBuilder;
25+
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
26+
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
27+
import org.elasticsearch.xpack.inference.FakeMlPlugin;
28+
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
29+
import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin;
30+
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
31+
import org.junit.After;
32+
33+
import java.io.IOException;
34+
import java.util.Collection;
35+
import java.util.HashMap;
36+
import java.util.List;
37+
import java.util.Map;
38+
import java.util.function.BiFunction;
39+
40+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
41+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
42+
import static org.hamcrest.CoreMatchers.equalTo;
43+
import static org.hamcrest.CoreMatchers.is;
44+
45+
@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 1)
46+
public class ManyInferenceQueryClausesIT extends ESIntegTestCase {
47+
private static final String INDEX_NAME = "test_index";
48+
49+
private static final Map<String, Object> SPARSE_EMBEDDING_SERVICE_SETTINGS = Map.of("model", "my_model", "api_key", "my_api_key");
50+
private static final Map<String, Object> TEXT_EMBEDDING_SERVICE_SETTINGS = Map.of(
51+
"model",
52+
"my_model",
53+
"dimensions",
54+
256,
55+
"similarity",
56+
"cosine",
57+
"api_key",
58+
"my_api_key"
59+
);
60+
61+
private final Map<String, TaskType> inferenceIds = new HashMap<>();
62+
63+
@Override
64+
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
65+
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
66+
}
67+
68+
@Override
69+
protected Collection<Class<? extends Plugin>> nodePlugins() {
70+
return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, ReindexPlugin.class, FakeMlPlugin.class);
71+
}
72+
73+
@After
74+
public void cleanUp() {
75+
IntegrationTestUtils.deleteIndex(client(), INDEX_NAME);
76+
for (var entry : inferenceIds.entrySet()) {
77+
IntegrationTestUtils.deleteInferenceEndpoint(client(), entry.getValue(), entry.getKey());
78+
}
79+
}
80+
81+
public void testManySemanticQueryClauses() throws Exception {
82+
manyQueryClausesTestCase(randomIntBetween(18, 24), SemanticQueryBuilder::new, TaskType.SPARSE_EMBEDDING);
83+
}
84+
85+
public void testManyMatchQueryClauses() throws Exception {
86+
manyQueryClausesTestCase(randomIntBetween(18, 24), MatchQueryBuilder::new, TaskType.SPARSE_EMBEDDING);
87+
}
88+
89+
public void testManySparseVectorQueryClauses() throws Exception {
90+
manyQueryClausesTestCase(randomIntBetween(18, 24), (f, q) -> new SparseVectorQueryBuilder(f, null, q), TaskType.SPARSE_EMBEDDING);
91+
}
92+
93+
public void testManyKnnQueryClauses() throws Exception {
94+
int clauseCount = randomIntBetween(18, 24);
95+
manyQueryClausesTestCase(
96+
clauseCount,
97+
(f, q) -> new KnnVectorQueryBuilder(f, new TextEmbeddingQueryVectorBuilder(null, q), clauseCount, clauseCount * 10, null, null),
98+
TaskType.TEXT_EMBEDDING
99+
);
100+
}
101+
102+
private void manyQueryClausesTestCase(
103+
int clauseCount,
104+
BiFunction<String, String, QueryBuilder> clauseGenerator,
105+
TaskType semanticTextFieldTaskType
106+
) throws Exception {
107+
Map<String, Object> inferenceEndpointServiceSettings = getServiceSettings(semanticTextFieldTaskType);
108+
Map<String, String> semanticTextFields = new HashMap<>(clauseCount);
109+
for (int i = 0; i < clauseCount; i++) {
110+
String fieldName = randomAlphaOfLength(10);
111+
String inferenceId = randomIdentifier();
112+
113+
createInferenceEndpoint(semanticTextFieldTaskType, inferenceId, inferenceEndpointServiceSettings);
114+
semanticTextFields.put(fieldName, inferenceId);
115+
}
116+
117+
XContentBuilder mapping = IntegrationTestUtils.generateSemanticTextMapping(semanticTextFields);
118+
assertAcked(prepareCreate(INDEX_NAME).setMapping(mapping));
119+
120+
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
121+
for (String semanticTextField : semanticTextFields.keySet()) {
122+
Map<String, Object> source = Map.of(semanticTextField, randomAlphaOfLength(10));
123+
DocWriteResponse docWriteResponse = client().prepareIndex(INDEX_NAME).setSource(source).get(TEST_REQUEST_TIMEOUT);
124+
assertThat(docWriteResponse.getResult(), is(DocWriteResponse.Result.CREATED));
125+
126+
boolQuery.should(clauseGenerator.apply(semanticTextField, randomAlphaOfLength(10)));
127+
}
128+
client().admin().indices().prepareRefresh(INDEX_NAME).get();
129+
130+
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQuery).size(clauseCount);
131+
SearchRequest searchRequest = new SearchRequest(new String[] { INDEX_NAME }, searchSourceBuilder);
132+
assertResponse(client().search(searchRequest), response -> {
133+
assertThat(response.getSuccessfulShards(), equalTo(response.getTotalShards()));
134+
assertThat(response.getHits().getTotalHits().value(), equalTo((long) clauseCount));
135+
});
136+
}
137+
138+
private void createInferenceEndpoint(TaskType taskType, String inferenceId, Map<String, Object> serviceSettings) throws IOException {
139+
IntegrationTestUtils.createInferenceEndpoint(client(), taskType, inferenceId, serviceSettings);
140+
inferenceIds.put(inferenceId, taskType);
141+
}
142+
143+
private static Map<String, Object> getServiceSettings(TaskType taskType) {
144+
return switch (taskType) {
145+
case SPARSE_EMBEDDING -> SPARSE_EMBEDDING_SERVICE_SETTINGS;
146+
case TEXT_EMBEDDING -> TEXT_EMBEDDING_SERVICE_SETTINGS;
147+
default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]");
148+
};
149+
}
150+
}

0 commit comments

Comments
 (0)