Skip to content

Commit b9d1222

Browse files
jonathan-buttnerelasticsearchmachine
andauthored
[ML] Adding elser default endpoint for EIS (#122066)
* Adding elser default endpoint * [CI] Auto commit changes from spotless * Fixing test and allowing duplicate calls * [CI] Auto commit changes from spotless * Update docs/changelog/122066.yaml --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 4b3acd4 commit b9d1222

File tree

7 files changed

+81
-13
lines changed

7 files changed

+81
-13
lines changed

docs/changelog/122066.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 122066
2+
summary: Adding elser default endpoint for EIS
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,34 @@
1212
import org.elasticsearch.inference.TaskType;
1313

1414
import java.io.IOException;
15+
import java.util.List;
16+
import java.util.Map;
1517

1618
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getAllModels;
1719
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels;
1820
import static org.hamcrest.Matchers.hasSize;
21+
import static org.hamcrest.Matchers.is;
1922

2023
public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEISAuthServerTest {
2124

2225
public void testGetDefaultEndpoints() throws IOException {
2326
var allModels = getAllModels();
2427
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);
2528

26-
assertThat(allModels, hasSize(4));
29+
assertThat(allModels, hasSize(5));
2730
assertThat(chatCompletionModels, hasSize(1));
2831

2932
for (var model : chatCompletionModels) {
3033
assertEquals("chat_completion", model.get("task_type"));
3134
}
3235

36+
assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
37+
assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING);
38+
}
39+
40+
private static void assertInferenceIdTaskType(List<Map<String, Object>> models, String inferenceId, TaskType taskType) {
41+
var model = models.stream().filter(m -> m.get("inference_id").equals(inferenceId)).findFirst();
42+
assertTrue("could not find inference id: " + inferenceId, model.isPresent());
43+
assertThat(model.get().get("task_type"), is(taskType.toString()));
3344
}
3445
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
204204
service.defaultConfigIds(),
205205
is(
206206
List.of(
207+
new InferenceService.DefaultConfigId(".elser-v2-elastic", MinimalServiceSettings.sparseEmbedding(), service),
207208
new InferenceService.DefaultConfigId(
208209
".rainbow-sprinkles-elastic",
209210
MinimalServiceSettings.chatCompletion(),
@@ -216,7 +217,8 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
216217

217218
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
218219
service.defaultConfigs(listener);
219-
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
220+
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
221+
assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
220222

221223
var getModelListener = new PlainActionFuture<UnparsedModel>();
222224
// persists the default endpoints
@@ -244,12 +246,18 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
244246
try (var service = createElasticInferenceService()) {
245247
service.waitForAuthorizationToComplete(TIMEOUT);
246248
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
247-
assertTrue(service.defaultConfigIds().isEmpty());
249+
assertThat(
250+
service.defaultConfigIds(),
251+
is(
252+
List.of(
253+
new InferenceService.DefaultConfigId(".elser-v2-elastic", MinimalServiceSettings.sparseEmbedding(), service)
254+
)
255+
)
256+
);
248257
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
249258

250259
var getModelListener = new PlainActionFuture<UnparsedModel>();
251260
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
252-
253261
var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT));
254262
assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]"));
255263
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,20 @@ public boolean containsDefaultConfigId(String inferenceEntityId) {
126126
return defaultConfigIds.containsKey(inferenceEntityId);
127127
}
128128

129+
/**
130+
* Adds the default configuration information if it does not already exist internally.
131+
* @param defaultConfigId the default endpoint information
132+
*/
133+
public synchronized void putDefaultIdIfAbsent(InferenceService.DefaultConfigId defaultConfigId) {
134+
defaultConfigIds.putIfAbsent(defaultConfigId.inferenceId(), defaultConfigId);
135+
}
136+
129137
/**
130138
* Set the default inference ids provided by the services
131-
* @param defaultConfigId The default
139+
* @param defaultConfigId The default endpoint information
140+
* @throws IllegalStateException if the {@link InferenceService.DefaultConfigId#inferenceId()} already exists internally
132141
*/
133-
public synchronized void addDefaultIds(InferenceService.DefaultConfigId defaultConfigId) {
142+
public synchronized void addDefaultIds(InferenceService.DefaultConfigId defaultConfigId) throws IllegalStateException {
134143
var config = defaultConfigIds.get(defaultConfigId.inferenceId());
135144
if (config != null) {
136145
throw new IllegalStateException(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
5858

5959
import java.util.ArrayList;
60+
import java.util.Comparator;
6061
import java.util.EnumSet;
6162
import java.util.HashMap;
6263
import java.util.HashSet;
@@ -65,6 +66,7 @@
6566
import java.util.Map;
6667
import java.util.Objects;
6768
import java.util.Set;
69+
import java.util.TreeSet;
6870
import java.util.concurrent.CountDownLatch;
6971
import java.util.concurrent.TimeUnit;
7072
import java.util.concurrent.atomic.AtomicReference;
@@ -90,14 +92,24 @@ public class ElasticInferenceService extends SenderService {
9092
private static final Logger logger = LogManager.getLogger(ElasticInferenceService.class);
9193
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION);
9294
private static final String SERVICE_NAME = "Elastic";
95+
96+
// rainbow-sprinkles
9397
static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles";
94-
static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = Strings.format(".%s-elastic", DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);
98+
static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);
99+
100+
// elser-v2
101+
static final String DEFAULT_ELSER_MODEL_ID_V2 = "elser-v2";
102+
static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2);
95103

96104
/**
97105
* The task types that the {@link InferenceAction.Request} can accept.
98106
*/
99107
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING);
100108

109+
private static String defaultEndpointId(String modelId) {
110+
return Strings.format(".%s-elastic", modelId);
111+
}
112+
101113
private final ElasticInferenceServiceComponents elasticInferenceServiceComponents;
102114
private Configuration configuration;
103115
private final AtomicReference<AuthorizedContent> authRef = new AtomicReference<>(AuthorizedContent.empty());
@@ -142,6 +154,19 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
142154
elasticInferenceServiceComponents
143155
),
144156
MinimalServiceSettings.chatCompletion()
157+
),
158+
DEFAULT_ELSER_MODEL_ID_V2,
159+
new DefaultModelConfig(
160+
new ElasticInferenceServiceSparseEmbeddingsModel(
161+
DEFAULT_ELSER_ENDPOINT_ID_V2,
162+
TaskType.SPARSE_EMBEDDING,
163+
NAME,
164+
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_MODEL_ID_V2, null, null),
165+
EmptyTaskSettings.INSTANCE,
166+
EmptySecretSettings.INSTANCE,
167+
elasticInferenceServiceComponents
168+
),
169+
MinimalServiceSettings.sparseEmbedding()
145170
)
146171
);
147172
}
@@ -184,13 +209,13 @@ private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizat
184209

185210
configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes());
186211

187-
defaultConfigIds().forEach(modelRegistry::addDefaultIds);
212+
defaultConfigIds().forEach(modelRegistry::putDefaultIdIfAbsent);
188213
handleRevokedDefaultConfigs(authorizedDefaultModelIds);
189214
}
190215

191216
private Set<String> getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorization auth) {
192217
var authorizedModels = auth.getAuthorizedModelIds();
193-
var authorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet());
218+
var authorizedDefaultModelIds = new TreeSet<>(defaultModelsConfigs.keySet());
194219
authorizedDefaultModelIds.retainAll(authorizedModels);
195220

196221
return authorizedDefaultModelIds;
@@ -218,6 +243,7 @@ private List<DefaultConfigId> getAuthorizedDefaultConfigIds(
218243
}
219244
}
220245

246+
authorizedConfigIds.sort(Comparator.comparing(DefaultConfigId::inferenceId));
221247
return authorizedConfigIds;
222248
}
223249

@@ -230,6 +256,7 @@ private List<DefaultModelConfig> getAuthorizedDefaultModelsObjects(Set<String> a
230256
}
231257
}
232258

259+
authorizedModels.sort(Comparator.comparing(modelConfig -> modelConfig.model.getInferenceEntityId()));
233260
return authorizedModels;
234261
}
235262

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public static ElasticInferenceServiceSparseEmbeddingsServiceSettings fromMap(
7575
public ElasticInferenceServiceSparseEmbeddingsServiceSettings(
7676
String modelId,
7777
@Nullable Integer maxInputTokens,
78-
RateLimitSettings rateLimitSettings
78+
@Nullable RateLimitSettings rateLimitSettings
7979
) {
8080
this.modelId = Objects.requireNonNull(modelId);
8181
this.maxInputTokens = maxInputTokens;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -934,13 +934,17 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIn
934934
}
935935
}
936936

937-
public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception {
937+
public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() throws Exception {
938938
String responseJson = """
939939
{
940940
"models": [
941941
{
942942
"model_name": "rainbow-sprinkles",
943943
"task_types": ["chat"]
944+
},
945+
{
946+
"model_name": "elser-v2",
947+
"task_types": ["embed/text/sparse"]
944948
}
945949
]
946950
}
@@ -957,15 +961,19 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
957961
service.defaultConfigIds(),
958962
is(
959963
List.of(
964+
new InferenceService.DefaultConfigId(".elser-v2-elastic", MinimalServiceSettings.sparseEmbedding(), service),
960965
new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service)
961966
)
962967
)
963968
);
964-
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
969+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
965970

966971
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
967972
service.defaultConfigs(listener);
968-
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
973+
var models = listener.actionGet(TIMEOUT);
974+
assertThat(models.size(), is(2));
975+
assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
976+
assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
969977
}
970978
}
971979

0 commit comments

Comments
 (0)