Skip to content

Commit b7cb871

Browse files
committed
Further tests
1 parent a89fdf1 commit b7cb871

File tree

8 files changed

+269
-20
lines changed

8 files changed

+269
-20
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public class VoyageAIEmbeddingsServiceSettings extends FilteredXContentObject im
4242
null, null, null, null, null
4343
);
4444

45-
static final String EMBEDDING_TYPE = "embedding_type";
45+
public static final String EMBEDDING_TYPE = "embedding_type";
4646

4747
public static VoyageAIEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
4848
ValidationException validationException = new ValidationException();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntityTests.java

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,128 @@
99

1010
import org.elasticsearch.common.Strings;
1111
import org.elasticsearch.inference.InputType;
12+
import org.elasticsearch.inference.SimilarityMeasure;
1213
import org.elasticsearch.test.ESTestCase;
1314
import org.elasticsearch.xcontent.XContentBuilder;
1415
import org.elasticsearch.xcontent.XContentFactory;
1516
import org.elasticsearch.xcontent.XContentType;
17+
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
18+
import org.elasticsearch.xpack.inference.services.ServiceFields;
19+
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
1620
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
1721
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
1822
import org.hamcrest.MatcherAssert;
1923

2024
import java.io.IOException;
25+
import java.util.HashMap;
2126
import java.util.List;
27+
import java.util.Map;
2228

2329
import static org.hamcrest.CoreMatchers.is;
2430

2531
public class VoyageAIEmbeddingsRequestEntityTests extends ESTestCase {
32+
public void testXContent_WritesAllFields_ServiceSettingsDefined() throws IOException {
33+
var entity = new VoyageAIEmbeddingsRequestEntity(
34+
List.of("abc"),
35+
VoyageAIEmbeddingsServiceSettings.fromMap(
36+
new HashMap<>(
37+
Map.of(
38+
ServiceFields.URL,
39+
"https://www.abc.com",
40+
ServiceFields.SIMILARITY,
41+
SimilarityMeasure.DOT_PRODUCT.toString(),
42+
ServiceFields.DIMENSIONS,
43+
2048,
44+
ServiceFields.MAX_INPUT_TOKENS,
45+
512,
46+
VoyageAIServiceSettings.MODEL_ID,
47+
"model",
48+
VoyageAIEmbeddingsServiceSettings.EMBEDDING_TYPE,
49+
"float"
50+
)
51+
),
52+
ConfigurationParseContext.PERSISTENT
53+
),
54+
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
55+
"model"
56+
);
57+
58+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
59+
entity.toXContent(builder, null);
60+
String xContentResult = Strings.toString(builder);
61+
62+
MatcherAssert.assertThat(xContentResult, is("""
63+
{"input":["abc"],"model":"model","input_type":"document","output_dimension":2048,"output_dtype":"float"}"""));
64+
}
65+
66+
public void testXContent_WritesAllFields_ServiceSettingsDefined_Int8() throws IOException {
67+
var entity = new VoyageAIEmbeddingsRequestEntity(
68+
List.of("abc"),
69+
VoyageAIEmbeddingsServiceSettings.fromMap(
70+
new HashMap<>(
71+
Map.of(
72+
ServiceFields.URL,
73+
"https://www.abc.com",
74+
ServiceFields.SIMILARITY,
75+
SimilarityMeasure.DOT_PRODUCT.toString(),
76+
ServiceFields.DIMENSIONS,
77+
2048,
78+
ServiceFields.MAX_INPUT_TOKENS,
79+
512,
80+
VoyageAIServiceSettings.MODEL_ID,
81+
"model",
82+
VoyageAIEmbeddingsServiceSettings.EMBEDDING_TYPE,
83+
"int8"
84+
)
85+
),
86+
ConfigurationParseContext.PERSISTENT
87+
),
88+
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
89+
"model"
90+
);
91+
92+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
93+
entity.toXContent(builder, null);
94+
String xContentResult = Strings.toString(builder);
95+
96+
MatcherAssert.assertThat(xContentResult, is("""
97+
{"input":["abc"],"model":"model","input_type":"document","output_dimension":2048,"output_dtype":"int8"}"""));
98+
}
99+
100+
public void testXContent_WritesAllFields_ServiceSettingsDefined_Binary() throws IOException {
101+
var entity = new VoyageAIEmbeddingsRequestEntity(
102+
List.of("abc"),
103+
VoyageAIEmbeddingsServiceSettings.fromMap(
104+
new HashMap<>(
105+
Map.of(
106+
ServiceFields.URL,
107+
"https://www.abc.com",
108+
ServiceFields.SIMILARITY,
109+
SimilarityMeasure.DOT_PRODUCT.toString(),
110+
ServiceFields.DIMENSIONS,
111+
2048,
112+
ServiceFields.MAX_INPUT_TOKENS,
113+
512,
114+
VoyageAIServiceSettings.MODEL_ID,
115+
"model",
116+
VoyageAIEmbeddingsServiceSettings.EMBEDDING_TYPE,
117+
"binary"
118+
)
119+
),
120+
ConfigurationParseContext.PERSISTENT
121+
),
122+
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
123+
"model"
124+
);
125+
126+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
127+
entity.toXContent(builder, null);
128+
String xContentResult = Strings.toString(builder);
129+
130+
MatcherAssert.assertThat(xContentResult, is("""
131+
{"input":["abc"],"model":"model","input_type":"document","output_dimension":2048,"output_dtype":"binary"}"""));
132+
}
133+
26134
public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
27135
var entity = new VoyageAIEmbeddingsRequestEntity(
28136
List.of("abc"),

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestTests.java

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.inference.InputType;
1313
import org.elasticsearch.test.ESTestCase;
1414
import org.elasticsearch.xcontent.XContentType;
15+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType;
1516
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
1617
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests;
1718
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
@@ -87,6 +88,77 @@ public void testCreateRequest_AllOptionsDefined() throws IOException {
8788
)));
8889
}
8990

91+
public void testCreateRequest_DimensionDefined() throws IOException {
92+
var request = createRequest(
93+
List.of("abc"),
94+
VoyageAIEmbeddingsModelTests.createModel(
95+
"url",
96+
"secret", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
97+
null,
98+
2048,
99+
"model"
100+
)
101+
);
102+
103+
var httpRequest = request.createHttpRequest();
104+
MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
105+
106+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
107+
108+
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
109+
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
110+
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
111+
MatcherAssert.assertThat(
112+
httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(),
113+
is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
114+
);
115+
116+
var requestMap = entityAsMap(httpPost.getEntity().getContent());
117+
MatcherAssert.assertThat(requestMap, is(Map.of(
118+
"input", List.of("abc"),
119+
"model", "model",
120+
"input_type", "document",
121+
"output_dtype", "float",
122+
"output_dimension", 2048
123+
)));
124+
}
125+
126+
public void testCreateRequest_EmbeddingTypeDefined() throws IOException {
127+
var request = createRequest(
128+
List.of("abc"),
129+
VoyageAIEmbeddingsModelTests.createModel(
130+
"url",
131+
"secret", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
132+
null,
133+
2048,
134+
"model",
135+
VoyageAIEmbeddingType.BYTE
136+
)
137+
);
138+
139+
var httpRequest = request.createHttpRequest();
140+
MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
141+
142+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
143+
144+
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
145+
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
146+
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
147+
MatcherAssert.assertThat(
148+
httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(),
149+
is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
150+
);
151+
152+
var requestMap = entityAsMap(httpPost.getEntity().getContent());
153+
MatcherAssert.assertThat(requestMap, is(Map.of(
154+
"input", List.of("abc"),
155+
"model", "model",
156+
"input_type", "document",
157+
"output_dtype", "int8",
158+
"output_dimension", 2048
159+
)));
160+
}
161+
90162
public void testCreateRequest_InputTypeSearch() throws IOException {
91163
var request = createRequest(
92164
List.of("abc"),

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,48 @@ public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumen
7979
"""));
8080
}
8181

82+
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationTrue() throws IOException {
83+
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, true), "model");
84+
85+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
86+
entity.toXContent(builder, null);
87+
String xContentResult = Strings.toString(builder);
88+
89+
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
90+
{
91+
"model": "model",
92+
"query": "query",
93+
"documents": [
94+
"abc"
95+
],
96+
"return_documents": false,
97+
"top_k": 8,
98+
"truncation": true
99+
}
100+
"""));
101+
}
102+
103+
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationFalse() throws IOException {
104+
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, false), "model");
105+
106+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
107+
entity.toXContent(builder, null);
108+
String xContentResult = Strings.toString(builder);
109+
110+
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
111+
{
112+
"model": "model",
113+
"query": "query",
114+
"documents": [
115+
"abc"
116+
],
117+
"return_documents": false,
118+
"top_k": 8,
119+
"truncation": false
120+
}
121+
"""));
122+
}
123+
82124
public void testXContent_SingleRequest_DoesNotWriteTopKIfNull() throws IOException {
83125
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), null, "model");
84126

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ public void testTruncate_DoesNotTruncate() {
102102
assertThat(truncatedRequest, sameInstance(request));
103103
}
104104

105-
private static VoyageAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topN) {
106-
var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, topN);
105+
private static VoyageAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topK) {
106+
var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, topK);
107107
return new VoyageAIRerankRequest(query, List.of(input), rerankModel);
108108

109109
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1830,18 +1830,18 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo
18301830
var results = listener.actionGet(TIMEOUT);
18311831
assertThat(results, hasSize(2));
18321832
{
1833-
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class));
1834-
var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0);
1833+
assertThat(results.getFirst(), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class));
1834+
var floatResult = (ChunkedInferenceEmbeddingFloat) results.getFirst();
18351835
assertThat(floatResult.chunks(), hasSize(1));
1836-
assertEquals("foo", floatResult.chunks().get(0).matchedText());
1837-
assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding(), 0.0f);
1836+
assertEquals("foo", floatResult.chunks().getFirst().matchedText());
1837+
assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().getFirst().embedding(), 0.0f);
18381838
}
18391839
{
18401840
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class));
18411841
var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1);
18421842
assertThat(floatResult.chunks(), hasSize(1));
1843-
assertEquals("bar", floatResult.chunks().get(0).matchedText());
1844-
assertArrayEquals(new float[] { 0.223f, -0.223f }, floatResult.chunks().get(0).embedding(), 0.0f);
1843+
assertEquals("bar", floatResult.chunks().getFirst().matchedText());
1844+
assertArrayEquals(new float[] { 0.223f, -0.223f }, floatResult.chunks().getFirst().embedding(), 0.0f);
18451845
}
18461846

18471847
MatcherAssert.assertThat(webServer.requests(), hasSize(1));

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModelTests.java

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,49 +14,58 @@
1414
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
1515

1616
public class VoyageAIRerankModelTests {
17+
public static VoyageAIRerankModel createModel(String apiKey, String modelId, @Nullable Integer topK, @Nullable Boolean truncation) {
18+
return new VoyageAIRerankModel(
19+
"id",
20+
"service",
21+
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)),
22+
new VoyageAIRerankTaskSettings(topK, null, truncation),
23+
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
24+
);
25+
}
1726

18-
public static VoyageAIRerankModel createModel(String apiKey, String modelId, @Nullable Integer topN) {
27+
public static VoyageAIRerankModel createModel(String apiKey, String modelId, @Nullable Integer topK) {
1928
return new VoyageAIRerankModel(
2029
"id",
2130
"service",
2231
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)),
23-
new VoyageAIRerankTaskSettings(topN, null, null),
32+
new VoyageAIRerankTaskSettings(topK, null, null),
2433
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
2534
);
2635
}
2736

28-
public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topN) {
37+
public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topK) {
2938
return new VoyageAIRerankModel(
3039
"id",
3140
"service",
3241
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)),
33-
new VoyageAIRerankTaskSettings(topN, null, null),
42+
new VoyageAIRerankTaskSettings(topK, null, null),
3443
new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8))
3544
);
3645
}
3746

38-
public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topN, Boolean returnDocuments, Boolean truncation) {
47+
public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topK, Boolean returnDocuments, Boolean truncation) {
3948
return new VoyageAIRerankModel(
4049
"id",
4150
"service",
4251
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)),
43-
new VoyageAIRerankTaskSettings(topN, returnDocuments, truncation),
52+
new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation),
4453
new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8))
4554
);
4655
}
4756

4857
public static VoyageAIRerankModel createModel(
4958
String url,
5059
String modelId,
51-
@Nullable Integer topN,
60+
@Nullable Integer topK,
5261
Boolean returnDocuments,
5362
Boolean truncation
5463
) {
5564
return new VoyageAIRerankModel(
5665
"id",
5766
"service",
5867
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(url, modelId, null)),
59-
new VoyageAIRerankTaskSettings(topN, returnDocuments, truncation),
68+
new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation),
6069
new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8))
6170
);
6271
}
@@ -65,15 +74,15 @@ public static VoyageAIRerankModel createModel(
6574
String url,
6675
String apiKey,
6776
String modelId,
68-
@Nullable Integer topN,
77+
@Nullable Integer topK,
6978
Boolean returnDocuments,
7079
Boolean truncation
7180
) {
7281
return new VoyageAIRerankModel(
7382
"id",
7483
"service",
7584
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(url, modelId, null)),
76-
new VoyageAIRerankTaskSettings(topN, returnDocuments, truncation),
85+
new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation),
7786
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
7887
);
7988
}

0 commit comments

Comments
 (0)