Skip to content

Commit 698a7b6

Browse files
jonathan-buttnerelasticsearchmachine
andauthored
[ML] Refactor inference API service tests base classes (#135461)
* Refactoring openai * Splitting up parameterized tests * Working tests * [CI] Auto commit changes from spotless * [CI] Update transport version definitions * Removing deprecated function * Moving string creation and refactoring customservice chunking * Removing usages of persistent function * [CI] Auto commit changes from spotless * Finishing comment --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 02a519a commit 698a7b6

File tree

43 files changed

+1163
-1483
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1163
-1483
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,11 +1079,27 @@ public interface EnumConstructor<E extends Enum<E>> {
10791079
E apply(String name) throws IllegalArgumentException;
10801080
}
10811081

1082-
public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName) {
1082+
/**
1083+
* Create an exception for when the task type is not valid for the service.
1084+
*/
1085+
public static ElasticsearchStatusException createInvalidTaskTypeException(
1086+
String inferenceEntityId,
1087+
String serviceName,
1088+
TaskType taskType,
1089+
ConfigurationParseContext parseContext
1090+
) {
1091+
var message = parseContext == ConfigurationParseContext.PERSISTENT
1092+
? parsePersistedConfigErrorMsg(inferenceEntityId, serviceName, taskType)
1093+
: TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName);
1094+
return new ElasticsearchStatusException(message, RestStatus.BAD_REQUEST);
1095+
}
1096+
1097+
private static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName, TaskType taskType) {
10831098
return format(
1084-
"Failed to parse stored model [%s] for [%s] service, please delete and add the service again",
1099+
"Failed to parse stored model [%s] for [%s] service, error: [%s]. Please delete and add the service again",
10851100
inferenceEntityId,
1086-
serviceName
1101+
serviceName,
1102+
TaskType.unsupportedTaskTypeErrorMsg(taskType, serviceName)
10871103
);
10881104
}
10891105

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.services.ai21;
99

10-
import org.elasticsearch.ElasticsearchStatusException;
1110
import org.elasticsearch.TransportVersion;
1211
import org.elasticsearch.action.ActionListener;
1312
import org.elasticsearch.cluster.service.ClusterService;
@@ -27,7 +26,6 @@
2726
import org.elasticsearch.inference.SettingsConfiguration;
2827
import org.elasticsearch.inference.TaskType;
2928
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
30-
import org.elasticsearch.rest.RestStatus;
3129
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
3230
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
3331
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
@@ -55,7 +53,7 @@
5553
import java.util.Set;
5654

5755
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
58-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
56+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
5957
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
6058
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
6159
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
@@ -178,14 +176,7 @@ public void parseRequestConfig(
178176
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
179177
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
180178

181-
Ai21Model model = createModel(
182-
modelId,
183-
taskType,
184-
serviceSettingsMap,
185-
serviceSettingsMap,
186-
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
187-
ConfigurationParseContext.REQUEST
188-
);
179+
Ai21Model model = createModel(modelId, taskType, serviceSettingsMap, serviceSettingsMap, ConfigurationParseContext.REQUEST);
189180

190181
throwIfNotEmptyMap(config, NAME);
191182
throwIfNotEmptyMap(serviceSettingsMap, NAME);
@@ -208,21 +199,15 @@ public Ai21Model parsePersistedConfigWithSecrets(
208199
removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
209200
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
210201

211-
return createModelFromPersistent(
212-
modelId,
213-
taskType,
214-
serviceSettingsMap,
215-
secretSettingsMap,
216-
parsePersistedConfigErrorMsg(modelId, NAME)
217-
);
202+
return createModelFromPersistent(modelId, taskType, serviceSettingsMap, secretSettingsMap);
218203
}
219204

220205
@Override
221206
public Ai21Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
222207
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
223208
removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
224209

225-
return createModelFromPersistent(modelId, taskType, serviceSettingsMap, null, parsePersistedConfigErrorMsg(modelId, NAME));
210+
return createModelFromPersistent(modelId, taskType, serviceSettingsMap, null);
226211
}
227212

228213
@Override
@@ -240,32 +225,23 @@ private static Ai21Model createModel(
240225
TaskType taskType,
241226
Map<String, Object> serviceSettings,
242227
@Nullable Map<String, Object> secretSettings,
243-
String failureMessage,
244228
ConfigurationParseContext context
245229
) {
246230
switch (taskType) {
247231
case CHAT_COMPLETION, COMPLETION:
248232
return new Ai21ChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context);
249233
default:
250-
throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
234+
throw createInvalidTaskTypeException(modelId, NAME, taskType, context);
251235
}
252236
}
253237

254238
private Ai21Model createModelFromPersistent(
255239
String inferenceEntityId,
256240
TaskType taskType,
257241
Map<String, Object> serviceSettings,
258-
Map<String, Object> secretSettings,
259-
String failureMessage
242+
Map<String, Object> secretSettings
260243
) {
261-
return createModel(
262-
inferenceEntityId,
263-
taskType,
264-
serviceSettings,
265-
secretSettings,
266-
failureMessage,
267-
ConfigurationParseContext.PERSISTENT
268-
);
244+
return createModel(inferenceEntityId, taskType, serviceSettings, secretSettings, ConfigurationParseContext.PERSISTENT);
269245
}
270246

271247
/**

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.services.alibabacloudsearch;
99

10-
import org.elasticsearch.ElasticsearchStatusException;
1110
import org.elasticsearch.TransportVersion;
1211
import org.elasticsearch.TransportVersions;
1312
import org.elasticsearch.action.ActionListener;
@@ -32,7 +31,6 @@
3231
import org.elasticsearch.inference.SimilarityMeasure;
3332
import org.elasticsearch.inference.TaskType;
3433
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
35-
import org.elasticsearch.rest.RestStatus;
3634
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
3735
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
3836
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -59,7 +57,7 @@
5957
import java.util.Map;
6058

6159
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
62-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
60+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
6361
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
6462
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
6563
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -135,7 +133,6 @@ public void parseRequestConfig(
135133
taskSettingsMap,
136134
chunkingSettings,
137135
serviceSettingsMap,
138-
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
139136
ConfigurationParseContext.REQUEST
140137
);
141138

@@ -165,8 +162,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations(
165162
Map<String, Object> serviceSettings,
166163
Map<String, Object> taskSettings,
167164
ChunkingSettings chunkingSettings,
168-
@Nullable Map<String, Object> secretSettings,
169-
String failureMessage
165+
@Nullable Map<String, Object> secretSettings
170166
) {
171167
return createModel(
172168
inferenceEntityId,
@@ -175,7 +171,6 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations(
175171
taskSettings,
176172
chunkingSettings,
177173
secretSettings,
178-
failureMessage,
179174
ConfigurationParseContext.PERSISTENT
180175
);
181176
}
@@ -187,7 +182,6 @@ private static AlibabaCloudSearchModel createModel(
187182
Map<String, Object> taskSettings,
188183
ChunkingSettings chunkingSettings,
189184
@Nullable Map<String, Object> secretSettings,
190-
String failureMessage,
191185
ConfigurationParseContext context
192186
) {
193187
return switch (taskType) {
@@ -229,7 +223,7 @@ private static AlibabaCloudSearchModel createModel(
229223
secretSettings,
230224
context
231225
);
232-
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
226+
default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
233227
};
234228
}
235229

@@ -255,8 +249,7 @@ public AlibabaCloudSearchModel parsePersistedConfigWithSecrets(
255249
serviceSettingsMap,
256250
taskSettingsMap,
257251
chunkingSettings,
258-
secretSettingsMap,
259-
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
252+
secretSettingsMap
260253
);
261254
}
262255

@@ -276,8 +269,7 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta
276269
serviceSettingsMap,
277270
taskSettingsMap,
278271
chunkingSettings,
279-
null,
280-
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
272+
null
281273
);
282274
}
283275

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060

6161
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
6262
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
63-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
63+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
6464
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
6565
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
6666
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -204,7 +204,6 @@ public void parseRequestConfig(
204204
taskSettingsMap,
205205
chunkingSettings,
206206
serviceSettingsMap,
207-
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
208207
ConfigurationParseContext.REQUEST
209208
);
210209

@@ -241,7 +240,6 @@ public Model parsePersistedConfigWithSecrets(
241240
taskSettingsMap,
242241
chunkingSettings,
243242
secretSettingsMap,
244-
parsePersistedConfigErrorMsg(modelId, NAME),
245243
ConfigurationParseContext.PERSISTENT
246244
);
247245
}
@@ -263,7 +261,6 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String,
263261
taskSettingsMap,
264262
chunkingSettings,
265263
null,
266-
parsePersistedConfigErrorMsg(modelId, NAME),
267264
ConfigurationParseContext.PERSISTENT
268265
);
269266
}
@@ -285,7 +282,6 @@ private static AmazonBedrockModel createModel(
285282
Map<String, Object> taskSettings,
286283
ChunkingSettings chunkingSettings,
287284
@Nullable Map<String, Object> secretSettings,
288-
String failureMessage,
289285
ConfigurationParseContext context
290286
) {
291287
switch (taskType) {
@@ -318,7 +314,7 @@ private static AmazonBedrockModel createModel(
318314
checkChatCompletionProviderForTopKParameter(model);
319315
return model;
320316
}
321-
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
317+
default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
322318
}
323319
}
324320

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.services.anthropic;
99

10-
import org.elasticsearch.ElasticsearchStatusException;
1110
import org.elasticsearch.TransportVersion;
1211
import org.elasticsearch.TransportVersions;
1312
import org.elasticsearch.action.ActionListener;
@@ -28,7 +27,6 @@
2827
import org.elasticsearch.inference.SettingsConfiguration;
2928
import org.elasticsearch.inference.TaskType;
3029
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
31-
import org.elasticsearch.rest.RestStatus;
3230
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
3331
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
3432
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@@ -48,7 +46,7 @@
4846

4947
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
5048
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
51-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
49+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
5250
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
5351
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
5452
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
@@ -94,7 +92,6 @@ public void parseRequestConfig(
9492
serviceSettingsMap,
9593
taskSettingsMap,
9694
serviceSettingsMap,
97-
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
9895
ConfigurationParseContext.REQUEST
9996
);
10097

@@ -113,16 +110,14 @@ private static AnthropicModel createModelFromPersistent(
113110
TaskType taskType,
114111
Map<String, Object> serviceSettings,
115112
Map<String, Object> taskSettings,
116-
@Nullable Map<String, Object> secretSettings,
117-
String failureMessage
113+
@Nullable Map<String, Object> secretSettings
118114
) {
119115
return createModel(
120116
inferenceEntityId,
121117
taskType,
122118
serviceSettings,
123119
taskSettings,
124120
secretSettings,
125-
failureMessage,
126121
ConfigurationParseContext.PERSISTENT
127122
);
128123
}
@@ -133,7 +128,6 @@ private static AnthropicModel createModel(
133128
Map<String, Object> serviceSettings,
134129
Map<String, Object> taskSettings,
135130
@Nullable Map<String, Object> secretSettings,
136-
String failureMessage,
137131
ConfigurationParseContext context
138132
) {
139133
return switch (taskType) {
@@ -146,7 +140,7 @@ private static AnthropicModel createModel(
146140
secretSettings,
147141
context
148142
);
149-
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
143+
default -> throw createInvalidTaskTypeException(inferenceEntityId, NAME, taskType, context);
150144
};
151145
}
152146

@@ -161,29 +155,15 @@ public AnthropicModel parsePersistedConfigWithSecrets(
161155
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
162156
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
163157

164-
return createModelFromPersistent(
165-
inferenceEntityId,
166-
taskType,
167-
serviceSettingsMap,
168-
taskSettingsMap,
169-
secretSettingsMap,
170-
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
171-
);
158+
return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap);
172159
}
173160

174161
@Override
175162
public AnthropicModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
176163
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
177164
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
178165

179-
return createModelFromPersistent(
180-
inferenceEntityId,
181-
taskType,
182-
serviceSettingsMap,
183-
taskSettingsMap,
184-
null,
185-
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
186-
);
166+
return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null);
187167
}
188168

189169
@Override

0 commit comments

Comments
 (0)