Skip to content

Commit 6052331

Browse files
authored
[Inference API] Add Completion Inference API for Alibaba Cloud AI Search Model (#112512)
1 parent 893ea68 commit 6052331

20 files changed

+1233
-0
lines changed

docs/changelog/112512.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 112512
2+
summary: Add Completion Inference API for Alibaba Cloud AI Search Model
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
2828
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
2929
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings;
30+
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettings;
31+
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettings;
3032
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettings;
3133
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings;
3234
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankServiceSettings;
@@ -543,6 +545,20 @@ private static void addAlibabaCloudSearchNamedWriteables(List<NamedWriteableRegi
543545
AlibabaCloudSearchRerankTaskSettings::new
544546
)
545547
);
548+
namedWriteables.add(
549+
new NamedWriteableRegistry.Entry(
550+
ServiceSettings.class,
551+
AlibabaCloudSearchCompletionServiceSettings.NAME,
552+
AlibabaCloudSearchCompletionServiceSettings::new
553+
)
554+
);
555+
namedWriteables.add(
556+
new NamedWriteableRegistry.Entry(
557+
TaskSettings.class,
558+
AlibabaCloudSearchCompletionTaskSettings.NAME,
559+
AlibabaCloudSearchCompletionTaskSettings::new
560+
)
561+
);
546562

547563
}
548564

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionCreator.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1212
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1313
import org.elasticsearch.xpack.inference.services.ServiceComponents;
14+
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
1415
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
1516
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;
1617
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
@@ -50,4 +51,11 @@ public ExecutableAction create(AlibabaCloudSearchRerankModel model, Map<String,
5051

5152
return new AlibabaCloudSearchRerankAction(sender, overriddenModel, serviceComponents);
5253
}
54+
55+
@Override
56+
public ExecutableAction create(AlibabaCloudSearchCompletionModel model, Map<String, Object> taskSettings) {
57+
var overriddenModel = AlibabaCloudSearchCompletionModel.of(model, taskSettings);
58+
59+
return new AlibabaCloudSearchCompletionAction(sender, overriddenModel, serviceComponents);
60+
}
5361
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionVisitor.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.inference.InputType;
1111
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
12+
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
1213
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
1314
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;
1415
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
@@ -21,4 +22,6 @@ public interface AlibabaCloudSearchActionVisitor {
2122
ExecutableAction create(AlibabaCloudSearchSparseModel model, Map<String, Object> taskSettings, InputType inputType);
2223

2324
ExecutableAction create(AlibabaCloudSearchRerankModel model, Map<String, Object> taskSettings);
25+
26+
ExecutableAction create(AlibabaCloudSearchCompletionModel model, Map<String, Object> taskSettings);
2427
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.ElasticsearchException;
13+
import org.elasticsearch.ElasticsearchStatusException;
14+
import org.elasticsearch.action.ActionListener;
15+
import org.elasticsearch.core.TimeValue;
16+
import org.elasticsearch.inference.InferenceServiceResults;
17+
import org.elasticsearch.inference.TaskType;
18+
import org.elasticsearch.rest.RestStatus;
19+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
20+
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
21+
import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchCompletionRequestManager;
22+
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
23+
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
24+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
25+
import org.elasticsearch.xpack.inference.services.ServiceComponents;
26+
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
27+
28+
import java.util.Objects;
29+
30+
import static org.elasticsearch.core.Strings.format;
31+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
32+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
33+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
34+
35+
public class AlibabaCloudSearchCompletionAction implements ExecutableAction {
36+
private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchCompletionAction.class);
37+
38+
private final AlibabaCloudSearchAccount account;
39+
private final AlibabaCloudSearchCompletionModel model;
40+
private final String failedToSendRequestErrorMessage;
41+
private final Sender sender;
42+
private final AlibabaCloudSearchCompletionRequestManager requestCreator;
43+
44+
public AlibabaCloudSearchCompletionAction(Sender sender, AlibabaCloudSearchCompletionModel model, ServiceComponents serviceComponents) {
45+
this.model = Objects.requireNonNull(model);
46+
this.sender = Objects.requireNonNull(sender);
47+
this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
48+
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search completion");
49+
this.requestCreator = AlibabaCloudSearchCompletionRequestManager.of(account, model, serviceComponents.threadPool());
50+
}
51+
52+
@Override
53+
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
54+
if (inferenceInputs instanceof DocumentsOnlyInput == false) {
55+
listener.onFailure(
56+
new ElasticsearchStatusException(
57+
format("Invalid inference input type, task type [%s] do not support Field [query]", TaskType.COMPLETION),
58+
RestStatus.INTERNAL_SERVER_ERROR
59+
)
60+
);
61+
return;
62+
}
63+
64+
var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs;
65+
if (docsOnlyInput.getInputs().size() % 2 == 0) {
66+
listener.onFailure(
67+
new ElasticsearchStatusException(
68+
"Alibaba Completion's inputs must be an odd number. The last input is the current query, "
69+
+ "all preceding inputs are the completion history as pairs of user input and the assistant's response.",
70+
RestStatus.BAD_REQUEST
71+
)
72+
);
73+
return;
74+
}
75+
76+
try {
77+
ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(
78+
failedToSendRequestErrorMessage,
79+
listener
80+
);
81+
sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
82+
} catch (ElasticsearchException e) {
83+
listener.onFailure(e);
84+
} catch (Exception e) {
85+
listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage));
86+
}
87+
}
88+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
16+
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
18+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
19+
import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.completion.AlibabaCloudSearchCompletionRequest;
20+
import org.elasticsearch.xpack.inference.external.response.alibabacloudsearch.AlibabaCloudSearchCompletionResponseEntity;
21+
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
22+
23+
import java.util.List;
24+
import java.util.Objects;
25+
import java.util.function.Supplier;
26+
27+
public class AlibabaCloudSearchCompletionRequestManager extends AlibabaCloudSearchRequestManager {
28+
private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchCompletionRequestManager.class);
29+
30+
private static final ResponseHandler HANDLER = createCompletionHandler();
31+
32+
private static ResponseHandler createCompletionHandler() {
33+
return new AlibabaCloudSearchResponseHandler(
34+
"alibaba cloud search completion",
35+
AlibabaCloudSearchCompletionResponseEntity::fromResponse
36+
);
37+
}
38+
39+
public static AlibabaCloudSearchCompletionRequestManager of(
40+
AlibabaCloudSearchAccount account,
41+
AlibabaCloudSearchCompletionModel model,
42+
ThreadPool threadPool
43+
) {
44+
return new AlibabaCloudSearchCompletionRequestManager(
45+
Objects.requireNonNull(account),
46+
Objects.requireNonNull(model),
47+
Objects.requireNonNull(threadPool)
48+
);
49+
}
50+
51+
private final AlibabaCloudSearchCompletionModel model;
52+
53+
private final AlibabaCloudSearchAccount account;
54+
55+
private AlibabaCloudSearchCompletionRequestManager(
56+
AlibabaCloudSearchAccount account,
57+
AlibabaCloudSearchCompletionModel model,
58+
ThreadPool threadPool
59+
) {
60+
super(threadPool, model);
61+
this.account = Objects.requireNonNull(account);
62+
this.model = Objects.requireNonNull(model);
63+
}
64+
65+
@Override
66+
public void execute(
67+
InferenceInputs inferenceInputs,
68+
RequestSender requestSender,
69+
Supplier<Boolean> hasRequestCompletedFunction,
70+
ActionListener<InferenceServiceResults> listener
71+
) {
72+
List<String> input = DocumentsOnlyInput.of(inferenceInputs).getInputs();
73+
AlibabaCloudSearchCompletionRequest request = new AlibabaCloudSearchCompletionRequest(account, input, model);
74+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
75+
}
76+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ public class AlibabaCloudSearchUtils {
1515
public static final String TEXT_EMBEDDING_PATH = "text-embedding";
1616
public static final String SPARSE_EMBEDDING_PATH = "text-sparse-embedding";
1717
public static final String RERANK_PATH = "ranker";
18+
public static final String COMPLETION_PATH = "text-generation";
1819
}
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.completion;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.client.utils.URIBuilder;
13+
import org.apache.http.entity.ByteArrayEntity;
14+
import org.elasticsearch.common.Strings;
15+
import org.elasticsearch.xcontent.XContentType;
16+
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
17+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
18+
import org.elasticsearch.xpack.inference.external.request.Request;
19+
import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchRequest;
20+
import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils;
21+
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
22+
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettings;
23+
24+
import java.net.URI;
25+
import java.net.URISyntaxException;
26+
import java.nio.charset.StandardCharsets;
27+
import java.util.List;
28+
import java.util.Objects;
29+
30+
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;
31+
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
32+
33+
public class AlibabaCloudSearchCompletionRequest extends AlibabaCloudSearchRequest {
34+
private final AlibabaCloudSearchAccount account;
35+
private final List<String> input;
36+
private final URI uri;
37+
private final AlibabaCloudSearchCompletionTaskSettings taskSettings;
38+
private final String model;
39+
private final String host;
40+
private final String workspaceName;
41+
private final String httpSchema;
42+
private final String inferenceEntityId;
43+
44+
public AlibabaCloudSearchCompletionRequest(
45+
AlibabaCloudSearchAccount account,
46+
List<String> input,
47+
AlibabaCloudSearchCompletionModel completionModel
48+
) {
49+
Objects.requireNonNull(completionModel);
50+
51+
this.account = Objects.requireNonNull(account);
52+
this.input = Objects.requireNonNull(input);
53+
taskSettings = completionModel.getTaskSettings();
54+
model = completionModel.getServiceSettings().getCommonSettings().modelId();
55+
host = completionModel.getServiceSettings().getCommonSettings().getHost();
56+
workspaceName = completionModel.getServiceSettings().getCommonSettings().getWorkspaceName();
57+
httpSchema = completionModel.getServiceSettings().getCommonSettings().getHttpSchema() != null
58+
? completionModel.getServiceSettings().getCommonSettings().getHttpSchema()
59+
: "https";
60+
uri = buildUri(null, AlibabaCloudSearchUtils.SERVICE_NAME, this::buildDefaultUri);
61+
inferenceEntityId = completionModel.getInferenceEntityId();
62+
}
63+
64+
@Override
65+
public HttpRequest createHttpRequest() {
66+
HttpPost httpPost = new HttpPost(uri);
67+
68+
ByteArrayEntity byteEntity = new ByteArrayEntity(
69+
Strings.toString(new AlibabaCloudSearchCompletionRequestEntity(input, taskSettings, model)).getBytes(StandardCharsets.UTF_8)
70+
);
71+
httpPost.setEntity(byteEntity);
72+
73+
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
74+
httpPost.setHeader(createAuthBearerHeader(account.apiKey()));
75+
76+
return new HttpRequest(httpPost, getInferenceEntityId());
77+
}
78+
79+
@Override
80+
public String getInferenceEntityId() {
81+
return inferenceEntityId;
82+
}
83+
84+
@Override
85+
public URI getURI() {
86+
return uri;
87+
}
88+
89+
@Override
90+
public Request truncate() {
91+
return this;
92+
}
93+
94+
@Override
95+
public boolean[] getTruncationInfo() {
96+
return null;
97+
}
98+
99+
URI buildDefaultUri() throws URISyntaxException {
100+
return new URIBuilder().setScheme(httpSchema)
101+
.setHost(host)
102+
.setPathSegments(
103+
AlibabaCloudSearchUtils.VERSION_3,
104+
AlibabaCloudSearchUtils.OPENAPI_PATH,
105+
AlibabaCloudSearchUtils.WORKSPACE_PATH,
106+
workspaceName,
107+
AlibabaCloudSearchUtils.COMPLETION_PATH,
108+
model
109+
)
110+
.build();
111+
}
112+
}

0 commit comments

Comments
 (0)