Skip to content

Commit c05655f

Browse files
committed
Refactor use GenericRequestManager instead of GoogleVertexAiCompletionRequestManager
1 parent 7b99b1d commit c05655f

File tree

5 files changed

+63
-79
lines changed

5 files changed

+63
-79
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiCompletionRequestManager.java

Lines changed: 0 additions & 69 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,20 @@
77

88
package org.elasticsearch.xpack.inference.services.googlevertexai;
99

10-
import org.elasticsearch.inference.Model;
1110
import org.elasticsearch.inference.ModelConfigurations;
1211
import org.elasticsearch.inference.ModelSecrets;
1312
import org.elasticsearch.inference.ServiceSettings;
1413
import org.elasticsearch.inference.TaskSettings;
1514
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
15+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
1616
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor;
17+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
1718

1819
import java.net.URI;
1920
import java.util.Map;
2021
import java.util.Objects;
2122

22-
public abstract class GoogleVertexAiModel extends Model {
23+
public abstract class GoogleVertexAiModel extends RateLimitGroupingModel {
2324

2425
private final GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings;
2526

@@ -58,4 +59,15 @@ public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() {
5859
public URI uri() {
5960
return uri;
6061
}
62+
63+
@Override
64+
public int rateLimitGroupingHash() {
65+
// In VertexAI rate limiting is scoped to the project and the model. URI already has this information so we are using that
66+
return Objects.hash(uri);
67+
}
68+
69+
@Override
70+
public RateLimitSettings rateLimitSettings() {
71+
return rateLimitServiceSettings().rateLimitSettings();
72+
}
6173
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
3131
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
3232
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
33+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
3334
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
35+
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
3436
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
3537
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
3638
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@@ -42,6 +44,7 @@
4244
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
4345
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
4446
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
47+
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
4548
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
4649
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
4750

@@ -84,6 +87,9 @@ public class GoogleVertexAiService extends SenderService {
8487
InputType.INTERNAL_SEARCH
8588
);
8689

90+
private final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
91+
"Google VertexAI chat completion"
92+
);
8793
@Override
8894
public Set<TaskType> supportedStreamingTasks() {
8995
return EnumSet.of(TaskType.CHAT_COMPLETION);
@@ -240,7 +246,13 @@ protected void doUnifiedCompletionInfer(
240246
var chatCompletionModel = (GoogleVertexAiChatCompletionModel) model;
241247
var updatedChatCompletionModel = GoogleVertexAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest());
242248

243-
var manager = GoogleVertexAiCompletionRequestManager.of(updatedChatCompletionModel, getServiceComponents().threadPool());
249+
var manager = new GenericRequestManager<>(
250+
getServiceComponents().threadPool(),
251+
updatedChatCompletionModel,
252+
COMPLETION_HANDLER,
253+
(unifiedChatInput) -> new GoogleVertexAiUnifiedChatCompletionRequest(unifiedChatInput, updatedChatCompletionModel),
254+
UnifiedChatInput.class
255+
);
244256

245257
var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
246258
var action = new SenderExecutableAction(getSender(), manager, errorMessage);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,19 @@
99

1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1111
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
12+
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
13+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
14+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
15+
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
1216
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
17+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
1318
import org.elasticsearch.xpack.inference.services.ServiceComponents;
14-
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiCompletionRequestManager;
1519
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiEmbeddingsRequestManager;
1620
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRerankRequestManager;
21+
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedChatCompletionResponseHandler;
1722
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
1823
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
24+
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
1925
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
2026

2127
import java.util.Map;
@@ -30,6 +36,11 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor
3036

3137
private final ServiceComponents serviceComponents;
3238

39+
static final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
40+
"Google VertexAI chat completion"
41+
);
42+
static final String USER_ROLE = "user";
43+
3344
public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) {
3445
this.sender = Objects.requireNonNull(sender);
3546
this.serviceComponents = Objects.requireNonNull(serviceComponents);
@@ -56,8 +67,16 @@ public ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Obje
5667

5768
@Override
5869
public ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettings) {
59-
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Google Vertex AI chat completion");
60-
var requestManager = GoogleVertexAiCompletionRequestManager.of(model, serviceComponents.threadPool());
61-
return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage);
70+
71+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
72+
var manager = new GenericRequestManager<>(
73+
serviceComponents.threadPool(),
74+
model,
75+
COMPLETION_HANDLER,
76+
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
77+
ChatCompletionInput.class
78+
);
79+
80+
return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX);
6281
}
6382
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
2121
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
2222
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
23+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
24+
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
2325
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2426
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
2527
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
26-
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiCompletionRequestManager;
2728
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests;
29+
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
2830
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
2931
import org.junit.After;
3032
import org.junit.Before;
@@ -36,6 +38,8 @@
3638
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
3739
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
3840
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
41+
import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.COMPLETION_HANDLER;
42+
import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.USER_ROLE;
3943
import static org.hamcrest.Matchers.is;
4044
import static org.mockito.ArgumentMatchers.any;
4145
import static org.mockito.Mockito.doAnswer;
@@ -123,9 +127,15 @@ private ExecutableAction createAction(String location, String projectId, String
123127
new RateLimitSettings(100)
124128
);
125129

126-
var requestManager = new GoogleVertexAiCompletionRequestManager(model, threadPool);
130+
var manager = new GenericRequestManager<>(
131+
threadPool,
132+
model,
133+
COMPLETION_HANDLER,
134+
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
135+
ChatCompletionInput.class
136+
);
127137
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Google Vertex AI chat completion");
128-
return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage);
138+
return new SenderExecutableAction(sender, manager, failedToSendRequestErrorMessage);
129139
}
130140

131141
}

0 commit comments

Comments
 (0)