Skip to content

Commit 473dee6

Browse files
Refactored tests to reduce duplication
1 parent e170b96 commit 473dee6

File tree

1 file changed

+32
-40
lines changed

1 file changed

+32
-40
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
2828
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
2929
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
30+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
3031
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
3132
import org.elasticsearch.xpack.inference.services.ServiceComponents;
3233
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests;
@@ -462,31 +463,13 @@ public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() thro
462463
""";
463464
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
464465

465-
var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model");
466-
var actionCreator = new HuggingFaceActionCreator(sender, createWithEmptySettings(threadPool));
467-
var action = actionCreator.create(model);
468-
469-
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
470-
action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
466+
PlainActionFuture<InferenceServiceResults> listener = createChatCompletionFuture(sender, createWithEmptySettings(threadPool));
471467

472468
var result = listener.actionGet(TIMEOUT);
473469

474470
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?"))));
475471

476-
assertThat(webServer.requests(), hasSize(1));
477-
assertNull(webServer.requests().get(0).getUri().getQuery());
478-
assertThat(
479-
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
480-
equalTo(XContentType.JSON.mediaTypeWithoutParameters())
481-
);
482-
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
483-
484-
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
485-
assertThat(requestMap.size(), is(4));
486-
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello"))));
487-
assertThat(requestMap.get("model"), is("model"));
488-
assertThat(requestMap.get("n"), is(1));
489-
assertThat(requestMap.get("stream"), is(false));
472+
assertChatCompletionRequest();
490473
}
491474
}
492475

@@ -508,36 +491,45 @@ public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() th
508491
""";
509492
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
510493

511-
var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model");
512-
var actionCreator = new HuggingFaceActionCreator(
494+
PlainActionFuture<InferenceServiceResults> listener = createChatCompletionFuture(
513495
sender,
514496
new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator())
515497
);
516-
var action = actionCreator.create(model);
517-
518-
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
519-
action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
520498

521499
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
522500
assertThat(
523501
thrownException.getMessage(),
524502
is("Failed to send Hugging Face completion request from inference entity id " + "[id]. Cause: Required [choices]")
525503
);
526504

527-
assertThat(webServer.requests(), hasSize(1));
528-
assertNull(webServer.requests().get(0).getUri().getQuery());
529-
assertThat(
530-
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
531-
equalTo(XContentType.JSON.mediaTypeWithoutParameters())
532-
);
533-
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
534-
535-
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
536-
assertThat(requestMap.size(), is(4));
537-
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello"))));
538-
assertThat(requestMap.get("model"), is("model"));
539-
assertThat(requestMap.get("n"), is(1));
540-
assertThat(requestMap.get("stream"), is(false));
505+
assertChatCompletionRequest();
541506
}
542507
}
508+
509+
private PlainActionFuture<InferenceServiceResults> createChatCompletionFuture(Sender sender, ServiceComponents threadPool) {
510+
var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model");
511+
var actionCreator = new HuggingFaceActionCreator(sender, threadPool);
512+
var action = actionCreator.create(model);
513+
514+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
515+
action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
516+
return listener;
517+
}
518+
519+
private void assertChatCompletionRequest() throws IOException {
520+
assertThat(webServer.requests(), hasSize(1));
521+
assertNull(webServer.requests().get(0).getUri().getQuery());
522+
assertThat(
523+
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
524+
equalTo(XContentType.JSON.mediaTypeWithoutParameters())
525+
);
526+
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
527+
528+
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
529+
assertThat(requestMap.size(), is(4));
530+
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello"))));
531+
assertThat(requestMap.get("model"), is("model"));
532+
assertThat(requestMap.get("n"), is(1));
533+
assertThat(requestMap.get("stream"), is(false));
534+
}
543535
}

0 commit comments

Comments
 (0)