Skip to content

Commit b451f06

Browse files
author
Max Hniebergall
committed
Add tests
1 parent 0e87606 commit b451f06

File tree

46 files changed

+696
-70
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+696
-70
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,21 @@ static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody) {
8181
""", taskType);
8282
}
8383

84+
static String updateConfig(@Nullable TaskType taskTypeInBody, String apiKey, int temperature) {
85+
var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
86+
return Strings.format("""
87+
{
88+
%s
89+
"service_settings": {
90+
"api_key": "%s"
91+
},
92+
"task_settings": {
93+
"temperature": %d
94+
}
95+
}
96+
""", taskType, apiKey, temperature);
97+
}
98+
8499
static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody) {
85100
var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
86101
return Strings.format("""
@@ -196,6 +211,11 @@ protected Map<String, Object> putModel(String modelId, String modelConfig, TaskT
196211
return putRequest(endpoint, modelConfig);
197212
}
198213

214+
protected Map<String, Object> updateEndpoint(String inferenceID, String modelConfig, TaskType taskType) throws IOException {
215+
String endpoint = Strings.format("_inference/%s/%s/_update", taskType, inferenceID);
216+
return putRequest(endpoint, modelConfig);
217+
}
218+
199219
protected Map<String, Object> putPipeline(String pipelineId, String modelId) throws IOException {
200220
String endpoint = Strings.format("_ingest/pipeline/%s", pipelineId);
201221
String body = """

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import java.io.IOException;
1818
import java.util.List;
19+
import java.util.Map;
20+
import java.util.Objects;
1921
import java.util.Set;
2022
import java.util.function.Function;
2123
import java.util.stream.IntStream;
@@ -29,7 +31,7 @@
2931
public class InferenceCrudIT extends InferenceBaseRestTest {
3032

3133
@SuppressWarnings("unchecked")
32-
public void testGet() throws IOException {
34+
public void testCRUD() throws IOException {
3335
for (int i = 0; i < 5; i++) {
3436
putModel("se_model_" + i, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
3537
}
@@ -53,11 +55,29 @@ public void testGet() throws IOException {
5355
for (var denseModel : getDenseModels) {
5456
assertEquals("text_embedding", denseModel.get("task_type"));
5557
}
56-
57-
var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING);
58-
assertThat(singleModel, hasSize(1));
59-
assertEquals("se_model_1", singleModel.get(0).get("inference_id"));
60-
58+
String oldApiKey;
59+
{
60+
var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING);
61+
assertThat(singleModel, hasSize(1));
62+
assertEquals("se_model_1", singleModel.get(0).get("inference_id"));
63+
oldApiKey = (String) singleModel.get(0).get("api_key");
64+
}
65+
var newApiKey = randomAlphaOfLength(10);
66+
int temperature = randomIntBetween(1, 10);
67+
Map<String, Object> updatedEndpoint = updateEndpoint(
68+
"se_model_1",
69+
updateConfig(TaskType.SPARSE_EMBEDDING, newApiKey, temperature),
70+
TaskType.SPARSE_EMBEDDING
71+
);
72+
Map<String, Objects> updatedTaskSettings = (Map<String, Objects>) updatedEndpoint.get("task_settings");
73+
assertEquals(temperature, updatedTaskSettings.get("temperature"));
74+
{
75+
var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING);
76+
assertThat(singleModel, hasSize(1));
77+
assertEquals("se_model_1", singleModel.get(0).get("inference_id"));
78+
assertNotEquals(oldApiKey, newApiKey);
79+
assertEquals(updatedEndpoint, singleModel.get(0));
80+
}
6181
for (int i = 0; i < 5; i++) {
6282
deleteModel("se_model_" + i, TaskType.SPARSE_EMBEDDING);
6383
}

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ public TransportVersion getMinimalSupportedVersion() {
161161

162162
@Override
163163
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
164-
return fromMap(newSettings);
164+
return fromMap(new HashMap<>(newSettings));
165165
}
166166
}
167167

@@ -214,7 +214,7 @@ public TransportVersion getMinimalSupportedVersion() {
214214

215215
@Override
216216
public SecretSettings newSecretSettings(Map<String, Object> newSecrets) {
217-
return TestSecretSettings.fromMap(newSecrets);
217+
return TestSecretSettings.fromMap(new HashMap<>(newSecrets));
218218
}
219219
}
220220
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettings.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ public Map<String, Object> getParameters() {
137137

138138
@Override
139139
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
140-
AlibabaCloudSearchCompletionTaskSettings updatedSettings = AlibabaCloudSearchCompletionTaskSettings.fromMap(newSettings);
140+
AlibabaCloudSearchCompletionTaskSettings updatedSettings = AlibabaCloudSearchCompletionTaskSettings.fromMap(
141+
new HashMap<>(newSettings)
142+
);
141143
return of(this, updatedSettings);
142144
}
143145
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettings.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import java.io.IOException;
2323
import java.util.EnumSet;
24+
import java.util.HashMap;
2425
import java.util.Map;
2526
import java.util.Objects;
2627

@@ -173,7 +174,7 @@ public static String invalidInputTypeMessage(InputType inputType) {
173174

174175
@Override
175176
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
176-
AlibabaCloudSearchEmbeddingsTaskSettings newSettingsOnly = fromMap(newSettings);
177+
AlibabaCloudSearchEmbeddingsTaskSettings newSettingsOnly = fromMap(new HashMap<>(newSettings));
177178
return of(this, newSettingsOnly, newSettingsOnly.inputType != null ? newSettingsOnly.inputType : this.getInputType());
178179
}
179180
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettings.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import java.io.IOException;
2323
import java.util.EnumSet;
24+
import java.util.HashMap;
2425
import java.util.Map;
2526
import java.util.Objects;
2627

@@ -186,7 +187,7 @@ public static String invalidInputTypeMessage(InputType inputType) {
186187

187188
@Override
188189
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
189-
AlibabaCloudSearchSparseTaskSettings updatedSettings = fromMap(newSettings);
190+
AlibabaCloudSearchSparseTaskSettings updatedSettings = fromMap(new HashMap<>(newSettings));
190191
return of(this, updatedSettings, updatedSettings.getInputType() != null ? updatedSettings.getInputType() : this.inputType);
191192
}
192193
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettings.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.xcontent.XContentBuilder;
1919

2020
import java.io.IOException;
21+
import java.util.HashMap;
2122
import java.util.Map;
2223
import java.util.Objects;
2324

@@ -110,6 +111,6 @@ public int hashCode() {
110111

111112
@Override
112113
public SecretSettings newSecretSettings(Map<String, Object> newSecrets) {
113-
return fromMap(newSecrets);
114+
return fromMap(new HashMap<>(newSecrets));
114115
}
115116
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettings.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.xcontent.XContentBuilder;
1818

1919
import java.io.IOException;
20+
import java.util.HashMap;
2021
import java.util.Map;
2122
import java.util.Objects;
2223

@@ -191,7 +192,7 @@ public int hashCode() {
191192
@Override
192193
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
193194
AmazonBedrockChatCompletionRequestTaskSettings requestSettings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap(
194-
newSettings
195+
new HashMap<>(newSettings)
195196
);
196197
return of(this, requestSettings);
197198
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionTaskSettings.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
2020

2121
import java.io.IOException;
22+
import java.util.HashMap;
2223
import java.util.Map;
2324
import java.util.Objects;
2425

@@ -61,7 +62,7 @@ private static AnthropicChatCompletionTaskSettings fromPersistedMap(Map<String,
6162

6263
@Override
6364
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
64-
return fromRequestMap(newSettings);
65+
return fromRequestMap(new HashMap<>(newSettings));
6566
}
6667

6768
private record CommonFields(int maxTokens, Double temperature, Double topP, Integer topK) {}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionTaskSettings.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettings;
2121

2222
import java.io.IOException;
23+
import java.util.HashMap;
2324
import java.util.Map;
2425
import java.util.Objects;
2526

@@ -173,6 +174,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
173174
return builder;
174175
}
175176

177+
@Override
178+
public String toString() {
179+
return "AzureAiStudioChatCompletionTaskSettings{"
180+
+ "temperature="
181+
+ temperature
182+
+ ", topP="
183+
+ topP
184+
+ ", doSample="
185+
+ doSample
186+
+ ", maxNewTokens="
187+
+ maxNewTokens
188+
+ '}';
189+
}
190+
176191
@Override
177192
public boolean equals(Object o) {
178193
if (this == o) return true;
@@ -192,7 +207,7 @@ public int hashCode() {
192207
@Override
193208
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
194209
AzureAiStudioChatCompletionRequestTaskSettings requestSettings = AzureAiStudioChatCompletionRequestTaskSettings.fromMap(
195-
newSettings
210+
new HashMap<>(newSettings)
196211
);
197212
return of(this, requestSettings);
198213
}

0 commit comments

Comments
 (0)