Skip to content

Commit d5c80bb

Browse files
committed
fix the tests
1 parent 14cc105 commit d5c80bb

20 files changed

+212
-292
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) {
2727
private final Boolean returnDocuments;
2828
private final Integer topN;
2929

30-
public QueryAndDocsInputs(String query, List<String> chunks) {
31-
this(query, chunks, null, null, false);
32-
}
33-
3430
public QueryAndDocsInputs(
3531
String query,
3632
List<String> chunks,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java

Lines changed: 0 additions & 62 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsRequestManager.java

Lines changed: 0 additions & 63 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,22 @@
99

1010
import org.elasticsearch.common.settings.SecureString;
1111
import org.elasticsearch.core.Nullable;
12-
import org.elasticsearch.inference.Model;
1312
import org.elasticsearch.inference.ModelConfigurations;
1413
import org.elasticsearch.inference.ModelSecrets;
1514
import org.elasticsearch.inference.ServiceSettings;
1615
import org.elasticsearch.inference.TaskSettings;
1716
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
17+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
1818
import org.elasticsearch.xpack.inference.services.ServiceUtils;
1919
import org.elasticsearch.xpack.inference.services.cohere.action.CohereActionVisitor;
2020
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
21+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
2122

2223
import java.net.URI;
2324
import java.util.Map;
2425
import java.util.Objects;
2526

26-
public abstract class CohereModel extends Model {
27+
public abstract class CohereModel extends RateLimitGroupingModel {
2728
private final SecureString apiKey;
2829
private final CohereRateLimitServiceSettings rateLimitServiceSettings;
2930

@@ -63,5 +64,15 @@ public CohereRateLimitServiceSettings rateLimitServiceSettings() {
6364

6465
public abstract ExecutableAction accept(CohereActionVisitor creator, Map<String, Object> taskSettings);
6566

66-
public abstract URI uri();
67+
public RateLimitSettings rateLimitSettings() {
68+
return rateLimitServiceSettings.rateLimitSettings();
69+
}
70+
71+
public int rateLimitGroupingHash() {
72+
return apiKey().hashCode();
73+
}
74+
75+
public URI uri() {
76+
return rateLimitServiceSettings.uri();
77+
}
6778
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRateLimitServiceSettings.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
1111

12+
import java.net.URI;
13+
1214
public interface CohereRateLimitServiceSettings {
1315
RateLimitSettings rateLimitSettings();
1416

17+
URI uri();
1518
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRerankRequestManager.java

Lines changed: 0 additions & 64 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,23 @@
1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1111
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1212
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
13+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
14+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
15+
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
16+
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
17+
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
1318
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1419
import org.elasticsearch.xpack.inference.services.ServiceComponents;
15-
import org.elasticsearch.xpack.inference.services.cohere.CohereCompletionRequestManager;
16-
import org.elasticsearch.xpack.inference.services.cohere.CohereEmbeddingsRequestManager;
17-
import org.elasticsearch.xpack.inference.services.cohere.CohereRerankRequestManager;
20+
import org.elasticsearch.xpack.inference.services.cohere.CohereResponseHandler;
1821
import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel;
1922
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
23+
import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1CompletionRequest;
24+
import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1EmbeddingsRequest;
25+
import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1RerankRequest;
2026
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel;
27+
import org.elasticsearch.xpack.inference.services.cohere.response.CohereCompletionResponseEntity;
28+
import org.elasticsearch.xpack.inference.services.cohere.response.CohereEmbeddingsResponseEntity;
29+
import org.elasticsearch.xpack.inference.services.cohere.response.CohereRankedResponseEntity;
2130

2231
import java.util.Map;
2332
import java.util.Objects;
@@ -28,12 +37,30 @@
2837
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the cohere model type.
2938
*/
3039
public class CohereActionCreator implements CohereActionVisitor {
40+
41+
private static final ResponseHandler EMBEDDINGS_HANDLER = new CohereResponseHandler(
42+
"cohere text embedding",
43+
CohereEmbeddingsResponseEntity::fromResponse,
44+
false
45+
);
46+
47+
private static final ResponseHandler RERANK_HANDLER = new CohereResponseHandler(
48+
"cohere rerank",
49+
(request, response) -> CohereRankedResponseEntity.fromResponse(response),
50+
false
51+
);
52+
53+
private static final ResponseHandler COMPLETION_HANDLER = new CohereResponseHandler(
54+
"cohere completion",
55+
CohereCompletionResponseEntity::fromResponse,
56+
true
57+
);
58+
3159
private static final String COMPLETION_ERROR_PREFIX = "Cohere completion";
3260
private final Sender sender;
3361
private final ServiceComponents serviceComponents;
3462

3563
public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) {
36-
// TODO Batching - accept a class that can handle batching
3764
this.sender = Objects.requireNonNull(sender);
3865
this.serviceComponents = Objects.requireNonNull(serviceComponents);
3966
}
@@ -42,24 +69,53 @@ public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) {
4269
public ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object> taskSettings) {
4370
var overriddenModel = CohereEmbeddingsModel.of(model, taskSettings);
4471
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere embeddings");
45-
// TODO - Batching pass the batching class on to the CohereEmbeddingsRequestManager
46-
var requestCreator = CohereEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
72+
var requestCreator = new GenericRequestManager<>(
73+
serviceComponents.threadPool(),
74+
model,
75+
EMBEDDINGS_HANDLER,
76+
(inferenceInputs -> new CohereV1EmbeddingsRequest(
77+
inferenceInputs.getStringInputs(),
78+
inferenceInputs.getInputType(),
79+
overriddenModel
80+
)),
81+
EmbeddingsInput.class
82+
);
4783
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
4884
}
4985

5086
@Override
5187
public ExecutableAction create(CohereRerankModel model, Map<String, Object> taskSettings) {
5288
var overriddenModel = CohereRerankModel.of(model, taskSettings);
53-
var requestCreator = CohereRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
89+
var requestCreator = new GenericRequestManager<>(
90+
serviceComponents.threadPool(),
91+
overriddenModel,
92+
RERANK_HANDLER,
93+
(inferenceInputs -> new CohereV1RerankRequest(
94+
inferenceInputs.getQuery(),
95+
inferenceInputs.getChunks(),
96+
inferenceInputs.getReturnDocuments(),
97+
inferenceInputs.getTopN(),
98+
overriddenModel
99+
)),
100+
QueryAndDocsInputs.class
101+
);
102+
54103
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere rerank");
55104
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
56105
}
57106

58107
@Override
59108
public ExecutableAction create(CohereCompletionModel model, Map<String, Object> taskSettings) {
60109
// no overridden model as task settings are always empty for cohere completion model
61-
var requestManager = CohereCompletionRequestManager.of(model, serviceComponents.threadPool());
110+
var requestCreator = new GenericRequestManager<>(
111+
serviceComponents.threadPool(),
112+
model,
113+
COMPLETION_HANDLER,
114+
(completionInput) -> new CohereV1CompletionRequest(completionInput.getInputs(), model, completionInput.stream()),
115+
ChatCompletionInput.class
116+
);
117+
62118
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
63-
return new SingleInputSenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX);
119+
return new SingleInputSenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX);
64120
}
65121
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import org.elasticsearch.xpack.inference.services.cohere.action.CohereActionVisitor;
2121
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
2222

23-
import java.net.URI;
2423
import java.util.Map;
2524

2625
public class CohereCompletionModel extends CohereModel {
@@ -73,9 +72,4 @@ public DefaultSecretSettings getSecretSettings() {
7372
public ExecutableAction accept(CohereActionVisitor visitor, Map<String, Object> taskSettings) {
7473
return visitor.create(this, taskSettings);
7574
}
76-
77-
@Override
78-
public URI uri() {
79-
return getServiceSettings().uri();
80-
}
8175
}

0 commit comments

Comments
 (0)