Skip to content

Commit 3e8c70a

Browse files
committed
Add ElasticInferenceServiceDenseTextEmbeddingsRequestTests
1 parent 9d47176 commit 3e8c70a

File tree

1 file changed

+165
-0
lines changed

1 file changed

+165
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elastic.request;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.elasticsearch.inference.InputType;
13+
import org.elasticsearch.tasks.Task;
14+
import org.elasticsearch.test.ESTestCase;
15+
import org.elasticsearch.xcontent.XContentType;
16+
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests;
17+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
18+
19+
import java.io.IOException;
20+
import java.util.List;
21+
22+
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
23+
import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata;
24+
import static org.hamcrest.Matchers.aMapWithSize;
25+
import static org.hamcrest.Matchers.equalTo;
26+
import static org.hamcrest.Matchers.instanceOf;
27+
import static org.hamcrest.Matchers.is;
28+
import static org.hamcrest.Matchers.nullValue;
29+
30+
public class ElasticInferenceServiceDenseTextEmbeddingsRequestTests extends ESTestCase {
31+
32+
public void testCreateHttpRequest_UsageContextSearch() throws IOException {
33+
var url = "http://eis-gateway.com";
34+
var input = List.of("input text");
35+
var modelId = "my-dense-model-id";
36+
37+
var request = createRequest(url, modelId, input, InputType.SEARCH);
38+
var httpRequest = request.createHttpRequest();
39+
40+
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
41+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
42+
43+
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
44+
var requestMap = entityAsMap(httpPost.getEntity().getContent());
45+
assertThat(requestMap.size(), equalTo(3));
46+
assertThat(requestMap.get("input"), is(input));
47+
assertThat(requestMap.get("model"), is(modelId));
48+
assertThat(requestMap.get("usage_context"), equalTo("search"));
49+
}
50+
51+
public void testCreateHttpRequest_UsageContextIngest() throws IOException {
52+
var url = "http://eis-gateway.com";
53+
var input = List.of("ingest text");
54+
var modelId = "my-dense-model-id";
55+
56+
var request = createRequest(url, modelId, input, InputType.INGEST);
57+
var httpRequest = request.createHttpRequest();
58+
59+
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
60+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
61+
62+
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
63+
var requestMap = entityAsMap(httpPost.getEntity().getContent());
64+
assertThat(requestMap.size(), equalTo(3));
65+
assertThat(requestMap.get("input"), is(input));
66+
assertThat(requestMap.get("model"), is(modelId));
67+
assertThat(requestMap.get("usage_context"), equalTo("ingest"));
68+
}
69+
70+
public void testCreateHttpRequest_UsageContextUnspecified() throws IOException {
71+
var url = "http://eis-gateway.com";
72+
var input = List.of("unspecified text");
73+
var modelId = "my-dense-model-id";
74+
75+
var request = createRequest(url, modelId, input, InputType.UNSPECIFIED);
76+
var httpRequest = request.createHttpRequest();
77+
78+
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
79+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
80+
81+
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
82+
var requestMap = entityAsMap(httpPost.getEntity().getContent());
83+
assertThat(requestMap, aMapWithSize(2));
84+
assertThat(requestMap.get("input"), is(input));
85+
assertThat(requestMap.get("model"), is(modelId));
86+
// usage_context should not be present for UNSPECIFIED
87+
}
88+
89+
public void testCreateHttpRequest_MultipleInputs() throws IOException {
90+
var url = "http://eis-gateway.com";
91+
var inputs = List.of("first input", "second input", "third input");
92+
var modelId = "my-dense-model-id";
93+
94+
var request = createRequest(url, modelId, inputs, InputType.SEARCH);
95+
var httpRequest = request.createHttpRequest();
96+
97+
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
98+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
99+
100+
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
101+
var requestMap = entityAsMap(httpPost.getEntity().getContent());
102+
assertThat(requestMap.size(), equalTo(3));
103+
assertThat(requestMap.get("input"), is(inputs));
104+
assertThat(requestMap.get("model"), is(modelId));
105+
assertThat(requestMap.get("usage_context"), equalTo("search"));
106+
}
107+
108+
public void testTraceContextPropagatedThroughHTTPHeaders() {
109+
var url = "http://eis-gateway.com";
110+
var input = List.of("input text");
111+
var modelId = "my-dense-model-id";
112+
113+
var request = createRequest(url, modelId, input, InputType.UNSPECIFIED);
114+
var httpRequest = request.createHttpRequest();
115+
116+
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
117+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
118+
119+
var traceParent = request.getTraceContext().traceParent();
120+
var traceState = request.getTraceContext().traceState();
121+
122+
assertThat(httpPost.getLastHeader(Task.TRACE_PARENT_HTTP_HEADER).getValue(), is(traceParent));
123+
assertThat(httpPost.getLastHeader(Task.TRACE_STATE).getValue(), is(traceState));
124+
}
125+
126+
public void testTruncate_ReturnsSameInstance() {
127+
var url = "http://eis-gateway.com";
128+
var input = List.of("input text");
129+
var modelId = "my-dense-model-id";
130+
131+
var request = createRequest(url, modelId, input, InputType.UNSPECIFIED);
132+
var truncatedRequest = request.truncate();
133+
134+
// Dense text embeddings request doesn't support truncation, should return same instance
135+
assertThat(truncatedRequest, is(request));
136+
}
137+
138+
public void testGetTruncationInfo_ReturnsNull() {
139+
var url = "http://eis-gateway.com";
140+
var input = List.of("input text");
141+
var modelId = "my-dense-model-id";
142+
143+
var request = createRequest(url, modelId, input, InputType.UNSPECIFIED);
144+
145+
// Dense text embeddings request doesn't support truncation info
146+
assertThat(request.getTruncationInfo(), is(nullValue()));
147+
}
148+
149+
private ElasticInferenceServiceDenseTextEmbeddingsRequest createRequest(
150+
String url,
151+
String modelId,
152+
List<String> inputs,
153+
InputType inputType
154+
) {
155+
var embeddingsModel = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(url, modelId, null);
156+
157+
return new ElasticInferenceServiceDenseTextEmbeddingsRequest(
158+
embeddingsModel,
159+
inputs,
160+
new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)),
161+
randomElasticInferenceServiceRequestMetadata(),
162+
inputType
163+
);
164+
}
165+
}

0 commit comments

Comments
 (0)