Skip to content

Commit 6a135c5

Browse files
Add unit tests for LlamaEmbeddingsServiceSettings to validate configuration parsing and serialization
1 parent 39c5787 commit 6a135c5

File tree

1 file changed

+183
-0
lines changed

1 file changed

+183
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.llama.embeddings;
9+
10+
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.common.ValidationException;
12+
import org.elasticsearch.common.io.stream.ByteArrayStreamInput;
13+
import org.elasticsearch.common.io.stream.BytesStreamOutput;
14+
import org.elasticsearch.common.io.stream.Writeable;
15+
import org.elasticsearch.common.xcontent.XContentHelper;
16+
import org.elasticsearch.inference.SimilarityMeasure;
17+
import org.elasticsearch.test.AbstractWireSerializingTestCase;
18+
import org.elasticsearch.xcontent.XContentBuilder;
19+
import org.elasticsearch.xcontent.XContentFactory;
20+
import org.elasticsearch.xcontent.XContentType;
21+
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
22+
import org.elasticsearch.xpack.inference.services.ServiceFields;
23+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
24+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
25+
import org.hamcrest.CoreMatchers;
26+
27+
import java.io.IOException;
28+
import java.util.HashMap;
29+
import java.util.Map;
30+
31+
import static org.hamcrest.Matchers.containsString;
32+
import static org.hamcrest.Matchers.is;
33+
34+
public class LlamaEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase<LlamaEmbeddingsServiceSettings> {
35+
private static final String MODEL_ID = "some model";
36+
private static final String CORRECT_URL = "https://www.elastic.co";
37+
private static final int DIMENSIONS = 384;
38+
private static final SimilarityMeasure SIMILARITY_MEASURE = SimilarityMeasure.DOT_PRODUCT;
39+
private static final int MAX_INPUT_TOKENS = 128;
40+
private static final int RATE_LIMIT = 2;
41+
42+
public void testFromMap_AllFields_Success() {
43+
var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap(
44+
new HashMap<>(
45+
Map.of(
46+
ServiceFields.MODEL_ID,
47+
MODEL_ID,
48+
ServiceFields.URL,
49+
CORRECT_URL,
50+
ServiceFields.SIMILARITY,
51+
SIMILARITY_MEASURE.toString(),
52+
ServiceFields.DIMENSIONS,
53+
DIMENSIONS,
54+
ServiceFields.MAX_INPUT_TOKENS,
55+
MAX_INPUT_TOKENS,
56+
RateLimitSettings.FIELD_NAME,
57+
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
58+
)
59+
),
60+
ConfigurationParseContext.PERSISTENT
61+
);
62+
63+
assertThat(
64+
serviceSettings,
65+
is(
66+
new LlamaEmbeddingsServiceSettings(
67+
MODEL_ID,
68+
CORRECT_URL,
69+
DIMENSIONS,
70+
SIMILARITY_MEASURE,
71+
MAX_INPUT_TOKENS,
72+
new RateLimitSettings(RATE_LIMIT)
73+
)
74+
)
75+
);
76+
}
77+
78+
public void testFromMap_NoModelId_Failure() {
79+
var thrownException = expectThrows(
80+
ValidationException.class,
81+
() -> LlamaEmbeddingsServiceSettings.fromMap(
82+
new HashMap<>(
83+
Map.of(
84+
ServiceFields.URL,
85+
CORRECT_URL,
86+
ServiceFields.SIMILARITY,
87+
SIMILARITY_MEASURE.toString(),
88+
ServiceFields.DIMENSIONS,
89+
DIMENSIONS,
90+
ServiceFields.MAX_INPUT_TOKENS,
91+
MAX_INPUT_TOKENS,
92+
RateLimitSettings.FIELD_NAME,
93+
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
94+
)
95+
),
96+
ConfigurationParseContext.PERSISTENT
97+
)
98+
);
99+
assertThat(
100+
thrownException.getMessage(),
101+
containsString(Strings.format("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];", 2))
102+
);
103+
}
104+
105+
public void testToXContent_WritesAllValues() throws IOException {
106+
var entity = new LlamaEmbeddingsServiceSettings(
107+
MODEL_ID,
108+
CORRECT_URL,
109+
DIMENSIONS,
110+
SIMILARITY_MEASURE,
111+
MAX_INPUT_TOKENS,
112+
new RateLimitSettings(3)
113+
);
114+
115+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
116+
entity.toXContent(builder, null);
117+
String xContentResult = Strings.toString(builder);
118+
119+
assertThat(xContentResult, CoreMatchers.is(XContentHelper.stripWhitespace("""
120+
{
121+
"model_id": "some model",
122+
"url": "https://www.elastic.co",
123+
"dimensions": 384,
124+
"similarity": "dot_product",
125+
"max_input_tokens": 128,
126+
"rate_limit": {
127+
"requests_per_minute": 3
128+
}
129+
}
130+
""")));
131+
}
132+
133+
public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException {
134+
var outputBuffer = new BytesStreamOutput();
135+
var settings = new LlamaEmbeddingsServiceSettings(
136+
MODEL_ID,
137+
CORRECT_URL,
138+
DIMENSIONS,
139+
SIMILARITY_MEASURE,
140+
MAX_INPUT_TOKENS,
141+
new RateLimitSettings(3)
142+
);
143+
settings.writeTo(outputBuffer);
144+
145+
var outputBufferRef = outputBuffer.bytes();
146+
var inputBuffer = new ByteArrayStreamInput(outputBufferRef.array());
147+
148+
var settingsFromBuffer = new LlamaEmbeddingsServiceSettings(inputBuffer);
149+
150+
assertEquals(settings, settingsFromBuffer);
151+
}
152+
153+
@Override
154+
protected Writeable.Reader<LlamaEmbeddingsServiceSettings> instanceReader() {
155+
return LlamaEmbeddingsServiceSettings::new;
156+
}
157+
158+
@Override
159+
protected LlamaEmbeddingsServiceSettings createTestInstance() {
160+
return createRandom();
161+
}
162+
163+
@Override
164+
protected LlamaEmbeddingsServiceSettings mutateInstance(LlamaEmbeddingsServiceSettings instance) throws IOException {
165+
return randomValueOtherThan(instance, LlamaEmbeddingsServiceSettingsTests::createRandom);
166+
}
167+
168+
private static LlamaEmbeddingsServiceSettings createRandom() {
169+
var modelId = randomAlphaOfLength(8);
170+
var url = randomAlphaOfLength(15);
171+
var similarityMeasure = randomFrom(SimilarityMeasure.values());
172+
var dimensions = randomIntBetween(32, 256);
173+
var maxInputTokens = randomIntBetween(128, 256);
174+
return new LlamaEmbeddingsServiceSettings(
175+
modelId,
176+
url,
177+
dimensions,
178+
similarityMeasure,
179+
maxInputTokens,
180+
RateLimitSettingsTests.createRandom()
181+
);
182+
}
183+
}

0 commit comments

Comments
 (0)