Skip to content

Commit ffd491d

Browse files
Refactor OpenShift AI model to use model ID directly instead of UnifiedCompletionRequest
1 parent 7f40459 commit ffd491d

File tree

3 files changed

+11
-83
lines changed

3 files changed

+11
-83
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ protected void doUnifiedCompletionInfer(
138138
}
139139

140140
OpenShiftAiChatCompletionModel chatCompletionModel = (OpenShiftAiChatCompletionModel) model;
141-
var overriddenModel = OpenShiftAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest());
141+
var overriddenModel = OpenShiftAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest().model());
142142
var manager = new GenericRequestManager<>(
143143
getServiceComponents().threadPool(),
144144
overriddenModel,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
import org.elasticsearch.inference.ModelSecrets;
1313
import org.elasticsearch.inference.SecretSettings;
1414
import org.elasticsearch.inference.TaskType;
15-
import org.elasticsearch.inference.UnifiedCompletionRequest;
1615
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1716
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
1817
import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiModel;
1918
import org.elasticsearch.xpack.inference.services.openshiftai.action.OpenShiftAiActionVisitor;
2019
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
2120

2221
import java.util.Map;
22+
import java.util.Objects;
2323

2424
/**
2525
* Represents an OpenShift AI chat completion model.
@@ -79,18 +79,18 @@ public OpenShiftAiChatCompletionModel(
7979
* If the request does not specify a model ID, the original model is returned.
8080
*
8181
* @param model the original OpenShiftAiChatCompletionModel
82-
* @param request the UnifiedCompletionRequest containing potential overrides
82+
* @param modelId the model ID specified in the request, which may override the original model's ID
8383
* @return a new OpenShiftAiChatCompletionModel with overridden settings or the original model ID if no overrides are specified
8484
*/
85-
public static OpenShiftAiChatCompletionModel of(OpenShiftAiChatCompletionModel model, UnifiedCompletionRequest request) {
86-
if (request.model() == null) {
87-
// If no model ID is specified in the request, return the original model
85+
public static OpenShiftAiChatCompletionModel of(OpenShiftAiChatCompletionModel model, String modelId) {
86+
if (modelId == null || Objects.equals(model.getServiceSettings().modelId(), modelId)) {
87+
// If no model ID is specified in the request, or if it matches the original model's ID, return the original model.
8888
return model;
8989
}
9090

9191
var originalModelServiceSettings = model.getServiceSettings();
9292
var overriddenServiceSettings = new OpenShiftAiChatCompletionServiceSettings(
93-
request.model(),
93+
modelId,
9494
originalModelServiceSettings.uri(),
9595
originalModelServiceSettings.rateLimitSettings()
9696
);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModelTests.java

Lines changed: 4 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,9 @@
99

1010
import org.elasticsearch.common.settings.SecureString;
1111
import org.elasticsearch.inference.TaskType;
12-
import org.elasticsearch.inference.UnifiedCompletionRequest;
1312
import org.elasticsearch.test.ESTestCase;
1413
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
1514

16-
import java.util.List;
17-
1815
import static org.hamcrest.Matchers.is;
1916

2017
public class OpenShiftAiChatCompletionModelTests extends ESTestCase {
@@ -38,90 +35,21 @@ public static OpenShiftAiChatCompletionModel createModelWithTaskType(String url,
3835

3936
public void testOverrideWith_UnifiedCompletionRequest_KeepsSameModelId() {
4037
var model = createCompletionModel("url", "api_key", "model_name");
41-
var request = new UnifiedCompletionRequest(
42-
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
43-
"model_name", // same model
44-
null,
45-
null,
46-
null,
47-
null,
48-
null,
49-
null
50-
);
51-
52-
var overriddenModel = OpenShiftAiChatCompletionModel.of(model, request);
38+
var overriddenModel = OpenShiftAiChatCompletionModel.of(model, "model_name");
5339

54-
assertThat(overriddenModel, is(model));
40+
assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name"));
5541
}
5642

5743
public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() {
5844
var model = createCompletionModel("url", "api_key", "model_name");
59-
var request = new UnifiedCompletionRequest(
60-
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
61-
"different_model", // overriding model
62-
null,
63-
null,
64-
null,
65-
null,
66-
null,
67-
null
68-
);
69-
70-
var overriddenModel = OpenShiftAiChatCompletionModel.of(model, request);
45+
var overriddenModel = OpenShiftAiChatCompletionModel.of(model, "different_model");
7146

7247
assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
7348
}
7449

75-
public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() {
76-
var model = createCompletionModel("url", "api_key", null);
77-
var request = new UnifiedCompletionRequest(
78-
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
79-
"different_model", // overriding model
80-
null,
81-
null,
82-
null,
83-
null,
84-
null,
85-
null
86-
);
87-
88-
var overriddenModel = OpenShiftAiChatCompletionModel.of(model, request);
89-
90-
assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
91-
}
92-
93-
public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() {
94-
var model = createCompletionModel("url", "api_key", null);
95-
var request = new UnifiedCompletionRequest(
96-
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
97-
null, // not overriding model
98-
null,
99-
null,
100-
null,
101-
null,
102-
null,
103-
null
104-
);
105-
106-
var overriddenModel = OpenShiftAiChatCompletionModel.of(model, request);
107-
108-
assertNull(overriddenModel.getServiceSettings().modelId());
109-
}
110-
11150
public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() {
11251
var model = createCompletionModel("url", "api_key", "model_name");
113-
var request = new UnifiedCompletionRequest(
114-
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
115-
null, // not overriding model
116-
null,
117-
null,
118-
null,
119-
null,
120-
null,
121-
null
122-
);
123-
124-
var overriddenModel = OpenShiftAiChatCompletionModel.of(model, request);
52+
var overriddenModel = OpenShiftAiChatCompletionModel.of(model, null);
12553

12654
assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name"));
12755
}

0 commit comments

Comments
 (0)