Skip to content

Commit 39c5787

Browse files
Add unit tests for LlamaChatCompletionServiceSettings to validate configuration parsing and serialization
1 parent 4eade05 commit 39c5787

File tree

1 file changed

+198
-0
lines changed

1 file changed

+198
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.llama.completion;
9+
10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.common.Strings;
12+
import org.elasticsearch.common.ValidationException;
13+
import org.elasticsearch.common.io.stream.Writeable;
14+
import org.elasticsearch.common.xcontent.XContentHelper;
15+
import org.elasticsearch.xcontent.XContentBuilder;
16+
import org.elasticsearch.xcontent.XContentFactory;
17+
import org.elasticsearch.xcontent.XContentType;
18+
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
19+
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
20+
import org.elasticsearch.xpack.inference.services.ServiceFields;
21+
import org.elasticsearch.xpack.inference.services.ServiceUtils;
22+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
23+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
24+
25+
import java.io.IOException;
26+
import java.util.HashMap;
27+
import java.util.Map;
28+
29+
import static org.hamcrest.Matchers.containsString;
30+
import static org.hamcrest.Matchers.is;
31+
32+
public class LlamaChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase<LlamaChatCompletionServiceSettings> {
33+
34+
public static final String MODEL_ID = "some model";
35+
public static final String CORRECT_URL = "https://www.elastic.co";
36+
public static final int RATE_LIMIT = 2;
37+
38+
public void testFromMap_AllFields_Success() {
39+
var serviceSettings = LlamaChatCompletionServiceSettings.fromMap(
40+
new HashMap<>(
41+
Map.of(
42+
ServiceFields.MODEL_ID,
43+
MODEL_ID,
44+
ServiceFields.URL,
45+
CORRECT_URL,
46+
RateLimitSettings.FIELD_NAME,
47+
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
48+
)
49+
),
50+
ConfigurationParseContext.PERSISTENT
51+
);
52+
53+
assertThat(serviceSettings, is(new LlamaChatCompletionServiceSettings(MODEL_ID, CORRECT_URL, new RateLimitSettings(RATE_LIMIT))));
54+
}
55+
56+
public void testFromMap_MissingModelId_ThrowsException() {
57+
var thrownException = expectThrows(
58+
ValidationException.class,
59+
() -> LlamaChatCompletionServiceSettings.fromMap(
60+
new HashMap<>(
61+
Map.of(
62+
ServiceFields.URL,
63+
CORRECT_URL,
64+
RateLimitSettings.FIELD_NAME,
65+
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
66+
)
67+
),
68+
ConfigurationParseContext.PERSISTENT
69+
)
70+
);
71+
72+
assertThat(
73+
thrownException.getMessage(),
74+
containsString("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];")
75+
);
76+
}
77+
78+
public void testFromMap_MissingUrl_ThrowsException() {
79+
var thrownException = expectThrows(
80+
ValidationException.class,
81+
() -> LlamaChatCompletionServiceSettings.fromMap(
82+
new HashMap<>(
83+
Map.of(
84+
ServiceFields.MODEL_ID,
85+
MODEL_ID,
86+
RateLimitSettings.FIELD_NAME,
87+
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
88+
)
89+
),
90+
ConfigurationParseContext.PERSISTENT
91+
)
92+
);
93+
94+
assertThat(
95+
thrownException.getMessage(),
96+
containsString("Validation Failed: 1: [service_settings] does not contain the required setting [url];")
97+
);
98+
}
99+
100+
public void testFromMap_MissingRateLimit_Success() {
101+
var serviceSettings = LlamaChatCompletionServiceSettings.fromMap(
102+
new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)),
103+
ConfigurationParseContext.PERSISTENT
104+
);
105+
106+
assertThat(serviceSettings, is(new LlamaChatCompletionServiceSettings(MODEL_ID, CORRECT_URL, null)));
107+
}
108+
109+
public void testToXContent_WritesAllValues() throws IOException {
110+
var serviceSettings = LlamaChatCompletionServiceSettings.fromMap(
111+
new HashMap<>(
112+
Map.of(
113+
ServiceFields.MODEL_ID,
114+
MODEL_ID,
115+
ServiceFields.URL,
116+
CORRECT_URL,
117+
RateLimitSettings.FIELD_NAME,
118+
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
119+
)
120+
),
121+
ConfigurationParseContext.PERSISTENT
122+
);
123+
124+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
125+
serviceSettings.toXContent(builder, null);
126+
String xContentResult = Strings.toString(builder);
127+
var expected = XContentHelper.stripWhitespace("""
128+
{
129+
"model_id": "some model",
130+
"url": "https://www.elastic.co",
131+
"rate_limit": {
132+
"requests_per_minute": 2
133+
}
134+
}
135+
""");
136+
137+
assertThat(xContentResult, is(expected));
138+
}
139+
140+
public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException {
141+
var serviceSettings = LlamaChatCompletionServiceSettings.fromMap(
142+
new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)),
143+
ConfigurationParseContext.PERSISTENT
144+
);
145+
146+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
147+
serviceSettings.toXContent(builder, null);
148+
String xContentResult = Strings.toString(builder);
149+
var expected = XContentHelper.stripWhitespace("""
150+
{
151+
"model_id": "some model",
152+
"url": "https://www.elastic.co",
153+
"rate_limit": {
154+
"requests_per_minute": 3000
155+
}
156+
}
157+
""");
158+
assertThat(xContentResult, is(expected));
159+
}
160+
161+
@Override
162+
protected Writeable.Reader<LlamaChatCompletionServiceSettings> instanceReader() {
163+
return LlamaChatCompletionServiceSettings::new;
164+
}
165+
166+
@Override
167+
protected LlamaChatCompletionServiceSettings createTestInstance() {
168+
return createRandom();
169+
}
170+
171+
@Override
172+
protected LlamaChatCompletionServiceSettings mutateInstance(LlamaChatCompletionServiceSettings instance) throws IOException {
173+
return randomValueOtherThan(instance, LlamaChatCompletionServiceSettingsTests::createRandom);
174+
}
175+
176+
@Override
177+
protected LlamaChatCompletionServiceSettings mutateInstanceForVersion(
178+
LlamaChatCompletionServiceSettings instance,
179+
TransportVersion version
180+
) {
181+
return instance;
182+
}
183+
184+
private static LlamaChatCompletionServiceSettings createRandom() {
185+
var modelId = randomAlphaOfLength(8);
186+
var url = randomAlphaOfLength(15);
187+
return new LlamaChatCompletionServiceSettings(modelId, ServiceUtils.createUri(url), RateLimitSettingsTests.createRandom());
188+
}
189+
190+
public static Map<String, Object> getServiceSettingsMap(String model, String url) {
191+
var map = new HashMap<String, Object>();
192+
193+
map.put(ServiceFields.MODEL_ID, model);
194+
map.put(ServiceFields.URL, url);
195+
196+
return map;
197+
}
198+
}

0 commit comments

Comments
 (0)