Skip to content

Commit ca426ec

Browse files
Adding elser default endpoint
1 parent cff329e commit ca426ec

File tree

4 files changed

+58
-8
lines changed

4 files changed

+58
-8
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,34 @@
1212
import org.elasticsearch.inference.TaskType;
1313

1414
import java.io.IOException;
15+
import java.util.List;
16+
import java.util.Map;
1517

1618
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getAllModels;
1719
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels;
1820
import static org.hamcrest.Matchers.hasSize;
21+
import static org.hamcrest.Matchers.is;
1922

2023
public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEISAuthServerTest {
2124

2225
public void testGetDefaultEndpoints() throws IOException {
2326
var allModels = getAllModels();
2427
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);
2528

26-
assertThat(allModels, hasSize(4));
29+
assertThat(allModels, hasSize(5));
2730
assertThat(chatCompletionModels, hasSize(1));
2831

2932
for (var model : chatCompletionModels) {
3033
assertEquals("chat_completion", model.get("task_type"));
3134
}
3235

36+
assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
37+
assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING);
38+
}
39+
40+
private static void assertInferenceIdTaskType(List<Map<String, Object>> models, String inferenceId, TaskType taskType) {
41+
var model = models.stream().filter(m -> m.get("inference_id").equals(inferenceId)).findFirst();
42+
assertTrue("could not find inference id: " + inferenceId, model.isPresent());
43+
assertThat(model.get().get("task_type"), is(taskType.toString()));
3344
}
3445
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
5858

5959
import java.util.ArrayList;
60+
import java.util.Comparator;
6061
import java.util.EnumSet;
6162
import java.util.HashMap;
6263
import java.util.HashSet;
@@ -65,6 +66,7 @@
6566
import java.util.Map;
6667
import java.util.Objects;
6768
import java.util.Set;
69+
import java.util.TreeSet;
6870
import java.util.concurrent.CountDownLatch;
6971
import java.util.concurrent.TimeUnit;
7072
import java.util.concurrent.atomic.AtomicReference;
@@ -90,14 +92,24 @@ public class ElasticInferenceService extends SenderService {
9092
private static final Logger logger = LogManager.getLogger(ElasticInferenceService.class);
9193
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION);
9294
private static final String SERVICE_NAME = "Elastic";
95+
96+
// rainbow-sprinkles
9397
static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles";
94-
static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = Strings.format(".%s-elastic", DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);
98+
static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);
99+
100+
// elser-v2
101+
static final String DEFAULT_ELSER_MODEL_ID_V2 = "elser-v2";
102+
static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2);
95103

96104
/**
97105
* The task types that the {@link InferenceAction.Request} can accept.
98106
*/
99107
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING);
100108

109+
private static String defaultEndpointId(String modelId) {
110+
return Strings.format(".%s-elastic", modelId);
111+
}
112+
101113
private final ElasticInferenceServiceComponents elasticInferenceServiceComponents;
102114
private Configuration configuration;
103115
private final AtomicReference<AuthorizedContent> authRef = new AtomicReference<>(AuthorizedContent.empty());
@@ -142,6 +154,19 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
142154
elasticInferenceServiceComponents
143155
),
144156
MinimalServiceSettings.chatCompletion()
157+
),
158+
DEFAULT_ELSER_MODEL_ID_V2,
159+
new DefaultModelConfig(
160+
new ElasticInferenceServiceSparseEmbeddingsModel(
161+
DEFAULT_ELSER_ENDPOINT_ID_V2,
162+
TaskType.SPARSE_EMBEDDING,
163+
NAME,
164+
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_MODEL_ID_V2, null, null),
165+
EmptyTaskSettings.INSTANCE,
166+
EmptySecretSettings.INSTANCE,
167+
elasticInferenceServiceComponents
168+
),
169+
MinimalServiceSettings.sparseEmbedding()
145170
)
146171
);
147172
}
@@ -190,7 +215,7 @@ private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizat
190215

191216
private Set<String> getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorization auth) {
192217
var authorizedModels = auth.getAuthorizedModelIds();
193-
var authorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet());
218+
var authorizedDefaultModelIds = new TreeSet<>(defaultModelsConfigs.keySet());
194219
authorizedDefaultModelIds.retainAll(authorizedModels);
195220

196221
return authorizedDefaultModelIds;
@@ -218,6 +243,7 @@ private List<DefaultConfigId> getAuthorizedDefaultConfigIds(
218243
}
219244
}
220245

246+
authorizedConfigIds.sort(Comparator.comparing(DefaultConfigId::inferenceId));
221247
return authorizedConfigIds;
222248
}
223249

@@ -230,6 +256,7 @@ private List<DefaultModelConfig> getAuthorizedDefaultModelsObjects(Set<String> a
230256
}
231257
}
232258

259+
authorizedModels.sort(Comparator.comparing(modelConfig -> modelConfig.model.getInferenceEntityId()));
233260
return authorizedModels;
234261
}
235262

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public static ElasticInferenceServiceSparseEmbeddingsServiceSettings fromMap(
7575
public ElasticInferenceServiceSparseEmbeddingsServiceSettings(
7676
String modelId,
7777
@Nullable Integer maxInputTokens,
78-
RateLimitSettings rateLimitSettings
78+
@Nullable RateLimitSettings rateLimitSettings
7979
) {
8080
this.modelId = Objects.requireNonNull(modelId);
8181
this.maxInputTokens = maxInputTokens;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -934,13 +934,17 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIn
934934
}
935935
}
936936

937-
public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception {
937+
public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() throws Exception {
938938
String responseJson = """
939939
{
940940
"models": [
941941
{
942942
"model_name": "rainbow-sprinkles",
943943
"task_types": ["chat"]
944+
},
945+
{
946+
"model_name": "elser-v2",
947+
"task_types": ["embed/text/sparse"]
944948
}
945949
]
946950
}
@@ -957,15 +961,23 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
957961
service.defaultConfigIds(),
958962
is(
959963
List.of(
960-
new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service)
964+
new InferenceService.DefaultConfigId(".elser-v2-elastic", MinimalServiceSettings.sparseEmbedding(), service),
965+
new InferenceService.DefaultConfigId(
966+
".rainbow-sprinkles-elastic",
967+
MinimalServiceSettings.chatCompletion(),
968+
service
969+
)
961970
)
962971
)
963972
);
964-
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
973+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
965974

966975
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
967976
service.defaultConfigs(listener);
968-
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
977+
var models = listener.actionGet(TIMEOUT);
978+
assertThat(models.size(), is(2));
979+
assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
980+
assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
969981
}
970982
}
971983

0 commit comments

Comments
 (0)