Skip to content

Commit c644dbb

Browse files
authored
[ML] Move InferenceInputs up a level (#112726) (#113564)
Refactor before streaming support is added - moving InferenceInputs up a level so that construction happens at the top level rather than each individual implementation. UnsupportedOperationException will now be thrown as an IllegalStateException later in the call chain, both would go through the listener's onFailure method anyway. Backport of 6c1aaa4
1 parent 37ebafd commit c644dbb

File tree

21 files changed

+121
-406
lines changed

21 files changed

+121
-406
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
import org.elasticsearch.inference.InferenceServiceResults;
1818
import org.elasticsearch.inference.InputType;
1919
import org.elasticsearch.inference.Model;
20+
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
2021
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
22+
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
23+
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
2124
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2225

2326
import java.io.IOException;
@@ -55,9 +58,9 @@ public void infer(
5558
) {
5659
init();
5760
if (query != null) {
58-
doInfer(model, query, input, taskSettings, inputType, timeout, listener);
61+
doInfer(model, new QueryAndDocsInputs(query, input), taskSettings, inputType, timeout, listener);
5962
} else {
60-
doInfer(model, input, taskSettings, inputType, timeout, listener);
63+
doInfer(model, new DocumentsOnlyInput(input), taskSettings, inputType, timeout, listener);
6164
}
6265
}
6366

@@ -86,22 +89,13 @@ public void chunkedInfer(
8689
ActionListener<List<ChunkedInferenceServiceResults>> listener
8790
) {
8891
init();
89-
doChunkedInfer(model, null, input, taskSettings, inputType, chunkingOptions, timeout, listener);
92+
// a non-null query is not supported and is dropped by all providers
93+
doChunkedInfer(model, new DocumentsOnlyInput(input), taskSettings, inputType, chunkingOptions, timeout, listener);
9094
}
9195

9296
protected abstract void doInfer(
9397
Model model,
94-
List<String> input,
95-
Map<String, Object> taskSettings,
96-
InputType inputType,
97-
TimeValue timeout,
98-
ActionListener<InferenceServiceResults> listener
99-
);
100-
101-
protected abstract void doInfer(
102-
Model model,
103-
String query,
104-
List<String> input,
98+
InferenceInputs inputs,
10599
Map<String, Object> taskSettings,
106100
InputType inputType,
107101
TimeValue timeout,
@@ -110,8 +104,7 @@ protected abstract void doInfer(
110104

111105
protected abstract void doChunkedInfer(
112106
Model model,
113-
@Nullable String query,
114-
List<String> input,
107+
DocumentsOnlyInput inputs,
115108
Map<String, Object> taskSettings,
116109
InputType inputType,
117110
ChunkingOptions chunkingOptions,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionCreator;
2929
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
3030
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
31-
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
31+
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
3232
import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils;
3333
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
3434
import org.elasticsearch.xpack.inference.services.SenderService;
@@ -204,8 +204,7 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta
204204
@Override
205205
public void doInfer(
206206
Model model,
207-
String query,
208-
List<String> input,
207+
InferenceInputs inputs,
209208
Map<String, Object> taskSettings,
210209
InputType inputType,
211210
TimeValue timeout,
@@ -220,35 +219,13 @@ public void doInfer(
220219
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());
221220

222221
var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType);
223-
action.execute(new QueryAndDocsInputs(query, input), timeout, listener);
224-
}
225-
226-
@Override
227-
public void doInfer(
228-
Model model,
229-
List<String> input,
230-
Map<String, Object> taskSettings,
231-
InputType inputType,
232-
TimeValue timeout,
233-
ActionListener<InferenceServiceResults> listener
234-
) {
235-
if (model instanceof AlibabaCloudSearchModel == false) {
236-
listener.onFailure(createInvalidModelException(model));
237-
return;
238-
}
239-
240-
AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model;
241-
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());
242-
243-
var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType);
244-
action.execute(new DocumentsOnlyInput(input), timeout, listener);
222+
action.execute(inputs, timeout, listener);
245223
}
246224

247225
@Override
248226
protected void doChunkedInfer(
249227
Model model,
250-
@Nullable String query,
251-
List<String> input,
228+
DocumentsOnlyInput inputs,
252229
Map<String, Object> taskSettings,
253230
InputType inputType,
254231
ChunkingOptions chunkingOptions,
@@ -263,8 +240,11 @@ protected void doChunkedInfer(
263240
AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model;
264241
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());
265242

266-
var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT)
267-
.batchRequestsWithListeners(listener);
243+
var batchedRequests = new EmbeddingRequestChunker(
244+
inputs.getInputs(),
245+
EMBEDDING_MAX_BATCH_SIZE,
246+
EmbeddingRequestChunker.EmbeddingType.FLOAT
247+
).batchRequestsWithListeners(listener);
268248
for (var request : batchedRequests) {
269249
var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType);
270250
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender;
2929
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
3030
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
31+
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
3132
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
3233
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
3334
import org.elasticsearch.xpack.inference.services.SenderService;
@@ -71,7 +72,7 @@ public AmazonBedrockService(
7172
@Override
7273
protected void doInfer(
7374
Model model,
74-
List<String> input,
75+
InferenceInputs inputs,
7576
Map<String, Object> taskSettings,
7677
InputType inputType,
7778
TimeValue timeout,
@@ -80,30 +81,16 @@ protected void doInfer(
8081
var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout);
8182
if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
8283
var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings);
83-
action.execute(new DocumentsOnlyInput(input), timeout, listener);
84+
action.execute(inputs, timeout, listener);
8485
} else {
8586
listener.onFailure(createInvalidModelException(model));
8687
}
8788
}
8889

89-
@Override
90-
protected void doInfer(
91-
Model model,
92-
String query,
93-
List<String> input,
94-
Map<String, Object> taskSettings,
95-
InputType inputType,
96-
TimeValue timeout,
97-
ActionListener<InferenceServiceResults> listener
98-
) {
99-
throw new UnsupportedOperationException("Amazon Bedrock service does not support inference with query input");
100-
}
101-
10290
@Override
10391
protected void doChunkedInfer(
10492
Model model,
105-
String query,
106-
List<String> input,
93+
DocumentsOnlyInput inputs,
10794
Map<String, Object> taskSettings,
10895
InputType inputType,
10996
ChunkingOptions chunkingOptions,
@@ -113,7 +100,7 @@ protected void doChunkedInfer(
113100
var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout);
114101
if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
115102
var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider());
116-
var batchedRequests = new EmbeddingRequestChunker(input, maxBatchSize, EmbeddingRequestChunker.EmbeddingType.FLOAT)
103+
var batchedRequests = new EmbeddingRequestChunker(inputs.getInputs(), maxBatchSize, EmbeddingRequestChunker.EmbeddingType.FLOAT)
117104
.batchRequestsWithListeners(listener);
118105
for (var request : batchedRequests) {
119106
var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.xpack.inference.external.action.anthropic.AnthropicActionCreator;
2626
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
2727
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
28+
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
2829
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
2930
import org.elasticsearch.xpack.inference.services.SenderService;
3031
import org.elasticsearch.xpack.inference.services.ServiceComponents;
@@ -165,7 +166,7 @@ public AnthropicModel parsePersistedConfig(String inferenceEntityId, TaskType ta
165166
@Override
166167
public void doInfer(
167168
Model model,
168-
List<String> input,
169+
InferenceInputs inputs,
169170
Map<String, Object> taskSettings,
170171
InputType inputType,
171172
TimeValue timeout,
@@ -180,27 +181,13 @@ public void doInfer(
180181
var actionCreator = new AnthropicActionCreator(getSender(), getServiceComponents());
181182

182183
var action = anthropicModel.accept(actionCreator, taskSettings);
183-
action.execute(new DocumentsOnlyInput(input), timeout, listener);
184-
}
185-
186-
@Override
187-
protected void doInfer(
188-
Model model,
189-
String query,
190-
List<String> input,
191-
Map<String, Object> taskSettings,
192-
InputType inputType,
193-
TimeValue timeout,
194-
ActionListener<InferenceServiceResults> listener
195-
) {
196-
throw new UnsupportedOperationException("Anthropic service does not support inference with query input");
184+
action.execute(inputs, timeout, listener);
197185
}
198186

199187
@Override
200188
protected void doChunkedInfer(
201189
Model model,
202-
@Nullable String query,
203-
List<String> input,
190+
DocumentsOnlyInput inputs,
204191
Map<String, Object> taskSettings,
205192
InputType inputType,
206193
ChunkingOptions chunkingOptions,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.xpack.inference.external.action.azureaistudio.AzureAiStudioActionCreator;
2929
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
3030
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
31+
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
3132
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
3233
import org.elasticsearch.xpack.inference.services.SenderService;
3334
import org.elasticsearch.xpack.inference.services.ServiceComponents;
@@ -62,7 +63,7 @@ public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents
6263
@Override
6364
protected void doInfer(
6465
Model model,
65-
List<String> input,
66+
InferenceInputs inputs,
6667
Map<String, Object> taskSettings,
6768
InputType inputType,
6869
TimeValue timeout,
@@ -72,30 +73,16 @@ protected void doInfer(
7273

7374
if (model instanceof AzureAiStudioModel baseAzureAiStudioModel) {
7475
var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings);
75-
action.execute(new DocumentsOnlyInput(input), timeout, listener);
76+
action.execute(inputs, timeout, listener);
7677
} else {
7778
listener.onFailure(createInvalidModelException(model));
7879
}
7980
}
8081

81-
@Override
82-
protected void doInfer(
83-
Model model,
84-
String query,
85-
List<String> input,
86-
Map<String, Object> taskSettings,
87-
InputType inputType,
88-
TimeValue timeout,
89-
ActionListener<InferenceServiceResults> listener
90-
) {
91-
throw new UnsupportedOperationException("Azure AI Studio service does not support inference with query input");
92-
}
93-
9482
@Override
9583
protected void doChunkedInfer(
9684
Model model,
97-
String query,
98-
List<String> input,
85+
DocumentsOnlyInput inputs,
9986
Map<String, Object> taskSettings,
10087
InputType inputType,
10188
ChunkingOptions chunkingOptions,
@@ -104,8 +91,11 @@ protected void doChunkedInfer(
10491
) {
10592
if (model instanceof AzureAiStudioModel baseAzureAiStudioModel) {
10693
var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents());
107-
var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT)
108-
.batchRequestsWithListeners(listener);
94+
var batchedRequests = new EmbeddingRequestChunker(
95+
inputs.getInputs(),
96+
EMBEDDING_MAX_BATCH_SIZE,
97+
EmbeddingRequestChunker.EmbeddingType.FLOAT
98+
).batchRequestsWithListeners(listener);
10999
for (var request : batchedRequests) {
110100
var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings);
111101
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreator;
2929
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
3030
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
31+
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
3132
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
3233
import org.elasticsearch.xpack.inference.services.SenderService;
3334
import org.elasticsearch.xpack.inference.services.ServiceComponents;
@@ -185,7 +186,7 @@ public AzureOpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType
185186
@Override
186187
protected void doInfer(
187188
Model model,
188-
List<String> input,
189+
InferenceInputs inputs,
189190
Map<String, Object> taskSettings,
190191
InputType inputType,
191192
TimeValue timeout,
@@ -200,27 +201,13 @@ protected void doInfer(
200201
var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents());
201202

202203
var action = azureOpenAiModel.accept(actionCreator, taskSettings);
203-
action.execute(new DocumentsOnlyInput(input), timeout, listener);
204-
}
205-
206-
@Override
207-
protected void doInfer(
208-
Model model,
209-
String query,
210-
List<String> input,
211-
Map<String, Object> taskSettings,
212-
InputType inputType,
213-
TimeValue timeout,
214-
ActionListener<InferenceServiceResults> listener
215-
) {
216-
throw new UnsupportedOperationException("Azure OpenAI service does not support inference with query input");
204+
action.execute(inputs, timeout, listener);
217205
}
218206

219207
@Override
220208
protected void doChunkedInfer(
221209
Model model,
222-
String query,
223-
List<String> input,
210+
DocumentsOnlyInput inputs,
224211
Map<String, Object> taskSettings,
225212
InputType inputType,
226213
ChunkingOptions chunkingOptions,
@@ -233,8 +220,11 @@ protected void doChunkedInfer(
233220
}
234221
AzureOpenAiModel azureOpenAiModel = (AzureOpenAiModel) model;
235222
var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents());
236-
var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT)
237-
.batchRequestsWithListeners(listener);
223+
var batchedRequests = new EmbeddingRequestChunker(
224+
inputs.getInputs(),
225+
EMBEDDING_MAX_BATCH_SIZE,
226+
EmbeddingRequestChunker.EmbeddingType.FLOAT
227+
).batchRequestsWithListeners(listener);
238228
for (var request : batchedRequests) {
239229
var action = azureOpenAiModel.accept(actionCreator, taskSettings);
240230
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());

0 commit comments

Comments
 (0)