Skip to content

Commit c193ecf

Browse files
Add tests for LlamaService request configuration validation and error handling
1 parent 36ff4cd commit c193ecf

File tree

2 files changed

+205
-1
lines changed

2 files changed

+205
-1
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.elasticsearch.inference.InputType;
2222
import org.elasticsearch.inference.Model;
2323
import org.elasticsearch.inference.ModelConfigurations;
24+
import org.elasticsearch.inference.SimilarityMeasure;
2425
import org.elasticsearch.inference.TaskType;
2526
import org.elasticsearch.inference.UnifiedCompletionRequest;
2627
import org.elasticsearch.rest.RestStatus;
@@ -39,12 +40,14 @@
3940
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
4041
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel;
4142
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests;
43+
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionServiceSettingsTests;
4244
import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel;
4345
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
4446
import org.junit.After;
4547
import org.junit.Before;
4648

4749
import java.io.IOException;
50+
import java.util.Collections;
4851
import java.util.EnumSet;
4952
import java.util.HashMap;
5053
import java.util.List;
@@ -63,6 +66,7 @@
6366
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
6467
import static org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests.createChatCompletionModel;
6568
import static org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionServiceSettingsTests.getServiceSettingsMap;
69+
import static org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettingsTests.buildServiceSettingsMap;
6670
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
6771
import static org.hamcrest.CoreMatchers.is;
6872
import static org.hamcrest.Matchers.equalTo;
@@ -496,13 +500,198 @@ public void testInfer_StreamRequest_ErrorResponse() {
496500
}]""", getUrl(webServer))));
497501
}
498502

503+
public void testInfer_StreamRequestRetry() throws Exception {
504+
webServer.enqueue(new MockResponse().setResponseCode(503).setBody("""
505+
{
506+
"error": {
507+
"message": "server busy"
508+
}
509+
}"""));
510+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody("""
511+
data: {\
512+
"id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26",\
513+
"choices": [{\
514+
"delta": {\
515+
"content": "Deep",\
516+
"function_call": null,\
517+
"refusal": null,\
518+
"role": "assistant",\
519+
"tool_calls": null\
520+
},\
521+
"finish_reason": null,\
522+
"index": 0,\
523+
"logprobs": null\
524+
}\
525+
],\
526+
"created": 1750158492,\
527+
"model": "llama3.2:3b",\
528+
"object": "chat.completion.chunk",\
529+
"service_tier": null,\
530+
"system_fingerprint": "fp_ollama",\
531+
"usage": null\
532+
}
533+
534+
"""));
535+
536+
streamCompletion().hasNoErrors().hasEvent("""
537+
{"completion":[{"delta":"Deep"}]}""");
538+
}
539+
499540
public void testSupportsStreaming() throws IOException {
500541
try (var service = new LlamaService(mock(), createWithEmptySettings(mock()))) {
501542
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)));
502543
assertFalse(service.canStream(TaskType.ANY));
503544
}
504545
}
505546

547+
public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
548+
try (var service = createService()) {
549+
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
550+
model -> fail("Expected exception, but got model: " + model),
551+
exception -> {
552+
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
553+
assertThat(exception.getMessage(), is("The [llama] service does not support task type [sparse_embedding]"));
554+
}
555+
);
556+
557+
service.parseRequestConfig(
558+
"id",
559+
TaskType.SPARSE_EMBEDDING,
560+
getRequestConfigMap(getEmbeddingsServiceSettingsMap(), getSecretSettingsMap("secret")),
561+
modelVerificationListener
562+
);
563+
}
564+
}
565+
566+
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
567+
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig(
568+
getRequestConfigMap(getEmbeddingsServiceSettingsMap(), getSecretSettingsMap("secret")),
569+
TaskType.TEXT_EMBEDDING
570+
);
571+
}
572+
573+
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig_Completion() throws IOException {
574+
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig(
575+
getRequestConfigMap(
576+
LlamaChatCompletionServiceSettingsTests.getServiceSettingsMap("llama-completion", "url"),
577+
getSecretSettingsMap("secret")
578+
),
579+
TaskType.COMPLETION
580+
);
581+
}
582+
583+
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig_ChatCompletion() throws IOException {
584+
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig(
585+
getRequestConfigMap(
586+
LlamaChatCompletionServiceSettingsTests.getServiceSettingsMap("llama-chat-completion", "url"),
587+
getSecretSettingsMap("secret")
588+
),
589+
TaskType.CHAT_COMPLETION
590+
);
591+
}
592+
593+
private void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig(Map<String, Object> secret, TaskType chatCompletion)
594+
throws IOException {
595+
try (var service = createService()) {
596+
secret.put("extra_key", "value");
597+
598+
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
599+
model -> fail("Expected exception, but got model: " + model),
600+
exception -> {
601+
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
602+
assertThat(
603+
exception.getMessage(),
604+
is("Configuration contains settings [{extra_key=value}] unknown to the [llama] service")
605+
);
606+
}
607+
);
608+
609+
service.parseRequestConfig("id", chatCompletion, secret, modelVerificationListener);
610+
}
611+
}
612+
613+
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingTaskSettingsMap() throws IOException {
614+
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap(getEmbeddingsServiceSettingsMap(), TaskType.TEXT_EMBEDDING);
615+
}
616+
617+
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap(
618+
Map<String, Object> serviceSettingsMap,
619+
TaskType chatCompletion
620+
) throws IOException {
621+
try (var service = createService()) {
622+
var taskSettings = new HashMap<String, Object>();
623+
taskSettings.put("extra_key", "value");
624+
625+
var config = getRequestConfigMap(serviceSettingsMap, Collections.emptyMap(), getSecretSettingsMap("secret"), taskSettings);
626+
627+
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
628+
model -> fail("Expected exception, but got model: " + model),
629+
exception -> {
630+
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
631+
assertThat(
632+
exception.getMessage(),
633+
is("Configuration contains settings [{extra_key=value}] unknown to the [llama] service")
634+
);
635+
}
636+
);
637+
638+
service.parseRequestConfig("id", chatCompletion, config, modelVerificationListener);
639+
}
640+
}
641+
642+
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException {
643+
try (var service = createService()) {
644+
var secretSettings = getSecretSettingsMap("secret");
645+
secretSettings.put("extra_key", "value");
646+
647+
var config = getRequestConfigMap(getEmbeddingsServiceSettingsMap(), secretSettings);
648+
649+
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
650+
model -> fail("Expected exception, but got model: " + model),
651+
exception -> {
652+
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
653+
assertThat(
654+
exception.getMessage(),
655+
is("Configuration contains settings [{extra_key=value}] unknown to the [llama] service")
656+
);
657+
}
658+
);
659+
660+
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
661+
}
662+
}
663+
664+
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInCompletionSecretSettingsMap() throws IOException {
665+
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap("llama-completion", TaskType.COMPLETION);
666+
}
667+
668+
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionSecretSettingsMap() throws IOException {
669+
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap("llama-chat-completion", TaskType.CHAT_COMPLETION);
670+
}
671+
672+
private void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap(String modelId, TaskType chatCompletion)
673+
throws IOException {
674+
try (var service = createService()) {
675+
var secretSettings = getSecretSettingsMap("secret");
676+
secretSettings.put("extra_key", "value");
677+
678+
var config = getRequestConfigMap(LlamaChatCompletionServiceSettingsTests.getServiceSettingsMap(modelId, "url"), secretSettings);
679+
680+
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
681+
model -> fail("Expected exception, but got model: " + model),
682+
exception -> {
683+
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
684+
assertThat(
685+
exception.getMessage(),
686+
is("Configuration contains settings [{extra_key=value}] unknown to the [llama] service")
687+
);
688+
}
689+
);
690+
691+
service.parseRequestConfig("id", chatCompletion, config, modelVerificationListener);
692+
}
693+
}
694+
506695
private InferenceEventsAssertion streamCompletion() throws Exception {
507696
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
508697
try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool))) {
@@ -529,6 +718,17 @@ private LlamaService createService() {
529718
return new LlamaService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool));
530719
}
531720

721+
private Map<String, Object> getRequestConfigMap(
722+
Map<String, Object> serviceSettings,
723+
Map<String, Object> chunkingSettings,
724+
Map<String, Object> secretSettings,
725+
Map<String, Object> taskSettings
726+
) {
727+
var requestConfigMap = getRequestConfigMap(serviceSettings, chunkingSettings, secretSettings);
728+
requestConfigMap.put(ModelConfigurations.TASK_SETTINGS, taskSettings);
729+
return requestConfigMap;
730+
}
731+
532732
private Map<String, Object> getRequestConfigMap(
533733
Map<String, Object> serviceSettings,
534734
Map<String, Object> chunkingSettings,
@@ -547,4 +747,8 @@ private Map<String, Object> getRequestConfigMap(Map<String, Object> serviceSetti
547747

548748
return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings));
549749
}
750+
751+
private static Map<String, Object> getEmbeddingsServiceSettingsMap() {
752+
return buildServiceSettingsMap("id", "url", SimilarityMeasure.COSINE.toString(), null, null, null);
753+
}
550754
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettingsTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ private static LlamaEmbeddingsServiceSettings createRandom() {
447447
);
448448
}
449449

450-
private static HashMap<String, Object> buildServiceSettingsMap(
450+
public static HashMap<String, Object> buildServiceSettingsMap(
451451
@Nullable String modelId,
452452
@Nullable String url,
453453
@Nullable String similarity,

0 commit comments

Comments
 (0)