Skip to content

Commit fddfd9d

Browse files
committed
Add ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java
1 parent 5af7516 commit fddfd9d

File tree

2 files changed

+233
-5
lines changed

2 files changed

+233
-5
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
146146
this.dimensions = dimensions;
147147
this.maxInputTokens = maxInputTokens;
148148
this.dimensionsSetByUser = dimensionsSetByUser;
149-
this.rateLimitSettings = rateLimitSettings;
149+
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
150150
}
151151

152152
public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(StreamInput in) throws IOException {
@@ -239,25 +239,29 @@ public TransportVersion getMinimalSupportedVersion() {
239239

240240
@Override
241241
public void writeTo(StreamOutput out) throws IOException {
242+
out.writeString(modelId);
242243
out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion()));
243244
out.writeOptionalVInt(dimensions);
244245
out.writeOptionalVInt(maxInputTokens);
245246
out.writeBoolean(dimensionsSetByUser);
247+
rateLimitSettings.writeTo(out);
246248
}
247249

248250
@Override
249251
public boolean equals(Object o) {
250252
if (this == o) return true;
251253
if (o == null || getClass() != o.getClass()) return false;
252254
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings that = (ElasticInferenceServiceDenseTextEmbeddingsServiceSettings) o;
253-
return Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser)
254-
&& Objects.equals(similarity, that.similarity)
255+
return dimensionsSetByUser == that.dimensionsSetByUser
256+
&& Objects.equals(modelId, that.modelId)
257+
&& similarity == that.similarity
255258
&& Objects.equals(dimensions, that.dimensions)
256-
&& Objects.equals(maxInputTokens, that.maxInputTokens);
259+
&& Objects.equals(maxInputTokens, that.maxInputTokens)
260+
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
257261
}
258262

259263
@Override
260264
public int hashCode() {
261-
return Objects.hash(similarity, dimensions, maxInputTokens, dimensionsSetByUser);
265+
return Objects.hash(modelId, similarity, dimensions, maxInputTokens, dimensionsSetByUser, rateLimitSettings);
262266
}
263267
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings;
9+
10+
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.common.io.stream.Writeable;
12+
import org.elasticsearch.inference.SimilarityMeasure;
13+
import org.elasticsearch.test.AbstractWireSerializingTestCase;
14+
import org.elasticsearch.xcontent.XContentBuilder;
15+
import org.elasticsearch.xcontent.XContentFactory;
16+
import org.elasticsearch.xcontent.XContentType;
17+
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
18+
import org.elasticsearch.xpack.inference.services.ServiceFields;
19+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
20+
21+
import java.io.IOException;
22+
import java.util.HashMap;
23+
import java.util.Map;
24+
25+
import static org.hamcrest.Matchers.is;
26+
27+
public class ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase<
28+
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings> {
29+
30+
@Override
31+
protected Writeable.Reader<ElasticInferenceServiceDenseTextEmbeddingsServiceSettings> instanceReader() {
32+
return ElasticInferenceServiceDenseTextEmbeddingsServiceSettings::new;
33+
}
34+
35+
@Override
36+
protected ElasticInferenceServiceDenseTextEmbeddingsServiceSettings createTestInstance() {
37+
return createRandom();
38+
}
39+
40+
@Override
41+
protected ElasticInferenceServiceDenseTextEmbeddingsServiceSettings mutateInstance(
42+
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings instance
43+
) throws IOException {
44+
return randomValueOtherThan(instance, ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests::createRandom);
45+
}
46+
47+
public void testFromMap_Request_WithAllSettings() {
48+
var modelId = "my-dense-model-id";
49+
var similarity = SimilarityMeasure.COSINE;
50+
var dimensions = 384;
51+
var maxInputTokens = 512;
52+
53+
var serviceSettings = ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap(
54+
new HashMap<>(
55+
Map.of(
56+
ServiceFields.MODEL_ID,
57+
modelId,
58+
ServiceFields.SIMILARITY,
59+
similarity.toString(),
60+
ServiceFields.DIMENSIONS,
61+
dimensions,
62+
ServiceFields.MAX_INPUT_TOKENS,
63+
maxInputTokens
64+
)
65+
),
66+
ConfigurationParseContext.REQUEST
67+
);
68+
69+
assertThat(serviceSettings.modelId(), is(modelId));
70+
assertThat(serviceSettings.similarity(), is(similarity));
71+
assertThat(serviceSettings.dimensions(), is(dimensions));
72+
assertThat(serviceSettings.maxInputTokens(), is(maxInputTokens));
73+
assertThat(serviceSettings.dimensionsSetByUser(), is(true)); // dimensions were provided
74+
}
75+
76+
public void testFromMap_Persistent_WithDimensionsSetByUser() {
77+
var modelId = "my-dense-model-id";
78+
var similarity = SimilarityMeasure.DOT_PRODUCT;
79+
var dimensions = 768;
80+
var dimensionsSetByUser = true;
81+
82+
var serviceSettings = ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap(
83+
new HashMap<>(
84+
Map.of(
85+
ServiceFields.MODEL_ID,
86+
modelId,
87+
ServiceFields.SIMILARITY,
88+
similarity.toString(),
89+
ServiceFields.DIMENSIONS,
90+
dimensions,
91+
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER,
92+
dimensionsSetByUser
93+
)
94+
),
95+
ConfigurationParseContext.PERSISTENT
96+
);
97+
98+
assertThat(serviceSettings.modelId(), is(modelId));
99+
assertThat(serviceSettings.similarity(), is(similarity));
100+
assertThat(serviceSettings.dimensions(), is(dimensions));
101+
assertThat(serviceSettings.dimensionsSetByUser(), is(dimensionsSetByUser));
102+
}
103+
104+
public void testFromMap_Persistent_WithoutDimensionsSetByUser_DefaultsToFalse() {
105+
var modelId = "my-dense-model-id";
106+
107+
var serviceSettings = ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap(
108+
new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)),
109+
ConfigurationParseContext.PERSISTENT
110+
);
111+
112+
assertThat(serviceSettings.dimensionsSetByUser(), is(false));
113+
}
114+
115+
public void testToXContent_WritesAllFields() throws IOException {
116+
var modelId = "my-dense-model";
117+
var similarity = SimilarityMeasure.DOT_PRODUCT;
118+
var dimensions = 1024;
119+
var maxInputTokens = 256;
120+
var dimensionsSetByUser = true;
121+
var rateLimitSettings = new RateLimitSettings(5000);
122+
123+
var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
124+
modelId,
125+
similarity,
126+
dimensions,
127+
maxInputTokens,
128+
dimensionsSetByUser,
129+
rateLimitSettings
130+
);
131+
132+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
133+
serviceSettings.toXContent(builder, null);
134+
String xContentResult = Strings.toString(builder);
135+
136+
assertThat(
137+
xContentResult,
138+
is(
139+
Strings.format(
140+
"""
141+
{"similarity":"%s","dimensions":%d,"max_input_tokens":%d,"model_id":"%s","rate_limit":{"requests_per_minute":%d},"dimensions_set_by_user":%s}""",
142+
similarity,
143+
dimensions,
144+
maxInputTokens,
145+
modelId,
146+
rateLimitSettings.requestsPerTimeUnit(),
147+
dimensionsSetByUser
148+
)
149+
)
150+
);
151+
}
152+
153+
public void testToXContent_WritesOnlyNonNullFields() throws IOException {
154+
var modelId = "my-dense-model";
155+
var rateLimitSettings = new RateLimitSettings(2000);
156+
157+
var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
158+
modelId,
159+
null, // similarity
160+
null, // dimensions
161+
null, // maxInputTokens
162+
false, // dimensionsSetByUser
163+
rateLimitSettings
164+
);
165+
166+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
167+
serviceSettings.toXContent(builder, null);
168+
String xContentResult = Strings.toString(builder);
169+
170+
assertThat(
171+
xContentResult,
172+
is(
173+
Strings.format(
174+
"""
175+
{"model_id":"%s","rate_limit":{"requests_per_minute":%d},"dimensions_set_by_user":false}""",
176+
modelId,
177+
rateLimitSettings.requestsPerTimeUnit()
178+
)
179+
)
180+
);
181+
}
182+
183+
public void testToXContentFragmentOfExposedFields() throws IOException {
184+
var modelId = "my-dense-model";
185+
var rateLimitSettings = new RateLimitSettings(1500);
186+
187+
var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
188+
modelId,
189+
SimilarityMeasure.COSINE,
190+
512,
191+
128,
192+
true,
193+
rateLimitSettings
194+
);
195+
196+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
197+
builder.startObject();
198+
serviceSettings.toXContentFragmentOfExposedFields(builder, null);
199+
builder.endObject();
200+
String xContentResult = Strings.toString(builder);
201+
202+
// Only model_id and rate_limit should be in exposed fields
203+
assertThat(xContentResult, is(Strings.format("""
204+
{"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", modelId, rateLimitSettings.requestsPerTimeUnit())));
205+
}
206+
207+
public static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings createRandom() {
208+
var modelId = randomAlphaOfLength(10);
209+
var similarity = SimilarityMeasure.COSINE;
210+
var dimensions = randomBoolean() ? randomIntBetween(1, 1024) : null;
211+
var maxInputTokens = randomBoolean() ? randomIntBetween(128, 256) : null;
212+
var dimensionsSetByUser = randomBoolean();
213+
var rateLimitSettings = randomBoolean() ? new RateLimitSettings(randomIntBetween(1, 10000)) : null;
214+
215+
return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
216+
modelId,
217+
similarity,
218+
dimensions,
219+
maxInputTokens,
220+
dimensionsSetByUser,
221+
rateLimitSettings
222+
);
223+
}
224+
}

0 commit comments

Comments
 (0)