Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/128584.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 128584
summary: '`InferenceService` support aliases'
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ default void init(Client client) {}

String name();

/**
* The aliases that map to {@link #name()}. {@link InferenceServiceRegistry} allows users to create and use inference services by one
* of their aliases.
*/
default List<String> aliases() {
return List.of();
}

/**
* Parse model configuration from the {@code config map} from a request and return
* the parsed {@link Model}. This requires that both the secrets and service settings be contained in the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,22 @@
public class InferenceServiceRegistry implements Closeable {

private final Map<String, InferenceService> services;
private final Map<String, String> aliases;
private final List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();

public InferenceServiceRegistry(
List<InferenceServiceExtension> inferenceServicePlugins,
InferenceServiceExtension.InferenceServiceFactoryContext factoryContext
) {
// TODO check names are unique
// toMap verifies that the names and aliases are unique
services = inferenceServicePlugins.stream()
.flatMap(r -> r.getInferenceServiceFactories().stream())
.map(factory -> factory.create(factoryContext))
.collect(Collectors.toMap(InferenceService::name, Function.identity()));
aliases = services.values()
.stream()
.flatMap(service -> service.aliases().stream().distinct().map(alias -> Map.entry(alias, service.name())))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}

public void init(Client client) {
Expand All @@ -56,13 +61,8 @@ public Map<String, InferenceService> getServices() {
}

public Optional<InferenceService> getService(String serviceName) {

if ("elser".equals(serviceName)) { // ElserService.NAME before removal
// here we are aliasing the elser service to use the elasticsearch service instead
return Optional.ofNullable(services.get("elasticsearch")); // ElasticsearchInternalService.NAME
} else {
return Optional.ofNullable(services.get(serviceName));
}
var serviceKey = aliases.getOrDefault(serviceName, serviceName);
return Optional.ofNullable(services.get(serviceKey));
}

public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public void testDefaultModels() throws IOException {
var rerankModel = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID);
assertDefaultRerankConfig(rerankModel);

putModel("my-model", mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING));
putModel("my-model", mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service"));
var registeredModels = getMinimalConfigs();
assertThat(registeredModels.size(), equalTo(1));
assertTrue(registeredModels.containsKey("my-model"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ static String updateConfig(@Nullable TaskType taskTypeInBody, String apiKey, int
""", taskType, apiKey, temperature);
}

static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody) {
static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody, String service) {
var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
return Strings.format("""
{
%s
"service": "streaming_completion_test_service",
"service": "%s",
"service_settings": {
"model": "my_model",
"api_key": "abc64"
Expand All @@ -133,7 +133,7 @@ static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody
"temperature": 3
}
}
""", taskType);
""", taskType, service);
}

static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody, boolean shouldReturnHiddenField) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws

public void testUnsupportedStream() throws Exception {
String modelId = "streaming";
putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING));
putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service"));
var singleModel = getModel(modelId);
assertEquals(modelId, singleModel.get("inference_id"));
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type"));
Expand All @@ -326,8 +326,16 @@ public void testUnsupportedStream() throws Exception {
}

public void testSupportedStream() throws Exception {
testSupportedStream("streaming_completion_test_service");
}

public void testSupportedStreamForAlias() throws Exception {
testSupportedStream("streaming_completion_test_service_alias");
}

public void testSupportedStream(String serviceName) throws Exception {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we make this private and potentially static?

String modelId = "streaming";
putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION));
putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION, serviceName));
var singleModel = getModel(modelId);
assertEquals(modelId, singleModel.get("inference_id"));
assertEquals(TaskType.COMPLETION.toString(), singleModel.get("task_type"));
Expand All @@ -352,7 +360,7 @@ public void testSupportedStream() throws Exception {

public void testUnifiedCompletionInference() throws Exception {
String modelId = "streaming";
putModel(modelId, mockCompletionServiceModelConfig(TaskType.CHAT_COMPLETION));
putModel(modelId, mockCompletionServiceModelConfig(TaskType.CHAT_COMPLETION, "streaming_completion_test_service"));
var singleModel = getModel(modelId);
assertEquals(modelId, singleModel.get("inference_id"));
assertEquals(TaskType.CHAT_COMPLETION.toString(), singleModel.get("task_type"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
"text_embedding_test_service",
"voyageai",
"watsonxai",
"sagemaker"
"amazon_sagemaker"
).toArray()
)
);
Expand Down Expand Up @@ -93,7 +93,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
"text_embedding_test_service",
"voyageai",
"watsonxai",
"sagemaker"
"amazon_sagemaker"
).toArray()
)
);
Expand Down Expand Up @@ -143,7 +143,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"openai",
"streaming_completion_test_service",
"hugging_face",
"sagemaker"
"amazon_sagemaker"
).toArray()
)
);
Expand All @@ -158,7 +158,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
assertThat(
providers,
containsInAnyOrder(
List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "sagemaker").toArray()
List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "amazon_sagemaker").toArray()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public List<Factory> getInferenceServiceFactories() {

public static class TestInferenceService extends AbstractTestInferenceService {
private static final String NAME = "streaming_completion_test_service";
private static final String ALIAS = "streaming_completion_test_service_alias";
private static final Set<TaskType> supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);

private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
Expand All @@ -75,6 +76,11 @@ public String name() {
return NAME;
}

@Override
public List<String> aliases() {
return List.of(ALIAS);
}

@Override
protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap) {
return TestServiceSettings.fromMap(serviceSettingsMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,11 @@ public String name() {
return NAME;
}

@Override
public List<String> aliases() {
return List.of(OLD_ELSER_SERVICE_NAME);
}

private RankedDocsResults textSimilarityResultsToRankedDocs(
List<? extends InferenceResults> results,
Function<Integer, String> inputSupplier,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails;

public class SageMakerService implements InferenceService {
public static final String NAME = "sagemaker";
public static final String NAME = "amazon_sagemaker";
private static final String DISPLAY_NAME = "Amazon SageMaker";
private static final List<String> ALIASES = List.of("sagemaker", "amazonsagemaker");
private static final int DEFAULT_BATCH_SIZE = 256;
private static final TimeValue DEFAULT_TIMEOUT = TimeValue.THIRTY_SECONDS;
private final SageMakerModelBuilder modelBuilder;
Expand All @@ -67,7 +69,7 @@ public SageMakerService(
this.threadPool = threadPool;
this.configuration = new LazyInitializable<>(
() -> new InferenceServiceConfiguration.Builder().setService(NAME)
.setName("Amazon SageMaker")
.setName(DISPLAY_NAME)
.setTaskTypes(supportedTaskTypes())
.setConfigurations(configurationMap.get())
.build()
Expand All @@ -79,6 +81,11 @@ public String name() {
return NAME;
}

@Override
public List<String> aliases() {
return ALIASES;
}

@Override
public void parseRequestConfig(
String modelId,
Expand Down
Loading