Skip to content

Commit dbc1c56

Browse files
Enhance documentation for OpenShift AI models and add task settings handling in rerank model
1 parent 4fc2556 commit dbc1c56

File tree

5 files changed

+143
-2
lines changed

5 files changed

+143
-2
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/completion/OpenShiftAiChatCompletionModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
/**
2525
* Represents an OpenShift AI chat completion model.
26-
* This class extends the OpenShiftAiModel and provides specific configurations for chat completion tasks.
26+
* This class extends the {@link OpenShiftAiModel} and provides specific configurations for chat completion tasks.
2727
*/
2828
public class OpenShiftAiChatCompletionModel extends OpenShiftAiModel {
2929

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/embeddings/OpenShiftAiEmbeddingsModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
/**
2525
* Represents an OpenShift AI embeddings model for inference.
26-
* This class extends the OpenShiftAiModel and provides specific configurations and settings for embeddings tasks.
26+
* This class extends the {@link OpenShiftAiModel} and provides specific configurations and settings for embeddings tasks.
2727
*/
2828
public class OpenShiftAiEmbeddingsModel extends OpenShiftAiModel {
2929

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModel.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,23 @@
1919

2020
import java.util.Map;
2121

22+
/**
23+
* Represents an OpenShift AI rerank model.
24+
* This class extends the {@link OpenShiftAiModel} and provides specific configurations for rerank tasks.
25+
*/
2226
public class OpenShiftAiRerankModel extends OpenShiftAiModel {
27+
28+
/**
29+
* Creates a new {@link OpenShiftAiRerankModel} with updated task settings if they differ from the existing ones.
30+
* @param model the existing OpenShift AI rerank model
31+
* @param taskSettings the new task settings to apply
32+
* @return a new {@link OpenShiftAiRerankModel} with updated task settings, or the original model if settings are unchanged
33+
*/
2334
public static OpenShiftAiRerankModel of(OpenShiftAiRerankModel model, Map<String, Object> taskSettings) {
2435
var requestTaskSettings = OpenShiftAiRerankTaskSettings.fromMap(taskSettings);
36+
if (requestTaskSettings.isEmpty() || requestTaskSettings.equals(model.getTaskSettings())) {
37+
return model;
38+
}
2539
return new OpenShiftAiRerankModel(model, OpenShiftAiRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings));
2640
}
2741

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankModelTests.java

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313
import org.elasticsearch.test.ESTestCase;
1414
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
1515

16+
import java.util.HashMap;
17+
import java.util.Map;
18+
19+
import static org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankTaskSettings.RETURN_DOCUMENTS;
20+
import static org.elasticsearch.xpack.inference.services.openshiftai.rerank.OpenShiftAiRerankTaskSettings.TOP_N;
21+
import static org.hamcrest.Matchers.is;
22+
1623
public class OpenShiftAiRerankModelTests extends ESTestCase {
1724

1825
public static OpenShiftAiRerankModel createModel(String url, String apiKey, @Nullable String modelId) {
@@ -35,4 +42,66 @@ public static OpenShiftAiRerankModel createModel(
3542
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
3643
);
3744
}
45+
46+
public void testOverrideWith_SameParams_KeepsSameModel() {
47+
testOverrideWith_KeepsSameModel(buildTaskSettingsMap(2, true));
48+
}
49+
50+
public void testOverrideWith_EmptyParams_KeepsSameModel() {
51+
testOverrideWith_KeepsSameModel(buildTaskSettingsMap(null, null));
52+
}
53+
54+
private static void testOverrideWith_KeepsSameModel(Map<String, Object> taskSettings) {
55+
var model = createModel("url", "api_key", "model_name", 2, true);
56+
var overriddenModel = OpenShiftAiRerankModel.of(model, taskSettings);
57+
58+
assertThat(overriddenModel.getTaskSettings().getTopN(), is(2));
59+
assertThat(overriddenModel.getTaskSettings().getReturnDocuments(), is(true));
60+
}
61+
62+
public void testOverrideWith_DifferentParams_OverridesAllTaskSettings() {
63+
testOverrideWith_DifferentParams(buildTaskSettingsMap(4, false), 4, false);
64+
}
65+
66+
public void testOverrideWith_DifferentParams_OverridesOnlyReturnDocuments() {
67+
testOverrideWith_DifferentParams(buildTaskSettingsMap(null, false), 2, false);
68+
}
69+
70+
public void testOverrideWith_DifferentParams_OverridesOnlyTopN() {
71+
testOverrideWith_DifferentParams(buildTaskSettingsMap(4, null), 4, true);
72+
}
73+
74+
public void testOverrideWith_DifferentParams_OverridesNullValues() {
75+
var model = createModel("url", "api_key", "model_name", null, null);
76+
var overriddenModel = OpenShiftAiRerankModel.of(model, buildTaskSettingsMap(4, false));
77+
78+
assertThat(overriddenModel.getTaskSettings().getTopN(), is(4));
79+
assertThat(overriddenModel.getTaskSettings().getReturnDocuments(), is(false));
80+
}
81+
82+
private static void testOverrideWith_DifferentParams(
83+
Map<String, Object> taskSettings,
84+
int expectedTopN,
85+
boolean expectedReturnDocuments
86+
) {
87+
var model = createModel("url", "api_key", "model_name", 2, true);
88+
var overriddenModel = OpenShiftAiRerankModel.of(model, taskSettings);
89+
90+
assertThat(overriddenModel.getTaskSettings().getTopN(), is(expectedTopN));
91+
assertThat(overriddenModel.getTaskSettings().getReturnDocuments(), is(expectedReturnDocuments));
92+
}
93+
94+
private static Map<String, Object> buildTaskSettingsMap(@Nullable Integer topN, @Nullable Boolean returnDocuments) {
95+
final var map = new HashMap<String, Object>();
96+
97+
if (returnDocuments != null) {
98+
map.put(RETURN_DOCUMENTS, returnDocuments);
99+
}
100+
101+
if (topN != null) {
102+
map.put(TOP_N, topN);
103+
}
104+
105+
return map;
106+
}
38107
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/rerank/OpenShiftAiRerankTaskSettingsTests.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,20 @@
88
package org.elasticsearch.xpack.inference.services.openshiftai.rerank;
99

1010
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.common.Strings;
1112
import org.elasticsearch.common.ValidationException;
1213
import org.elasticsearch.common.io.stream.Writeable;
1314
import org.elasticsearch.test.ESTestCase;
15+
import org.elasticsearch.xcontent.XContentBuilder;
16+
import org.elasticsearch.xcontent.XContentFactory;
17+
import org.elasticsearch.xcontent.XContentType;
1418
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
1519

1620
import java.io.IOException;
1721
import java.util.HashMap;
1822
import java.util.Map;
1923

24+
import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
2025
import static org.hamcrest.Matchers.containsString;
2126

2227
public class OpenShiftAiRerankTaskSettingsTests extends AbstractBWCWireSerializationTestCase<OpenShiftAiRerankTaskSettings> {
@@ -97,6 +102,59 @@ public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings
97102
assertEquals(7, updatedSettings.getTopN().intValue());
98103
}
99104

105+
public void testToXContent_WritesAllValues() throws IOException {
106+
Integer topN = 2;
107+
Boolean doReturnDocuments = true;
108+
109+
testToXContent(topN, doReturnDocuments, """
110+
{
111+
"top_n":2,
112+
"return_documents":true
113+
}
114+
""");
115+
}
116+
117+
public void testToXContent_EmptyValues() throws IOException {
118+
Integer topN = null;
119+
Boolean doReturnDocuments = null;
120+
121+
testToXContent(topN, doReturnDocuments, """
122+
{}
123+
""");
124+
}
125+
126+
public void testToXContent_OnlyTopN() throws IOException {
127+
Integer topN = 2;
128+
Boolean doReturnDocuments = null;
129+
130+
testToXContent(topN, doReturnDocuments, """
131+
{
132+
"top_n":2
133+
}
134+
""");
135+
}
136+
137+
public void testToXContent_OnlyReturnDocuments() throws IOException {
138+
Integer topN = null;
139+
Boolean doReturnDocuments = true;
140+
141+
testToXContent(topN, doReturnDocuments, """
142+
{
143+
"return_documents":true
144+
}
145+
""");
146+
}
147+
148+
private static void testToXContent(Integer topN, Boolean doReturnDocuments, String expectedString) throws IOException {
149+
var taskSettings = new OpenShiftAiRerankTaskSettings(topN, doReturnDocuments);
150+
151+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
152+
taskSettings.toXContent(builder, null);
153+
String xContentResult = Strings.toString(builder);
154+
155+
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(expectedString));
156+
}
157+
100158
@Override
101159
protected Writeable.Reader<OpenShiftAiRerankTaskSettings> instanceReader() {
102160
return OpenShiftAiRerankTaskSettings::new;

0 commit comments

Comments
 (0)