Skip to content

Commit 5909a7d

Browse files
Allowing model to be overridden but not working yet
1 parent 834676d commit 5909a7d

File tree

8 files changed

+81
-17
lines changed

8 files changed

+81
-17
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import org.elasticsearch.core.TimeValue;
1313
import org.elasticsearch.inference.InferenceServiceResults;
1414
import org.elasticsearch.rest.RestStatus;
15-
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
1615
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
1716
import org.elasticsearch.xpack.inference.external.http.sender.RequestManager;
1817
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
@@ -34,13 +33,7 @@ public SingleInputSenderExecutableAction(
3433

3534
@Override
3635
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
37-
if (inferenceInputs instanceof DocumentsOnlyInput == false) {
38-
listener.onFailure(new ElasticsearchStatusException("Invalid inference input type", RestStatus.INTERNAL_SERVER_ERROR));
39-
return;
40-
}
41-
42-
var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs;
43-
if (docsOnlyInput.getInputs().size() > 1) {
36+
if (inferenceInputs.inputSize() > 1) {
4437
listener.onFailure(
4538
new ElasticsearchStatusException(requestTypeForInputValidationError + " only accepts 1 input", RestStatus.BAD_REQUEST)
4639
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public class DocumentsOnlyInput extends InferenceInputs {
1414

1515
public static DocumentsOnlyInput of(InferenceInputs inferenceInputs) {
1616
if (inferenceInputs instanceof DocumentsOnlyInput == false) {
17-
throw createUnsupportedTypeException(inferenceInputs);
17+
throw createUnsupportedTypeException(inferenceInputs, DocumentsOnlyInput.class);
1818
}
1919

2020
return (DocumentsOnlyInput) inferenceInputs;
@@ -40,4 +40,8 @@ public List<String> getInputs() {
4040
public boolean stream() {
4141
return stream;
4242
}
43+
44+
public int inputSize() {
45+
return input.size();
46+
}
4347
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,19 @@
1010
import org.elasticsearch.common.Strings;
1111

1212
public abstract class InferenceInputs {
13-
public static IllegalArgumentException createUnsupportedTypeException(InferenceInputs inferenceInputs) {
14-
return new IllegalArgumentException(Strings.format("Unsupported inference inputs type: [%s]", inferenceInputs.getClass()));
13+
public static IllegalArgumentException createUnsupportedTypeException(InferenceInputs inferenceInputs, Class<?> clazz) {
14+
return new IllegalArgumentException(
15+
Strings.format("Unable to convert inference inputs type: [%s] to [%s]", inferenceInputs.getClass(), clazz)
16+
);
1517
}
1618

1719
public <T> T castTo(Class<T> clazz) {
18-
if (this.getClass().isInstance(clazz) == false) {
19-
throw createUnsupportedTypeException(this);
20+
if (clazz.isInstance(this) == false) {
21+
throw createUnsupportedTypeException(this, clazz);
2022
}
2123

2224
return clazz.cast(this);
2325
}
26+
27+
public abstract int inputSize();
2428
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public class QueryAndDocsInputs extends InferenceInputs {
1414

1515
public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) {
1616
if (inferenceInputs instanceof QueryAndDocsInputs == false) {
17-
throw createUnsupportedTypeException(inferenceInputs);
17+
throw createUnsupportedTypeException(inferenceInputs, QueryAndDocsInputs.class);
1818
}
1919

2020
return (QueryAndDocsInputs) inferenceInputs;
@@ -47,4 +47,7 @@ public boolean stream() {
4747
return stream;
4848
}
4949

50+
public int inputSize() {
51+
return chunks.size();
52+
}
5053
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public class UnifiedChatInput extends InferenceInputs {
1818

1919
public static UnifiedChatInput of(InferenceInputs inferenceInputs) {
2020
if (inferenceInputs instanceof UnifiedChatInput == false) {
21-
throw createUnsupportedTypeException(inferenceInputs);
21+
throw createUnsupportedTypeException(inferenceInputs, UnifiedChatInput.class);
2222
}
2323

2424
return (UnifiedChatInput) inferenceInputs;
@@ -63,4 +63,8 @@ public UnifiedCompletionRequest getRequest() {
6363
public boolean stream() {
6464
return stream;
6565
}
66+
67+
public int inputSize() {
68+
return request.messages().size();
69+
}
6670
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
108108
builder.endArray();
109109

110110
if (unifiedRequest.model() != null) {
111-
builder.field(MODEL_FIELD, unifiedRequest.model());
111+
builder.field(MODEL_FIELD, model.getServiceSettings().modelId());
112112
}
113113
if (unifiedRequest.maxCompletionTokens() != null) {
114114
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.Collections;
2626
import java.util.HashMap;
2727
import java.util.Map;
28+
import java.util.Objects;
2829

2930
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER;
3031

@@ -41,7 +42,24 @@ public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, Map<
4142

4243
public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, UnifiedCompletionRequest request) {
4344
var requestTaskSettings = OpenAiChatCompletionRequestTaskSettings.fromUnifiedRequest(request);
44-
return new OpenAiChatCompletionModel(model, OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings));
45+
var originalModelServiceSettings = model.getServiceSettings();
46+
var overriddenServiceSettings = new OpenAiChatCompletionServiceSettings(
47+
Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()),
48+
originalModelServiceSettings.uri(),
49+
originalModelServiceSettings.organizationId(),
50+
originalModelServiceSettings.maxInputTokens(),
51+
originalModelServiceSettings.rateLimitSettings()
52+
);
53+
54+
var overriddenTaskSettings = OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings);
55+
return new OpenAiChatCompletionModel(
56+
overriddenServiceSettings.modelId(),
57+
model.getTaskType(),
58+
model.getConfigurations().getService(),
59+
overriddenServiceSettings,
60+
overriddenTaskSettings,
61+
model.getSecretSettings()
62+
);
4563
}
4664

4765
public OpenAiChatCompletionModel(
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.test.ESTestCase;
12+
import org.hamcrest.Matchers;
13+
14+
import java.util.List;
15+
16+
public class InferenceInputsTests extends ESTestCase {
17+
public void testCastToSucceeds() {
18+
InferenceInputs inputs = new DocumentsOnlyInput(List.of(), false);
19+
assertThat(inputs.castTo(DocumentsOnlyInput.class), Matchers.instanceOf(DocumentsOnlyInput.class));
20+
21+
assertThat(UnifiedChatInput.of(List.of(), false).castTo(UnifiedChatInput.class), Matchers.instanceOf(UnifiedChatInput.class));
22+
assertThat(
23+
new QueryAndDocsInputs("hello", List.of(), false).castTo(QueryAndDocsInputs.class),
24+
Matchers.instanceOf(QueryAndDocsInputs.class)
25+
);
26+
}
27+
28+
public void testCastToFails() {
29+
InferenceInputs inputs = new DocumentsOnlyInput(List.of(), false);
30+
var exception = expectThrows(IllegalArgumentException.class, () -> inputs.castTo(QueryAndDocsInputs.class));
31+
assertThat(
32+
exception.getMessage(),
33+
Matchers.containsString(
34+
Strings.format("Unable to convert inference inputs type: [%s] to [%s]", DocumentsOnlyInput.class, QueryAndDocsInputs.class)
35+
)
36+
);
37+
}
38+
}

0 commit comments

Comments
 (0)