Skip to content

Commit f1bbcc6

Browse files
committed
external actions tests
1 parent 1946d49 commit f1bbcc6

20 files changed

+492
-172
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
7171
}
7272

7373
// default for testing
74-
static String convertToString(InputType inputType) {
74+
public static String convertToString(InputType inputType) {
7575
return switch (inputType) {
7676
case INGEST, INTERNAL_INGEST -> SEARCH_DOCUMENT;
7777
case SEARCH, INTERNAL_SEARCH -> SEARCH_QUERY;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioEmbeddingsRequestEntity.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
9090
}
9191

9292
// default for testing
93-
static String convertToString(InputType inputType) {
93+
public static String convertToString(InputType inputType) {
9494
return switch (inputType) {
9595
case INGEST, INTERNAL_INGEST -> RETRIEVAL_DOCUMENT_TASK_TYPE;
9696
case SEARCH, INTERNAL_SEARCH -> RETRIEVAL_QUERY_TASK_TYPE;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
7272
return builder;
7373
}
7474

75-
static String convertToString(InputType inputType) {
75+
public static String convertToString(InputType inputType) {
7676
return switch (inputType) {
7777
case null -> null;
7878
case INGEST, INTERNAL_INGEST -> DOCUMENT;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,26 @@ public static InputType randomWithoutUnspecified() {
2727
);
2828
}
2929

30+
public static InputType randomWithNull() {
31+
return randomBoolean()
32+
? null
33+
: randomFrom(
34+
InputType.UNSPECIFIED,
35+
InputType.INGEST,
36+
InputType.SEARCH,
37+
InputType.CLUSTERING,
38+
InputType.CLASSIFICATION,
39+
InputType.INTERNAL_SEARCH,
40+
InputType.INTERNAL_INGEST
41+
);
42+
}
43+
44+
public static InputType randomSearchAndIngestWithNull() {
45+
return randomBoolean()
46+
? null
47+
: randomFrom(InputType.UNSPECIFIED, InputType.INGEST, InputType.SEARCH, InputType.INTERNAL_SEARCH, InputType.INTERNAL_INGEST);
48+
}
49+
3050
public static InputType randomWithIngestAndSearch() {
3151
return randomFrom(InputType.INGEST, InputType.SEARCH, InputType.INTERNAL_SEARCH, InputType.INTERNAL_INGEST);
3252
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1717
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
1818
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
19+
import org.elasticsearch.xpack.inference.InputTypeTests;
1920
import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockMockRequestSender;
2021
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
2122
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -50,7 +51,7 @@ public void shutdown() throws IOException {
5051
terminate(threadPool);
5152
}
5253

53-
public void testEmbeddingsRequestAction() throws IOException {
54+
public void testEmbeddingsRequestAction_Titan() throws IOException {
5455
var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
5556
var mockedFloatResults = List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }));
5657
var mockedResult = new TextEmbeddingFloatResults(mockedFloatResults);
@@ -72,7 +73,46 @@ public void testEmbeddingsRequestAction() throws IOException {
7273
);
7374
var action = creator.create(model, Map.of());
7475
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
75-
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
76+
action.execute(
77+
new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
78+
InferenceAction.Request.DEFAULT_TIMEOUT,
79+
listener
80+
);
81+
var result = listener.actionGet(TIMEOUT);
82+
83+
assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F }))));
84+
85+
assertThat(sender.sendCount(), is(1));
86+
var sentInputs = sender.getInputs();
87+
assertThat(sentInputs.size(), is(1));
88+
assertThat(sentInputs.get(0), is("abc"));
89+
}
90+
}
91+
92+
public void testEmbeddingsRequestAction_Cohere() throws IOException {
93+
var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
94+
var mockedFloatResults = List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }));
95+
var mockedResult = new TextEmbeddingFloatResults(mockedFloatResults);
96+
try (var sender = new AmazonBedrockMockRequestSender()) {
97+
sender.enqueue(mockedResult);
98+
var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT);
99+
var model = AmazonBedrockEmbeddingsModelTests.createModel(
100+
"test_id",
101+
"test_region",
102+
"test_model",
103+
AmazonBedrockProvider.COHERE,
104+
null,
105+
false,
106+
null,
107+
null,
108+
null,
109+
"accesskey",
110+
"secretkey"
111+
);
112+
var action = creator.create(model, Map.of());
113+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
114+
var inputType = InputTypeTests.randomWithNull();
115+
action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
76116
var result = listener.actionGet(TIMEOUT);
77117

78118
assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F }))));
@@ -81,6 +121,11 @@ public void testEmbeddingsRequestAction() throws IOException {
81121
var sentInputs = sender.getInputs();
82122
assertThat(sentInputs.size(), is(1));
83123
assertThat(sentInputs.get(0), is("abc"));
124+
125+
if (inputType != null) {
126+
var sentInputType = sender.getInputType();
127+
assertThat(sentInputType, is(inputType));
128+
}
84129
}
85130
}
86131

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.threadpool.ThreadPool;
2020
import org.elasticsearch.xcontent.XContentType;
2121
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
22+
import org.elasticsearch.xpack.inference.InputTypeTests;
2223
import org.elasticsearch.xpack.inference.common.TruncatorTests;
2324
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
2425
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
@@ -111,7 +112,8 @@ public void testEmbeddingsRequestAction() throws IOException {
111112
var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
112113
var action = creator.create(model, Map.of());
113114
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
114-
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
115+
var inputType = InputTypeTests.randomWithNull();
116+
action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
115117

116118
var result = listener.actionGet(TIMEOUT);
117119

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,11 @@ public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException {
115115
model.setUri(new URI(getUrl(webServer)));
116116
var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
117117
var overriddenTaskSettings = createRequestTaskSettingsMap("overridden_user");
118+
var inputType = InputTypeTests.randomWithNull();
118119
var action = actionCreator.create(model, overriddenTaskSettings);
119120

120121
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
121-
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
122+
action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
122123

123124
var result = listener.actionGet(TIMEOUT);
124125

@@ -127,7 +128,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException {
127128
validateRequestWithApiKey(webServer.requests().get(0), "apikey");
128129

129130
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
130-
validateEmbeddingsRequestMapWithUser(requestMap, List.of("abc"), "overridden_user", null);
131+
validateEmbeddingsRequestMapWithUser(requestMap, List.of("abc"), "overridden_user", inputType);
131132
} catch (URISyntaxException e) {
132133
throw new RuntimeException(e);
133134
}
@@ -165,7 +166,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOExcepti
165166
model.setUri(new URI(getUrl(webServer)));
166167
var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
167168
var overriddenTaskSettings = createRequestTaskSettingsMap(null);
168-
var inputType = InputTypeTests.randomWithoutUnspecified();
169+
var inputType = InputTypeTests.randomWithNull();
169170
var action = actionCreator.create(model, overriddenTaskSettings);
170171

171172
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -368,10 +369,11 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC
368369
model.setUri(new URI(getUrl(webServer)));
369370
var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
370371
var overriddenTaskSettings = createRequestTaskSettingsMap("overridden_user");
372+
var inputType = InputTypeTests.randomWithNull();
371373
var action = actionCreator.create(model, overriddenTaskSettings);
372374

373375
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
374-
action.execute(new EmbeddingsInput(List.of("abcd"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
376+
action.execute(new EmbeddingsInput(List.of("abcd"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
375377

376378
var result = listener.actionGet(TIMEOUT);
377379

@@ -381,13 +383,13 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC
381383
validateRequestWithApiKey(webServer.requests().get(0), "apikey");
382384

383385
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
384-
validateEmbeddingsRequestMapWithUser(requestMap, List.of("abcd"), "overridden_user", null);
386+
validateEmbeddingsRequestMapWithUser(requestMap, List.of("abcd"), "overridden_user", inputType);
385387
}
386388
{
387389
validateRequestWithApiKey(webServer.requests().get(1), "apikey");
388390

389391
var requestMap = entityAsMap(webServer.requests().get(1).getBody());
390-
validateEmbeddingsRequestMapWithUser(requestMap, List.of("ab"), "overridden_user", null);
392+
validateEmbeddingsRequestMapWithUser(requestMap, List.of("ab"), "overridden_user", inputType);
391393
}
392394
} catch (URISyntaxException e) {
393395
throw new RuntimeException(e);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.elasticsearch.threadpool.ThreadPool;
2222
import org.elasticsearch.xcontent.XContentType;
2323
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
24+
import org.elasticsearch.xpack.inference.InputTypeTests;
2425
import org.elasticsearch.xpack.inference.common.TruncatorTests;
2526
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
2627
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
@@ -113,7 +114,8 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException {
113114
var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
114115

115116
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
116-
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
117+
var inputType = InputTypeTests.randomWithNull();
118+
action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
117119

118120
var result = listener.actionGet(TIMEOUT);
119121

@@ -124,9 +126,12 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException {
124126
assertThat(webServer.requests().get(0).getHeader(AzureOpenAiUtils.API_KEY_HEADER), equalTo("apikey"));
125127

126128
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
127-
assertThat(requestMap.size(), is(2));
129+
assertThat(requestMap.size(), is(inputType != null ? 3 : 2));
128130
assertThat(requestMap.get("input"), is(List.of("abc")));
129131
assertThat(requestMap.get("user"), is("user"));
132+
if (inputType != null) {
133+
assertThat(requestMap.get("input_type"), is(inputType.toString()));
134+
}
130135
}
131136
}
132137

@@ -137,7 +142,8 @@ public void testExecute_ThrowsElasticsearchException() {
137142
var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
138143

139144
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
140-
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
145+
var inputType = InputTypeTests.randomWithNull();
146+
action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
141147

142148
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
143149

@@ -157,7 +163,8 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled
157163
var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
158164

159165
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
160-
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
166+
var inputType = InputTypeTests.randomWithNull();
167+
action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
161168

162169
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
163170

@@ -177,7 +184,8 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled
177184
var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
178185

179186
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
180-
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
187+
var inputType = InputTypeTests.randomWithNull();
188+
action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
181189

182190
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
183191

@@ -191,7 +199,8 @@ public void testExecute_ThrowsException() {
191199
var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
192200

193201
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
194-
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
202+
var inputType = InputTypeTests.randomWithNull();
203+
action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
195204

196205
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
197206

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import org.elasticsearch.threadpool.ThreadPool;
2020
import org.elasticsearch.xcontent.XContentType;
2121
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
22-
import org.elasticsearch.xpack.inference.InputTypeTests;
2322
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
2423
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
2524
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -120,11 +119,7 @@ public void testCreate_CohereEmbeddingsModel() throws IOException {
120119
var action = actionCreator.create(model, overriddenTaskSettings);
121120

122121
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
123-
action.execute(
124-
new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithoutUnspecified()),
125-
InferenceAction.Request.DEFAULT_TIMEOUT,
126-
listener
127-
);
122+
action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
128123

129124
var result = listener.actionGet(TIMEOUT);
130125

0 commit comments

Comments
 (0)