Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/114683.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 114683
summary: Default inference endpoint for the multilingual-e5-small model
area: Machine Learning
type: enhancement
issues: []
7 changes: 6 additions & 1 deletion docs/reference/rest-api/usage.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,12 @@ GET /_xpack/usage
"service": "elasticsearch",
"task_type": "SPARSE_EMBEDDING",
"count": 1
}
},
{
"service": "elasticsearch",
"task_type": "TEXT_EMBEDDING",
"count": 1
},
]
},
"logstash" : {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.oneOf;

public class DefaultElserIT extends InferenceBaseRestTest {
public class DefaultEndPointsIT extends InferenceBaseRestTest {

private TestThreadPool threadPool;

@Before
public void createThreadPool() {
threadPool = new TestThreadPool(DefaultElserIT.class.getSimpleName());
threadPool = new TestThreadPool(DefaultEndPointsIT.class.getSimpleName());
}

@After
Expand All @@ -38,7 +38,7 @@ public void tearDown() throws Exception {
}

@SuppressWarnings("unchecked")
public void testInferCreatesDefaultElser() throws IOException {
public void testInferDeploysDefaultElser() throws IOException {
assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
var model = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID);
assertDefaultElserConfig(model);
Expand Down Expand Up @@ -67,4 +67,39 @@ private static void assertDefaultElserConfig(Map<String, Object> modelConfig) {
Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 1, "max_number_of_allocations", 8))
);
}

@SuppressWarnings("unchecked")
public void testInferDeploysDefaultE5() throws IOException {
assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
var model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID);
assertDefaultE5Config(model);

var inputs = List.of("Hello World", "Goodnight moon");
var queryParams = Map.of("timeout", "120s");
var results = infer(ElasticsearchInternalService.DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, inputs, queryParams);
var embeddings = (List<Map<String, Object>>) results.get("text_embedding");
assertThat(results.toString(), embeddings, hasSize(2));
}

@SuppressWarnings("unchecked")
private static void assertDefaultE5Config(Map<String, Object> modelConfig) {
assertEquals(modelConfig.toString(), ElasticsearchInternalService.DEFAULT_E5_ID, modelConfig.get("inference_id"));
assertEquals(modelConfig.toString(), ElasticsearchInternalService.NAME, modelConfig.get("service"));
assertEquals(modelConfig.toString(), TaskType.TEXT_EMBEDDING.toString(), modelConfig.get("task_type"));

var serviceSettings = (Map<String, Object>) modelConfig.get("service_settings");
assertThat(
modelConfig.toString(),
serviceSettings.get("model_id"),
is(oneOf(".multilingual-e5-small", ".multilingual-e5-small_linux-x86_64"))
);
assertEquals(modelConfig.toString(), 1, serviceSettings.get("num_threads"));

var adaptiveAllocations = (Map<String, Object>) serviceSettings.get("adaptive_allocations");
assertThat(
modelConfig.toString(),
adaptiveAllocations,
Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 1, "max_number_of_allocations", 8))
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ public static InferModelAction.Request buildInferenceRequest(
return request;
}

protected abstract boolean isDefaultId(String inferenceId);
abstract boolean isDefaultId(String inferenceId);

protected void maybeStartDeployment(
ElasticsearchInternalModel model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi

public static final int EMBEDDING_MAX_BATCH_SIZE = 10;
public static final String DEFAULT_ELSER_ID = ".elser-2";
public static final String DEFAULT_E5_ID = ".multi-e5-small";

private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class);
private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class);
Expand Down Expand Up @@ -815,20 +816,47 @@ public List<UnparsedModel> defaultConfigs() {
)
);

// TODO Chunking settings
Map<String, Object> e5Settings = Map.of(
ModelConfigurations.SERVICE_SETTINGS,
Map.of(
ElasticsearchInternalServiceSettings.MODEL_ID,
MULTILINGUAL_E5_SMALL_MODEL_ID, // TODO pick model depending on platform
ElasticsearchInternalServiceSettings.NUM_THREADS,
1,
ElasticsearchInternalServiceSettings.ADAPTIVE_ALLOCATIONS,
Map.of(
"enabled",
Boolean.TRUE,
"min_number_of_allocations",
1,
"max_number_of_allocations",
8 // no max?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah he's on vacation today

)
)
);

return List.of(
new UnparsedModel(
DEFAULT_ELSER_ID,
TaskType.SPARSE_EMBEDDING,
NAME,
elserSettings,
Map.of() // no secrets
),
new UnparsedModel(
DEFAULT_E5_ID,
TaskType.TEXT_EMBEDDING,
NAME,
e5Settings,
Map.of() // no secrets
)
);
}

@Override
protected boolean isDefaultId(String inferenceId) {
return DEFAULT_ELSER_ID.equals(inferenceId);
boolean isDefaultId(String inferenceId) {
return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId);
}

static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSettings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1541,6 +1541,13 @@ public void testEmbeddingTypeFromTaskTypeAndSettings() {
assertThat(e.getMessage(), containsString("Chunking is not supported for task type [completion]"));
}

public void testIsDefaultId() {
var service = createService(mock(Client.class));
assertTrue(service.isDefaultId(".elser-2"));
assertTrue(service.isDefaultId(".multi-e5-small"));
assertFalse(service.isDefaultId("foo"));
}

private ElasticsearchInternalService createService(Client client) {
var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool);
return new ElasticsearchInternalService(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,9 @@ protected void masterOperation(
if (getModelResponse.getResources().results().size() > 1) {
listener.onFailure(
ExceptionsHelper.badRequestException(
"cannot deploy more than one models at the same time; [{}] matches [{}] models]",
"cannot deploy more than one model at the same time; [{}] matches models [{}]",
request.getModelId(),
getModelResponse.getResources().results().size()
getModelResponse.getResources().results().stream().map(TrainedModelConfig::getModelId).toList()
)
);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,17 @@
- do:
inference.get:
inference_id: "*"
- length: { endpoints: 1}
- match: { endpoints.0.inference_id: ".elser-2" }
- length: { endpoints: 2}
- match: { endpoints.0.inference_id: ".multi-e5-small" }
- match: { endpoints.1.inference_id: ".elser-2" }

- do:
inference.get:
inference_id: _all
- length: { endpoints: 1}
- match: { endpoints.0.inference_id: ".elser-2" }
- length: { endpoints: 2}

- do:
inference.get:
inference_id: ""
- length: { endpoints: 1}
- match: { endpoints.0.inference_id: ".elser-2" }
- length: { endpoints: 2}