Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/113873.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 113873
summary: Default inference endpoint for ELSER
area: Machine Learning
type: enhancement
issues: []
7 changes: 6 additions & 1 deletion docs/reference/rest-api/usage.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,12 @@ GET /_xpack/usage
"inference": {
"available" : true,
"enabled" : true,
"models" : []
"models" : [{
"service": "elasticsearch",
"task_type": "SPARSE_EMBEDDING",
"count": 1
}
]
},
"logstash" : {
"available" : true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,13 @@ default Set<TaskType> supportedStreamingTasks() {
default boolean canStream(TaskType taskType) {
return supportedStreamingTasks().contains(taskType);
}

/**
* A service can define default configurations that can be
* used out of the box without creating an endpoint first.
* @return Default configurations provided by this service
*/
default List<UnparsedModel> defaultConfigs() {
return List.of();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.inference;

import java.util.Map;

/**
* Semi parsed model where inference entity id, task type and service
* are known but the settings are not parsed.
*/
public record UnparsedModel(
String inferenceEntityId,
TaskType taskType,
String service,
Map<String, Object> settings,
Map<String, Object> secrets
) {}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ public enum FeatureFlag {
TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null),
FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null),
CHUNKING_SETTINGS_ENABLED("es.inference_chunking_settings_feature_flag_enabled=true", Version.fromString("8.16.0"), null),
INFERENCE_SCALE_TO_ZERO("es.inference_scale_to_zero_feature_flag_enabled=true", Version.fromString("8.16.0"), null);
INFERENCE_SCALE_TO_ZERO("es.inference_scale_to_zero_feature_flag_enabled=true", Version.fromString("8.16.0"), null),
INFERENCE_DEFAULT_ELSER("es.inference_default_elser_feature_flag_enabled=true", Version.fromString("8.16.0"), null);

public final String systemProperty;
public final Version from;
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -201,5 +201,6 @@ tasks.named("precommit").configure {

tasks.named("yamlRestTestV7CompatTransform").configure({ task ->
task.skipTest("security/10_forbidden/Test bulk response with invalid credentials", "warning does not exist for compatibility")
task.skipTest("inference/inference_crud/Test get all", "Assertions on number of inference models break due to default configs")
})

Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ public int computeNumberOfAllocations() {
if (numberOfAllocations != null) {
return numberOfAllocations;
} else {
if (adaptiveAllocationsSettings == null || adaptiveAllocationsSettings.getMinNumberOfAllocations() == null) {
if (adaptiveAllocationsSettings == null
|| adaptiveAllocationsSettings.getMinNumberOfAllocations() == null
|| adaptiveAllocationsSettings.getMinNumberOfAllocations() == 0) {
return DEFAULT_NUM_ALLOCATIONS;
} else {
return adaptiveAllocationsSettings.getMinNumberOfAllocations();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ public AdaptiveAllocationsSettings merge(AdaptiveAllocationsSettings updates) {
public ActionRequestValidationException validate() {
ActionRequestValidationException validationException = new ActionRequestValidationException();
boolean hasMinNumberOfAllocations = (minNumberOfAllocations != null && minNumberOfAllocations != -1);
if (hasMinNumberOfAllocations && minNumberOfAllocations < 1) {
validationException.addValidationError("[" + MIN_NUMBER_OF_ALLOCATIONS + "] must be a positive integer or null");
if (hasMinNumberOfAllocations && minNumberOfAllocations < 0) {
validationException.addValidationError("[" + MIN_NUMBER_OF_ALLOCATIONS + "] must be a non-negative integer or null");
}
boolean hasMaxNumberOfAllocations = (maxNumberOfAllocations != null && maxNumberOfAllocations != -1);
if (hasMaxNumberOfAllocations && maxNumberOfAllocations < 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public class AdaptiveAllocationSettingsTests extends AbstractWireSerializingTest
public static AdaptiveAllocationsSettings testInstance() {
return new AdaptiveAllocationsSettings(
randomBoolean() ? null : randomBoolean(),
randomBoolean() ? null : randomIntBetween(1, 2),
randomBoolean() ? null : randomIntBetween(0, 2),
randomBoolean() ? null : randomIntBetween(2, 4)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public void testSparse() throws IOException {

var inferenceId = "sparse-inf";
putModel(inferenceId, inferenceConfig, TaskType.SPARSE_EMBEDDING);
var results = inferOnMockService(inferenceId, List.of("washing", "machine"));
var results = infer(inferenceId, List.of("washing", "machine"));
deleteModel(inferenceId);
assertNotNull(results.get("sparse_embedding"));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference;

import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Before;

import java.io.IOException;
import java.util.List;
import java.util.Map;

import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.oneOf;

public class DefaultElserIT extends InferenceBaseRestTest {

private TestThreadPool threadPool;

@Before
public void createThreadPool() {
threadPool = new TestThreadPool(DefaultElserIT.class.getSimpleName());
}

@After
public void tearDown() throws Exception {
threadPool.close();
super.tearDown();
}

@SuppressWarnings("unchecked")
public void testInferCreatesDefaultElser() throws IOException {
assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
var model = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID);
assertDefaultElserConfig(model);

var inputs = List.of("Hello World", "Goodnight moon");
var queryParams = Map.of("timeout", "120s");
var results = infer(ElasticsearchInternalService.DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, inputs, queryParams);
var embeddings = (List<Map<String, Object>>) results.get("sparse_embedding");
assertThat(results.toString(), embeddings, hasSize(2));
}

@SuppressWarnings("unchecked")
private static void assertDefaultElserConfig(Map<String, Object> modelConfig) {
assertEquals(modelConfig.toString(), ElasticsearchInternalService.DEFAULT_ELSER_ID, modelConfig.get("inference_id"));
assertEquals(modelConfig.toString(), ElasticsearchInternalService.NAME, modelConfig.get("service"));
assertEquals(modelConfig.toString(), TaskType.SPARSE_EMBEDDING.toString(), modelConfig.get("task_type"));

var serviceSettings = (Map<String, Object>) modelConfig.get("service_settings");
assertThat(modelConfig.toString(), serviceSettings.get("model_id"), is(oneOf(".elser_model_2", ".elser_model_2_linux-x86_64")));
assertEquals(modelConfig.toString(), 1, serviceSettings.get("num_threads"));

var adaptiveAllocations = (Map<String, Object>) serviceSettings.get("adaptive_allocations");
assertThat(
modelConfig.toString(),
adaptiveAllocations,
Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 1, "max_number_of_allocations", 8))
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ protected Map<String, Object> deployE5TrainedModels() throws IOException {

@SuppressWarnings("unchecked")
protected Map<String, Object> getModel(String modelId) throws IOException {
var endpoint = Strings.format("_inference/%s", modelId);
var endpoint = Strings.format("_inference/%s?error_trace", modelId);
return ((List<Map<String, Object>>) getInternal(endpoint).get("endpoints")).get(0);
}

Expand All @@ -293,9 +293,9 @@ private Map<String, Object> getInternal(String endpoint) throws IOException {
return entityAsMap(response);
}

protected Map<String, Object> inferOnMockService(String modelId, List<String> input) throws IOException {
protected Map<String, Object> infer(String modelId, List<String> input) throws IOException {
var endpoint = Strings.format("_inference/%s", modelId);
return inferOnMockServiceInternal(endpoint, input);
return inferInternal(endpoint, input, Map.of());
}

protected Deque<ServerSentEvent> streamInferOnMockService(String modelId, TaskType taskType, List<String> input) throws Exception {
Expand Down Expand Up @@ -324,14 +324,23 @@ public void onFailure(Exception exception) {
return responseConsumer.events();
}

protected Map<String, Object> inferOnMockService(String modelId, TaskType taskType, List<String> input) throws IOException {
protected Map<String, Object> infer(String modelId, TaskType taskType, List<String> input) throws IOException {
var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
return inferOnMockServiceInternal(endpoint, input);
return inferInternal(endpoint, input, Map.of());
}

protected Map<String, Object> infer(String modelId, TaskType taskType, List<String> input, Map<String, String> queryParameters)
throws IOException {
var endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId);
return inferInternal(endpoint, input, queryParameters);
}

private Map<String, Object> inferOnMockServiceInternal(String endpoint, List<String> input) throws IOException {
private Map<String, Object> inferInternal(String endpoint, List<String> input, Map<String, String> queryParameters) throws IOException {
var request = new Request("POST", endpoint);
request.setJsonEntity(jsonBody(input));
if (queryParameters.isEmpty() == false) {
request.addParameters(queryParameters);
}
var response = client().performRequest(request);
assertOkOrCreated(response);
return entityAsMap(response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ public void testGet() throws IOException {
}

var getAllModels = getAllModels();
assertThat(getAllModels, hasSize(9));
int numModels = DefaultElserFeatureFlag.isEnabled() ? 10 : 9;
assertThat(getAllModels, hasSize(numModels));

var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
assertThat(getSparseModels, hasSize(5));
int numSparseModels = DefaultElserFeatureFlag.isEnabled() ? 6 : 5;
assertThat(getSparseModels, hasSize(numSparseModels));
for (var sparseModel : getSparseModels) {
assertEquals("sparse_embedding", sparseModel.get("task_type"));
}
Expand Down Expand Up @@ -99,7 +101,7 @@ public void testApisWithoutTaskType() throws IOException {
assertEquals(modelId, singleModel.get("inference_id"));
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type"));

var inference = inferOnMockService(modelId, List.of(randomAlphaOfLength(10)));
var inference = infer(modelId, List.of(randomAlphaOfLength(10)));
assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING);
deleteModel(modelId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,20 @@ public void testMockService() throws IOException {
}

List<String> input = List.of(randomAlphaOfLength(10));
var inference = inferOnMockService(inferenceEntityId, input);
var inference = infer(inferenceEntityId, input);
assertNonEmptyInferenceResults(inference, 1, TaskType.TEXT_EMBEDDING);
// Same input should return the same result
assertEquals(inference, inferOnMockService(inferenceEntityId, input));
assertEquals(inference, infer(inferenceEntityId, input));
// Different input values should not
assertNotEquals(
inference,
inferOnMockService(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10))))
);
assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))));
}

public void testMockServiceWithMultipleInputs() throws IOException {
String inferenceEntityId = "test-mock-with-multi-inputs";
putModel(inferenceEntityId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);

// The response is randomly generated, the input can be anything
var inference = inferOnMockService(
var inference = infer(
inferenceEntityId,
TaskType.TEXT_EMBEDDING,
List.of(randomAlphaOfLength(5), randomAlphaOfLength(10), randomAlphaOfLength(15))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,20 @@ public void testMockService() throws IOException {
}

List<String> input = List.of(randomAlphaOfLength(10));
var inference = inferOnMockService(inferenceEntityId, input);
var inference = infer(inferenceEntityId, input);
assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING);
// Same input should return the same result
assertEquals(inference, inferOnMockService(inferenceEntityId, input));
assertEquals(inference, infer(inferenceEntityId, input));
// Different input values should not
assertNotEquals(
inference,
inferOnMockService(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10))))
);
assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))));
}

public void testMockServiceWithMultipleInputs() throws IOException {
String inferenceEntityId = "test-mock-with-multi-inputs";
putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);

// The response is randomly generated, the input can be anything
var inference = inferOnMockService(
var inference = infer(
inferenceEntityId,
TaskType.SPARSE_EMBEDDING,
List.of(randomAlphaOfLength(5), randomAlphaOfLength(10), randomAlphaOfLength(15))
Expand Down Expand Up @@ -84,7 +81,7 @@ public void testMockService_DoesNotReturnHiddenField_InModelResponses() throws I
}

// The response is randomly generated, the input can be anything
var inference = inferOnMockService(inferenceEntityId, List.of(randomAlphaOfLength(10)));
var inference = infer(inferenceEntityId, List.of(randomAlphaOfLength(10)));
assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING);
}

Expand All @@ -102,7 +99,7 @@ public void testMockService_DoesReturnHiddenField_InModelResponses() throws IOEx
}

// The response is randomly generated, the input can be anything
var inference = inferOnMockService(inferenceEntityId, List.of(randomAlphaOfLength(10)));
var inference = infer(inferenceEntityId, List.of(randomAlphaOfLength(10)));
assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public void testPutE5Small_withPlatformAgnosticVariant() throws IOException {
var models = getTrainedModel("_all");
assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId));

Map<String, Object> results = inferOnMockService(
Map<String, Object> results = infer(
inferenceEntityId,
TaskType.TEXT_EMBEDDING,
List.of("hello world", "this is the second document")
Expand All @@ -57,7 +57,7 @@ public void testPutE5Small_withPlatformSpecificVariant() throws IOException {
var models = getTrainedModel("_all");
assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId));

Map<String, Object> results = inferOnMockService(
Map<String, Object> results = infer(
inferenceEntityId,
TaskType.TEXT_EMBEDDING,
List.of("hello world", "this is the second document")
Expand Down
Loading