Skip to content

Commit c8775f7

Browse files
Update tests
1 parent 8bc7ef9 commit c8775f7

File tree

4 files changed

+127
-26
lines changed

4 files changed

+127
-26
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSe
650650
}
651651

652652
public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
653-
var model = LlamaEmbeddingsModelTests.createEmbeddingsModel("id", "url", "api_key", "user");
653+
var model = LlamaEmbeddingsModelTests.createEmbeddingsModel("id", "url", "api_key", "user", null, false);
654654
model.setURI(getUrl(webServer));
655655

656656
testChunkedInfer(model);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java

Lines changed: 116 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,51 @@ public void shutdown() throws IOException {
7777
webServer.close();
7878
}
7979

80-
public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException {
81-
testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction("overridden_user");
80+
public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_WithUser_WithDimensions_DimensionsSetByUserFalse()
81+
throws IOException {
82+
testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction("overridden_user", 384, false, null);
8283
}
8384

84-
public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_WithoutUser() throws IOException {
85-
testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction(null);
85+
public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_NoUser_WithDimensions_DimensionsSetByUserFalse()
86+
throws IOException {
87+
testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction(null, 384, false, null);
8688
}
8789

88-
private void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction(String user) throws IOException {
90+
public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_WithUser_NoDimensions_DimensionsSetByUserFalse()
91+
throws IOException {
92+
testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction("overridden_user", null, false, null);
93+
}
94+
95+
public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_NoUser_NoDimensions_DimensionsSetByUserFalse()
96+
throws IOException {
97+
testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction(null, null, false, null);
98+
}
99+
100+
public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_WithUser_WithDimensions_DimensionsSetByUserTrue()
101+
throws IOException {
102+
testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction("overridden_user", 384, true, 384);
103+
}
104+
105+
public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_NoUser_WithDimensions_DimensionsSetByUserTrue()
106+
throws IOException {
107+
testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction(null, 384, true, 384);
108+
}
109+
110+
public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_WithUser_NoDimensions_DimensionsSetByUserTrue()
111+
throws IOException {
112+
testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction("overridden_user", null, true, null);
113+
}
114+
115+
public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_NoUser_NoDimensions_DimensionsSetByUserTrue() throws IOException {
116+
testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction(null, null, true, null);
117+
}
118+
119+
private void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction(
120+
String user,
121+
Integer dimensions,
122+
boolean dimensionsSetByUser,
123+
Integer expectedDimensions
124+
) throws IOException {
89125
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
90126

91127
try (var sender = createSender(senderFactory)) {
@@ -113,25 +149,68 @@ private void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction(String us
113149
""";
114150
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
115151

116-
PlainActionFuture<InferenceServiceResults> listener = createEmbeddingsFuture(sender, createWithEmptySettings(threadPool), user);
152+
PlainActionFuture<InferenceServiceResults> listener = createEmbeddingsFuture(
153+
sender,
154+
createWithEmptySettings(threadPool),
155+
user,
156+
dimensions,
157+
dimensionsSetByUser
158+
);
117159

118160
var result = listener.actionGet(TIMEOUT);
119161

120162
assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.123F, 0.123F }))));
121163

122-
assertEmbeddingsRequest(user);
164+
assertEmbeddingsRequest(user, expectedDimensions);
123165
}
124166
}
125167

126-
public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws IOException {
127-
testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction("overridden_user");
168+
public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_WithUser_WithDimensions_DimensionsSetByUserFalse()
169+
throws IOException {
170+
testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction("overridden_user", 384, false, null);
171+
}
172+
173+
public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_NoUser_WithDimensions_DimensionsSetByUserFalse()
174+
throws IOException {
175+
testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction(null, 384, false, null);
128176
}
129177

130-
public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_WithoutUser() throws IOException {
131-
testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction(null);
178+
public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_WithUser_NoDimensions_DimensionsSetByUserFalse()
179+
throws IOException {
180+
testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction("overridden_user", null, false, null);
132181
}
133182

134-
private void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction(String user) throws IOException {
183+
public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_NoUser_NoDimensions_DimensionsSetByUserFalse()
184+
throws IOException {
185+
testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction(null, null, false, null);
186+
}
187+
188+
public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_WithUser_WithDimensions_DimensionsSetByUserTrue()
189+
throws IOException {
190+
testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction("overridden_user", 384, true, 384);
191+
}
192+
193+
public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_NoUser_WithDimensions_DimensionsSetByUserTrue()
194+
throws IOException {
195+
testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction(null, 384, true, 384);
196+
}
197+
198+
public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_WithUser_NoDimensions_DimensionsSetByUserTrue()
199+
throws IOException {
200+
testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction("overridden_user", null, true, null);
201+
}
202+
203+
public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_NoUser_NoDimensions_DimensionsSetByUserTrue()
204+
throws IOException {
205+
testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction(null, null, true, null);
206+
}
207+
208+
private void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction(
209+
String user,
210+
Integer dimensions,
211+
boolean dimensionsSetByUser,
212+
Integer expectedDimensions
213+
) throws IOException {
135214
var settings = buildSettingsWithRetryFields(
136215
TimeValue.timeValueMillis(1),
137216
TimeValue.timeValueMinutes(1),
@@ -149,15 +228,21 @@ private void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction(Stri
149228
""";
150229
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
151230

152-
PlainActionFuture<InferenceServiceResults> listener = createEmbeddingsFuture(sender, createWithEmptySettings(threadPool), user);
231+
PlainActionFuture<InferenceServiceResults> listener = createEmbeddingsFuture(
232+
sender,
233+
createWithEmptySettings(threadPool),
234+
user,
235+
dimensions,
236+
dimensionsSetByUser
237+
);
153238

154239
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
155240
assertThat(
156241
thrownException.getMessage(),
157242
is("Failed to send Llama text_embedding request from inference entity id [id]. Cause: Required [data]")
158243
);
159244

160-
assertEmbeddingsRequest(user);
245+
assertEmbeddingsRequest(user, expectedDimensions);
161246
}
162247
}
163248

@@ -262,8 +347,21 @@ private void testExecute_FailsFromInvalidResponseFormat_ForCompletionAction(Stri
262347
}
263348
}
264349

265-
private PlainActionFuture<InferenceServiceResults> createEmbeddingsFuture(Sender sender, ServiceComponents threadPool, String user) {
266-
var model = LlamaEmbeddingsModelTests.createEmbeddingsModel("model", getUrl(webServer), "secret", user);
350+
private PlainActionFuture<InferenceServiceResults> createEmbeddingsFuture(
351+
Sender sender,
352+
ServiceComponents threadPool,
353+
String user,
354+
Integer dimensions,
355+
boolean dimensionsSetByUser
356+
) {
357+
var model = LlamaEmbeddingsModelTests.createEmbeddingsModel(
358+
"model",
359+
getUrl(webServer),
360+
"secret",
361+
user,
362+
dimensions,
363+
dimensionsSetByUser
364+
);
267365
var actionCreator = new LlamaActionCreator(sender, threadPool);
268366
var overriddenTaskSettings = createRequestTaskSettingsMap(user);
269367
var action = actionCreator.create(model, overriddenTaskSettings);
@@ -305,19 +403,15 @@ private void assertCompletionRequest(String user) throws IOException {
305403
}
306404

307405
@SuppressWarnings("unchecked")
308-
private void assertEmbeddingsRequest(String user) throws IOException {
406+
private void assertEmbeddingsRequest(String user, Integer dimensions) throws IOException {
309407
assertCommonRequestProperties();
310408

311409
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
312-
if (user == null) {
313-
assertThat(requestMap.size(), is(2));
314-
} else {
315-
assertThat(requestMap.size(), is(3));
316-
}
317410
assertThat(requestMap.get("input"), instanceOf(List.class));
318411
var inputList = (List<String>) requestMap.get("input");
319412
assertThat(inputList, contains("abc"));
320413
assertThat(requestMap.get("user"), is(user));
414+
assertThat(requestMap.get("dimensions"), is(dimensions));
321415
}
322416

323417
private void assertCommonRequestProperties() {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModelTests.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,19 @@
1717
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings;
1818

1919
public class LlamaEmbeddingsModelTests extends ESTestCase {
20-
public static LlamaEmbeddingsModel createEmbeddingsModel(String modelId, String url, String apiKey, String user) {
20+
public static LlamaEmbeddingsModel createEmbeddingsModel(
21+
String modelId,
22+
String url,
23+
String apiKey,
24+
String user,
25+
Integer dimensions,
26+
boolean dimensionsSetByUser
27+
) {
2128
return new LlamaEmbeddingsModel(
2229
"id",
2330
TaskType.TEXT_EMBEDDING,
2431
"llama",
25-
new LlamaEmbeddingsServiceSettings(modelId, url, null, null, null, false, null),
32+
new LlamaEmbeddingsServiceSettings(modelId, url, dimensions, null, null, dimensionsSetByUser, null),
2633
new OpenAiEmbeddingsTaskSettings(user),
2734
null,
2835
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ private HttpPost validateRequestUrlAndContentType(HttpRequest request) {
120120
}
121121

122122
private static LlamaEmbeddingsRequest createRequest(String user) {
123-
var embeddingsModel = LlamaEmbeddingsModelTests.createEmbeddingsModel("llama-embed", "url", "apikey", user);
123+
var embeddingsModel = LlamaEmbeddingsModelTests.createEmbeddingsModel("llama-embed", "url", "apikey", user, null, false);
124124
return new LlamaEmbeddingsRequest(
125125
TruncatorTests.createTruncator(),
126126
new Truncator.TruncationResult(List.of("ABCD"), new boolean[] { false }),

0 commit comments

Comments
 (0)