Skip to content

Commit de1f8fc

Browse files
committed
Everything compiles
1 parent e6e877e commit de1f8fc

File tree

66 files changed

+413
-787
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+413
-787
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.apache.logging.log4j.Logger;
1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.inference.InputType;
1415
import org.elasticsearch.threadpool.ThreadPool;
1516
import org.elasticsearch.xpack.inference.common.Truncator;
1617
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
@@ -46,9 +47,12 @@ public void execute(
4647
Supplier<Boolean> hasRequestCompletedFunction,
4748
ActionListener<InferenceServiceResults> listener
4849
) {
49-
List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getInputs();
50+
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
51+
List<String> docsInput = input.getInputs();
52+
InputType inputType = input.getInputType();
53+
5054
var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());
51-
AzureAiStudioEmbeddingsRequest request = new AzureAiStudioEmbeddingsRequest(truncator, truncatedInput, model);
55+
AzureAiStudioEmbeddingsRequest request = new AzureAiStudioEmbeddingsRequest(truncator, truncatedInput, inputType, model);
5256
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
5357
}
5458

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioEmbeddingsRequest.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.apache.http.client.methods.HttpPost;
1212
import org.apache.http.entity.ByteArrayEntity;
1313
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.inference.InputType;
1415
import org.elasticsearch.xcontent.XContentType;
1516
import org.elasticsearch.xpack.inference.common.Truncator;
1617
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
@@ -23,13 +24,20 @@ public class AzureAiStudioEmbeddingsRequest extends AzureAiStudioRequest {
2324

2425
private final AzureAiStudioEmbeddingsModel embeddingsModel;
2526
private final Truncator.TruncationResult truncationResult;
27+
private final InputType inputType;
2628
private final Truncator truncator;
2729

28-
public AzureAiStudioEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, AzureAiStudioEmbeddingsModel model) {
30+
public AzureAiStudioEmbeddingsRequest(
31+
Truncator truncator,
32+
Truncator.TruncationResult input,
33+
InputType inputType,
34+
AzureAiStudioEmbeddingsModel model
35+
) {
2936
super(model);
3037
this.embeddingsModel = model;
3138
this.truncator = truncator;
3239
this.truncationResult = input;
40+
this.inputType = inputType;
3341
}
3442

3543
@Override
@@ -41,8 +49,9 @@ public HttpRequest createHttpRequest() {
4149
var dimensionsSetByUser = embeddingsModel.getServiceSettings().dimensionsSetByUser();
4250

4351
ByteArrayEntity byteEntity = new ByteArrayEntity(
44-
Strings.toString(new AzureAiStudioEmbeddingsRequestEntity(truncationResult.input(), user, dimensions, dimensionsSetByUser))
45-
.getBytes(StandardCharsets.UTF_8)
52+
Strings.toString(
53+
new AzureAiStudioEmbeddingsRequestEntity(truncationResult.input(), inputType, user, dimensions, dimensionsSetByUser)
54+
).getBytes(StandardCharsets.UTF_8)
4655
);
4756
httpPost.setEntity(byteEntity);
4857

@@ -55,7 +64,7 @@ public HttpRequest createHttpRequest() {
5564
@Override
5665
public Request truncate() {
5766
var truncatedInput = truncator.truncate(truncationResult.input());
58-
return new AzureAiStudioEmbeddingsRequest(truncator, truncatedInput, embeddingsModel);
67+
return new AzureAiStudioEmbeddingsRequest(truncator, truncatedInput, inputType, embeddingsModel);
5968
}
6069

6170
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioEmbeddingsRequestEntity.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.external.request.azureaistudio;
99

1010
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.inference.InputType;
1112
import org.elasticsearch.xcontent.ToXContentObject;
1213
import org.elasticsearch.xcontent.XContentBuilder;
1314

@@ -21,6 +22,7 @@
2122

2223
public record AzureAiStudioEmbeddingsRequestEntity(
2324
List<String> input,
25+
InputType inputType,
2426
@Nullable String user,
2527
@Nullable Integer dimensions,
2628
boolean dimensionsSetByUser

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
package org.elasticsearch.xpack.inference.services.openai.embeddings;
99

1010
import org.apache.http.client.utils.URIBuilder;
11-
import org.elasticsearch.common.ValidationException;
1211
import org.elasticsearch.core.Nullable;
1312
import org.elasticsearch.inference.ChunkingSettings;
1413
import org.elasticsearch.inference.ModelConfigurations;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ public void testEmbeddingsRequestAction() throws IOException {
7070
"accesskey",
7171
"secretkey"
7272
);
73-
var action = creator.create(model, Map.of(), null);
73+
var action = creator.create(model, Map.of());
7474
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
75-
action.execute(new EmbeddingsInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
75+
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
7676
var result = listener.actionGet(TIMEOUT);
7777

7878
assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F }))));
@@ -98,9 +98,9 @@ public void testEmbeddingsRequestAction_HandlesException() throws IOException {
9898
"accesskey",
9999
"secretkey"
100100
);
101-
var action = creator.create(model, Map.of(), null);
101+
var action = creator.create(model, Map.of());
102102
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
103-
action.execute(new EmbeddingsInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
103+
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
104104
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
105105

106106
assertThat(sender.sendCount(), is(1));

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ public void testEmbeddingsRequestAction() throws IOException {
109109
model.setURI(getUrl(webServer));
110110

111111
var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
112-
var action = creator.create(model, Map.of(), null);
112+
var action = creator.create(model, Map.of());
113113
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
114-
action.execute(new EmbeddingsInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
114+
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
115115

116116
var result = listener.actionGet(TIMEOUT);
117117

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,14 @@ public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException {
111111
""";
112112
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
113113

114-
var model = createModel("resource", "deployment", "apiversion", "orig_user", "apikey", null, "id", null);
114+
var model = createModel("resource", "deployment", "apiversion", "orig_user", "apikey", null, "id");
115115
model.setUri(new URI(getUrl(webServer)));
116116
var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
117117
var overriddenTaskSettings = createRequestTaskSettingsMap("overridden_user");
118-
var action = actionCreator.create(model, overriddenTaskSettings, null);
118+
var action = actionCreator.create(model, overriddenTaskSettings);
119119

120120
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
121-
action.execute(new EmbeddingsInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
121+
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
122122

123123
var result = listener.actionGet(TIMEOUT);
124124

@@ -161,15 +161,15 @@ public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOExcepti
161161
""";
162162
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
163163

164-
var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id", null);
164+
var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id");
165165
model.setUri(new URI(getUrl(webServer)));
166166
var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
167167
var overriddenTaskSettings = createRequestTaskSettingsMap(null);
168168
var inputType = InputTypeTests.randomWithoutUnspecified();
169-
var action = actionCreator.create(model, overriddenTaskSettings, inputType);
169+
var action = actionCreator.create(model, overriddenTaskSettings);
170170

171171
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
172-
action.execute(new EmbeddingsInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
172+
action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
173173

174174
var result = listener.actionGet(TIMEOUT);
175175

@@ -213,15 +213,15 @@ public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat
213213
""";
214214
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
215215

216-
var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id", null);
216+
var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id");
217217
model.setUri(new URI(getUrl(webServer)));
218218
var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
219219
var overriddenTaskSettings = createRequestTaskSettingsMap("overridden_user");
220220
var inputType = InputTypeTests.randomWithoutUnspecified();
221-
var action = actionCreator.create(model, overriddenTaskSettings, inputType);
221+
var action = actionCreator.create(model, overriddenTaskSettings);
222222

223223
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
224-
action.execute(new EmbeddingsInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
224+
action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
225225

226226
var failureCauseMessage = "Failed to find required field [data] in OpenAI embeddings response";
227227
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
@@ -287,15 +287,15 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC
287287
webServer.enqueue(new MockResponse().setResponseCode(413).setBody(responseJsonContentTooLarge));
288288
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
289289

290-
var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id", null);
290+
var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id");
291291
model.setUri(new URI(getUrl(webServer)));
292292
var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
293293
var overriddenTaskSettings = createRequestTaskSettingsMap("overridden_user");
294294
var inputType = InputTypeTests.randomWithoutUnspecified();
295-
var action = actionCreator.create(model, overriddenTaskSettings, inputType);
295+
var action = actionCreator.create(model, overriddenTaskSettings);
296296

297297
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
298-
action.execute(new EmbeddingsInput(List.of("abcd")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
298+
action.execute(new EmbeddingsInput(List.of("abcd"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
299299

300300
var result = listener.actionGet(TIMEOUT);
301301

@@ -364,14 +364,14 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC
364364
webServer.enqueue(new MockResponse().setResponseCode(400).setBody(responseJsonContentTooLarge));
365365
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
366366

367-
var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id", null);
367+
var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id");
368368
model.setUri(new URI(getUrl(webServer)));
369369
var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
370370
var overriddenTaskSettings = createRequestTaskSettingsMap("overridden_user");
371-
var action = actionCreator.create(model, overriddenTaskSettings, null);
371+
var action = actionCreator.create(model, overriddenTaskSettings);
372372

373373
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
374-
action.execute(new EmbeddingsInput(List.of("abcd")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
374+
action.execute(new EmbeddingsInput(List.of("abcd"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
375375

376376
var result = listener.actionGet(TIMEOUT);
377377

@@ -423,15 +423,15 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException {
423423
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
424424

425425
// truncated to 1 token = 3 characters
426-
var model = createModel("resource", "deployment", "apiversion", null, false, 1, null, null, "apikey", null, "id", null);
426+
var model = createModel("resource", "deployment", "apiversion", null, false, 1, null, null, "apikey", null, "id");
427427
model.setUri(new URI(getUrl(webServer)));
428428
var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
429429
var overriddenTaskSettings = createRequestTaskSettingsMap("overridden_user");
430430
var inputType = InputTypeTests.randomWithoutUnspecified();
431-
var action = actionCreator.create(model, overriddenTaskSettings, inputType);
431+
var action = actionCreator.create(model, overriddenTaskSettings);
432432

433433
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
434-
action.execute(new EmbeddingsInput(List.of("super long input")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
434+
action.execute(new EmbeddingsInput(List.of("super long input"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
435435

436436
var result = listener.actionGet(TIMEOUT);
437437

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException {
113113
var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
114114

115115
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
116-
action.execute(new EmbeddingsInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
116+
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
117117

118118
var result = listener.actionGet(TIMEOUT);
119119

@@ -137,7 +137,7 @@ public void testExecute_ThrowsElasticsearchException() {
137137
var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
138138

139139
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
140-
action.execute(new EmbeddingsInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
140+
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
141141

142142
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
143143

@@ -157,7 +157,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled
157157
var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
158158

159159
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
160-
action.execute(new EmbeddingsInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
160+
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
161161

162162
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
163163

@@ -177,7 +177,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled
177177
var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
178178

179179
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
180-
action.execute(new EmbeddingsInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
180+
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
181181

182182
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
183183

@@ -191,7 +191,7 @@ public void testExecute_ThrowsException() {
191191
var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
192192

193193
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
194-
action.execute(new EmbeddingsInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
194+
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
195195

196196
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
197197

@@ -209,7 +209,7 @@ private ExecutableAction createAction(
209209
) {
210210
AzureOpenAiEmbeddingsModel model = null;
211211
try {
212-
model = createModel(resourceName, deploymentId, apiVersion, user, apiKey, null, inferenceEntityId, null);
212+
model = createModel(resourceName, deploymentId, apiVersion, user, apiKey, null, inferenceEntityId);
213213
model.setUri(new URI(getUrl(webServer)));
214214
var requestCreator = new AzureOpenAiEmbeddingsRequestManager(model, TruncatorTests.createTruncator(), threadPool);
215215
var errorMessage = constructFailedToSendRequestMessage("Azure OpenAI embeddings");

0 commit comments

Comments
 (0)