Skip to content

Commit 38a58f9

Browse files
Adding some initial tests
1 parent be588f4 commit 38a58f9

File tree

4 files changed

+142
-4
lines changed

4 files changed

+142
-4
lines changed

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,13 @@ public void writeTo(StreamOutput out) throws IOException {
111111
out.writeOptionalFloat(topP);
112112
}
113113

114-
public record Message(Content content, String role, @Nullable String name, @Nullable String toolCallId, List<ToolCall> toolCalls)
115-
implements
116-
Writeable {
114+
public record Message(
115+
Content content,
116+
String role,
117+
@Nullable String name,
118+
@Nullable String toolCallId,
119+
@Nullable List<ToolCall> toolCalls
120+
) implements Writeable {
117121

118122
@SuppressWarnings("unchecked")
119123
static final ConstructingObjectParser<Message, Void> PARSER = new ConstructingObjectParser<>(

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ public void testFromMap_InvalidElserModelId() {
7171
assertThat(validationException.getMessage(), containsString(Strings.format("unknown ELSER model id [%s]", invalidModelId)));
7272
}
7373

74-
public void testToXContent_WritesAlLFields() throws IOException {
74+
public void testToXContent_WritesAllFields() throws IOException {
7575
var modelId = ElserModels.ELSER_V1_MODEL;
7676
var maxInputTokens = 10;
7777
var serviceSettings = new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elastic.completion;
9+
10+
import org.elasticsearch.inference.EmptySecretSettings;
11+
import org.elasticsearch.inference.EmptyTaskSettings;
12+
import org.elasticsearch.inference.TaskType;
13+
import org.elasticsearch.inference.UnifiedCompletionRequest;
14+
import org.elasticsearch.test.ESTestCase;
15+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
16+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
17+
18+
import java.util.List;
19+
20+
import static org.hamcrest.Matchers.is;
21+
22+
// TODO determine if we need the model id
23+
public class EISCompletionModelTests extends ESTestCase {
24+
25+
public void testOverridingModelId() {
26+
var originalModel = new ElasticInferenceServiceCompletionModel(
27+
"id",
28+
TaskType.COMPLETION,
29+
"elastic",
30+
new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)),
31+
EmptyTaskSettings.INSTANCE,
32+
EmptySecretSettings.INSTANCE,
33+
new ElasticInferenceServiceComponents("url")
34+
);
35+
36+
var request = new UnifiedCompletionRequest(
37+
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("message"), "user", null, null, null)),
38+
"new_model_id",
39+
null,
40+
null,
41+
null,
42+
null,
43+
null,
44+
null
45+
);
46+
47+
var overriddenModel = ElasticInferenceServiceCompletionModel.of(originalModel, request);
48+
49+
assertThat(overriddenModel.getServiceSettings().modelId(), is("new_model_id"));
50+
assertThat(overriddenModel.getTaskType(), is(TaskType.COMPLETION));
51+
}
52+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elastic.completion;
9+
10+
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.common.ValidationException;
12+
import org.elasticsearch.common.io.stream.Writeable;
13+
import org.elasticsearch.test.AbstractWireSerializingTestCase;
14+
import org.elasticsearch.xcontent.XContentBuilder;
15+
import org.elasticsearch.xcontent.XContentFactory;
16+
import org.elasticsearch.xcontent.XContentType;
17+
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
18+
import org.elasticsearch.xpack.inference.services.ServiceFields;
19+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
20+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
21+
22+
import java.io.IOException;
23+
import java.util.HashMap;
24+
import java.util.Map;
25+
26+
import static org.hamcrest.Matchers.containsString;
27+
import static org.hamcrest.Matchers.is;
28+
29+
public class EISCompletionServiceSettingsTests extends AbstractWireSerializingTestCase<ElasticInferenceServiceCompletionServiceSettings> {
30+
31+
@Override
32+
protected Writeable.Reader<ElasticInferenceServiceCompletionServiceSettings> instanceReader() {
33+
return ElasticInferenceServiceCompletionServiceSettings::new;
34+
}
35+
36+
@Override
37+
protected ElasticInferenceServiceCompletionServiceSettings createTestInstance() {
38+
return createRandom();
39+
}
40+
41+
@Override
42+
protected ElasticInferenceServiceCompletionServiceSettings mutateInstance(ElasticInferenceServiceCompletionServiceSettings instance)
43+
throws IOException {
44+
return randomValueOtherThan(instance, EISCompletionServiceSettingsTests::createRandom);
45+
}
46+
47+
public void testFromMap() {
48+
var modelId = "model_id";
49+
50+
var serviceSettings = ElasticInferenceServiceCompletionServiceSettings.fromMap(
51+
new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)),
52+
ConfigurationParseContext.REQUEST
53+
);
54+
55+
assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(1000))));
56+
}
57+
58+
public void testFromMap_MissingModelId_ThrowsException() {
59+
ValidationException validationException = expectThrows(
60+
ValidationException.class,
61+
() -> ElasticInferenceServiceCompletionServiceSettings.fromMap(new HashMap<>(Map.of()), ConfigurationParseContext.REQUEST)
62+
);
63+
64+
assertThat(validationException.getMessage(), containsString("does not contain the required setting [model_id]"));
65+
}
66+
67+
public void testToXContent_WritesAllFields() throws IOException {
68+
var modelId = "model_id";
69+
var serviceSettings = new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(1000));
70+
71+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
72+
serviceSettings.toXContent(builder, null);
73+
String xContentResult = Strings.toString(builder);
74+
75+
assertThat(xContentResult, is(Strings.format("""
76+
{"model_id":"%s","rate_limit":{"requests_per_minute":1000}}""", modelId)));
77+
}
78+
79+
public static ElasticInferenceServiceCompletionServiceSettings createRandom() {
80+
return new ElasticInferenceServiceCompletionServiceSettings(randomAlphaOfLength(4), RateLimitSettingsTests.createRandom());
81+
}
82+
}

0 commit comments

Comments
 (0)