Skip to content

Commit 1b15536

Browse files
committed
Added SemanticTextIndexOptionsIT
1 parent 137eac5 commit 1b15536

File tree

1 file changed

+232
-0
lines changed

1 file changed

+232
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.integration;
9+
10+
import org.elasticsearch.action.admin.indices.mapping.get.GetFieldMappingsAction;
11+
import org.elasticsearch.action.admin.indices.mapping.get.GetFieldMappingsRequest;
12+
import org.elasticsearch.action.support.IndicesOptions;
13+
import org.elasticsearch.common.bytes.BytesReference;
14+
import org.elasticsearch.common.settings.Settings;
15+
import org.elasticsearch.common.xcontent.XContentHelper;
16+
import org.elasticsearch.core.Nullable;
17+
import org.elasticsearch.core.TimeValue;
18+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
19+
import org.elasticsearch.index.mapper.vectors.IndexOptions;
20+
import org.elasticsearch.inference.TaskType;
21+
import org.elasticsearch.license.License;
22+
import org.elasticsearch.license.LicenseSettings;
23+
import org.elasticsearch.license.PostStartBasicAction;
24+
import org.elasticsearch.license.PostStartBasicRequest;
25+
import org.elasticsearch.license.PutLicenseAction;
26+
import org.elasticsearch.license.PutLicenseRequest;
27+
import org.elasticsearch.license.TestUtils;
28+
import org.elasticsearch.plugins.Plugin;
29+
import org.elasticsearch.reindex.ReindexPlugin;
30+
import org.elasticsearch.test.ESIntegTestCase;
31+
import org.elasticsearch.xcontent.ToXContent;
32+
import org.elasticsearch.xcontent.XContentBuilder;
33+
import org.elasticsearch.xcontent.XContentFactory;
34+
import org.elasticsearch.xcontent.XContentType;
35+
import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction;
36+
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
37+
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
38+
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
39+
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
40+
import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin;
41+
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
42+
import org.junit.After;
43+
import org.junit.Before;
44+
45+
import java.io.IOException;
46+
import java.util.Collection;
47+
import java.util.HashMap;
48+
import java.util.List;
49+
import java.util.Map;
50+
51+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
52+
import static org.hamcrest.CoreMatchers.equalTo;
53+
54+
public class SemanticTextIndexOptionsIT extends ESIntegTestCase {
55+
private static final String INDEX_NAME = "test-index";
56+
private static final Map<String, Object> BBQ_COMPATIBLE_SERVICE_SETTINGS = Map.of(
57+
"model",
58+
"my_model",
59+
"dimensions",
60+
256,
61+
"similarity",
62+
"cosine",
63+
"api_key",
64+
"my_api_key"
65+
);
66+
67+
private final Map<String, TaskType> inferenceIds = new HashMap<>();
68+
69+
@Override
70+
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
71+
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
72+
}
73+
74+
@Override
75+
protected Collection<Class<? extends Plugin>> nodePlugins() {
76+
return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, ReindexPlugin.class);
77+
}
78+
79+
@Before
80+
public void resetLicense() throws Exception {
81+
setLicense("trial");
82+
}
83+
84+
@After
85+
public void cleanUp() {
86+
assertAcked(
87+
safeGet(
88+
client().admin()
89+
.indices()
90+
.prepareDelete(INDEX_NAME)
91+
.setIndicesOptions(
92+
IndicesOptions.builder().concreteTargetOptions(new IndicesOptions.ConcreteTargetOptions(true)).build()
93+
)
94+
.execute()
95+
)
96+
);
97+
98+
for (var entry : inferenceIds.entrySet()) {
99+
assertAcked(
100+
safeGet(
101+
client().execute(
102+
DeleteInferenceEndpointAction.INSTANCE,
103+
new DeleteInferenceEndpointAction.Request(entry.getKey(), entry.getValue(), true, false)
104+
)
105+
)
106+
);
107+
}
108+
}
109+
110+
public void testValidateIndexOptionsWithBasicLicense() throws Exception {
111+
final String inferenceId = "test-inference-id-1";
112+
final String inferenceFieldName = "inference_field";
113+
createInferenceEndpoint(TaskType.TEXT_EMBEDDING, inferenceId, BBQ_COMPATIBLE_SERVICE_SETTINGS);
114+
115+
setLicense("basic");
116+
IndexOptions indexOptions = new DenseVectorFieldMapper.Int8HnswIndexOptions(
117+
randomIntBetween(1, 100),
118+
randomIntBetween(1, 10_000),
119+
null,
120+
null
121+
);
122+
assertAcked(
123+
safeGet(prepareCreate(INDEX_NAME).setMapping(generateMapping(inferenceFieldName, inferenceId, indexOptions)).execute())
124+
);
125+
126+
final Map<String, Object> expectedFieldMapping = generateExpectedFieldMapping(inferenceId, inferenceFieldName, indexOptions);
127+
var getFieldMappingsResponse = safeGet(
128+
client().execute(GetFieldMappingsAction.INSTANCE, new GetFieldMappingsRequest().indices(INDEX_NAME).fields(inferenceFieldName))
129+
);
130+
assertThat(getFieldMappingsResponse.fieldMappings(INDEX_NAME, inferenceFieldName).sourceAsMap(), equalTo(expectedFieldMapping));
131+
}
132+
133+
private void createInferenceEndpoint(TaskType taskType, String inferenceId, Map<String, Object> serviceSettings) throws IOException {
134+
final String service = switch (taskType) {
135+
case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME;
136+
case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME;
137+
default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]");
138+
};
139+
140+
final BytesReference content;
141+
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
142+
builder.startObject();
143+
builder.field("service", service);
144+
builder.field("service_settings", serviceSettings);
145+
builder.endObject();
146+
147+
content = BytesReference.bytes(builder);
148+
}
149+
150+
PutInferenceModelAction.Request request = new PutInferenceModelAction.Request(
151+
taskType,
152+
inferenceId,
153+
content,
154+
XContentType.JSON,
155+
TEST_REQUEST_TIMEOUT
156+
);
157+
var responseFuture = client().execute(PutInferenceModelAction.INSTANCE, request);
158+
assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId));
159+
160+
inferenceIds.put(inferenceId, taskType);
161+
}
162+
163+
private static XContentBuilder generateMapping(String inferenceFieldName, String inferenceId, @Nullable IndexOptions indexOptions)
164+
throws IOException {
165+
XContentBuilder mapping = XContentFactory.jsonBuilder();
166+
mapping.startObject();
167+
mapping.field("properties");
168+
generateFieldMapping(mapping, inferenceFieldName, inferenceId, indexOptions);
169+
mapping.endObject();
170+
171+
return mapping;
172+
}
173+
174+
private static void generateFieldMapping(
175+
XContentBuilder builder,
176+
String inferenceFieldName,
177+
String inferenceId,
178+
@Nullable IndexOptions indexOptions
179+
) throws IOException {
180+
builder.startObject();
181+
builder.startObject(inferenceFieldName);
182+
builder.field("type", SemanticTextFieldMapper.CONTENT_TYPE);
183+
builder.field("inference_id", inferenceId);
184+
if (indexOptions != null) {
185+
builder.startObject("index_options");
186+
if (indexOptions instanceof DenseVectorFieldMapper.DenseVectorIndexOptions) {
187+
builder.field("dense_vector");
188+
indexOptions.toXContent(builder, ToXContent.EMPTY_PARAMS);
189+
}
190+
builder.endObject();
191+
}
192+
builder.endObject();
193+
builder.endObject();
194+
}
195+
196+
private static Map<String, Object> generateExpectedFieldMapping(
197+
String inferenceFieldName,
198+
String inferenceId,
199+
@Nullable IndexOptions indexOptions
200+
) throws IOException {
201+
Map<String, Object> expectedFieldMapping;
202+
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
203+
generateFieldMapping(builder, inferenceFieldName, inferenceId, indexOptions);
204+
expectedFieldMapping = XContentHelper.convertToMap(BytesReference.bytes(builder), false, XContentType.JSON).v2();
205+
}
206+
207+
return expectedFieldMapping;
208+
}
209+
210+
private static void setLicense(String type) throws Exception {
211+
if (type.equals("basic")) {
212+
assertAcked(
213+
safeGet(
214+
client().execute(
215+
PostStartBasicAction.INSTANCE,
216+
new PostStartBasicRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT).acknowledge(true)
217+
)
218+
)
219+
);
220+
} else {
221+
License license = TestUtils.generateSignedLicense(type, License.VERSION_CURRENT, -1, TimeValue.timeValueHours(24));
222+
assertAcked(
223+
safeGet(
224+
client().execute(
225+
PutLicenseAction.INSTANCE,
226+
new PutLicenseRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT).license(license)
227+
)
228+
)
229+
);
230+
}
231+
}
232+
}

0 commit comments

Comments
 (0)