Skip to content

Commit 76ddf99

Browse files
Refactor Llama and Mistral models to remove taskSettings parameter and simplify model instantiation
1 parent eb60dfa commit 76ddf99

File tree

8 files changed

+9
-55
lines changed

8 files changed

+9
-55
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc
124124
* @param inferenceId the unique identifier for the inference entity
125125
* @param taskType the type of task this model is designed for
126126
* @param serviceSettings the settings for the inference service
127-
* @param taskSettings the settings specific to the task
128127
* @param chunkingSettings the settings for chunking, if applicable
129128
* @param secretSettings the secret settings for the model, such as API keys or tokens
130129
* @param failureMessage the message to use in case of failure
@@ -135,24 +134,14 @@ protected LlamaModel createModel(
135134
String inferenceId,
136135
TaskType taskType,
137136
Map<String, Object> serviceSettings,
138-
Map<String, Object> taskSettings,
139137
ChunkingSettings chunkingSettings,
140138
Map<String, Object> secretSettings,
141139
String failureMessage,
142140
ConfigurationParseContext context
143141
) {
144142
switch (taskType) {
145143
case TEXT_EMBEDDING:
146-
return new LlamaEmbeddingsModel(
147-
inferenceId,
148-
taskType,
149-
NAME,
150-
serviceSettings,
151-
taskSettings,
152-
chunkingSettings,
153-
secretSettings,
154-
context
155-
);
144+
return new LlamaEmbeddingsModel(inferenceId, taskType, NAME, serviceSettings, chunkingSettings, secretSettings, context);
156145
case CHAT_COMPLETION, COMPLETION:
157146
return new LlamaChatCompletionModel(inferenceId, taskType, NAME, serviceSettings, secretSettings, context);
158147
default:
@@ -280,7 +269,6 @@ public void parseRequestConfig(
280269
modelId,
281270
taskType,
282271
serviceSettingsMap,
283-
taskSettingsMap,
284272
chunkingSettings,
285273
serviceSettingsMap,
286274
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
@@ -305,7 +293,7 @@ public Model parsePersistedConfigWithSecrets(
305293
Map<String, Object> secrets
306294
) {
307295
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
308-
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
296+
removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
309297
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
310298

311299
ChunkingSettings chunkingSettings = null;
@@ -317,7 +305,6 @@ public Model parsePersistedConfigWithSecrets(
317305
modelId,
318306
taskType,
319307
serviceSettingsMap,
320-
taskSettingsMap,
321308
chunkingSettings,
322309
secretSettingsMap,
323310
parsePersistedConfigErrorMsg(modelId, NAME)
@@ -328,7 +315,6 @@ private LlamaModel createModelFromPersistent(
328315
String inferenceEntityId,
329316
TaskType taskType,
330317
Map<String, Object> serviceSettings,
331-
Map<String, Object> taskSettings,
332318
ChunkingSettings chunkingSettings,
333319
Map<String, Object> secretSettings,
334320
String failureMessage
@@ -337,7 +323,6 @@ private LlamaModel createModelFromPersistent(
337323
inferenceEntityId,
338324
taskType,
339325
serviceSettings,
340-
taskSettings,
341326
chunkingSettings,
342327
secretSettings,
343328
failureMessage,
@@ -348,7 +333,7 @@ private LlamaModel createModelFromPersistent(
348333
@Override
349334
public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
350335
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
351-
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
336+
removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
352337

353338
ChunkingSettings chunkingSettings = null;
354339
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
@@ -359,7 +344,6 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String,
359344
modelId,
360345
taskType,
361346
serviceSettingsMap,
362-
taskSettingsMap,
363347
chunkingSettings,
364348
null,
365349
parsePersistedConfigErrorMsg(modelId, NAME)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ public LlamaChatCompletionModel(
6868
SecretSettings secrets
6969
) {
7070
super(
71-
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings()),
71+
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE),
7272
new ModelSecrets(secrets)
7373
);
7474
setPropertiesFromServiceSettings(serviceSettings);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModel.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import org.elasticsearch.inference.ModelConfigurations;
1313
import org.elasticsearch.inference.ModelSecrets;
1414
import org.elasticsearch.inference.SecretSettings;
15-
import org.elasticsearch.inference.TaskSettings;
1615
import org.elasticsearch.inference.TaskType;
1716
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1817
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
@@ -42,7 +41,6 @@ public LlamaEmbeddingsModel(
4241
TaskType taskType,
4342
String service,
4443
Map<String, Object> serviceSettings,
45-
Map<String, Object> taskSettings,
4644
ChunkingSettings chunkingSettings,
4745
Map<String, Object> secrets,
4846
ConfigurationParseContext context
@@ -52,7 +50,6 @@ public LlamaEmbeddingsModel(
5250
taskType,
5351
service,
5452
LlamaEmbeddingsServiceSettings.fromMap(serviceSettings, context),
55-
EmptyTaskSettings.INSTANCE, // no task settings for Llama embeddings
5653
chunkingSettings,
5754
retrieveSecretSettings(secrets)
5855
);
@@ -86,7 +83,6 @@ private void setPropertiesFromServiceSettings(LlamaEmbeddingsServiceSettings ser
8683
* @param taskType the type of task this model is designed for
8784
* @param service the name of the inference service
8885
* @param serviceSettings the settings for the inference service, specific to embeddings
89-
* @param taskSettings the task settings for the model
9086
* @param chunkingSettings the chunking settings for processing input data
9187
* @param secrets the secret settings for the model, such as API keys or tokens
9288
*/
@@ -95,7 +91,6 @@ public LlamaEmbeddingsModel(
9591
TaskType taskType,
9692
String service,
9793
LlamaEmbeddingsServiceSettings serviceSettings,
98-
TaskSettings taskSettings,
9994
ChunkingSettings chunkingSettings,
10095
SecretSettings secrets
10196
) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ public void parseRequestConfig(
201201
modelId,
202202
taskType,
203203
serviceSettingsMap,
204-
taskSettingsMap,
205204
chunkingSettings,
206205
serviceSettingsMap,
207206
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
@@ -226,7 +225,7 @@ public MistralModel parsePersistedConfigWithSecrets(
226225
Map<String, Object> secrets
227226
) {
228227
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
229-
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
228+
removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
230229
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
231230

232231
ChunkingSettings chunkingSettings = null;
@@ -238,7 +237,6 @@ public MistralModel parsePersistedConfigWithSecrets(
238237
modelId,
239238
taskType,
240239
serviceSettingsMap,
241-
taskSettingsMap,
242240
chunkingSettings,
243241
secretSettingsMap,
244242
parsePersistedConfigErrorMsg(modelId, NAME)
@@ -248,7 +246,7 @@ public MistralModel parsePersistedConfigWithSecrets(
248246
@Override
249247
public MistralModel parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
250248
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
251-
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
249+
removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
252250

253251
ChunkingSettings chunkingSettings = null;
254252
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
@@ -259,7 +257,6 @@ public MistralModel parsePersistedConfig(String modelId, TaskType taskType, Map<
259257
modelId,
260258
taskType,
261259
serviceSettingsMap,
262-
taskSettingsMap,
263260
chunkingSettings,
264261
null,
265262
parsePersistedConfigErrorMsg(modelId, NAME)
@@ -280,24 +277,14 @@ private static MistralModel createModel(
280277
String modelId,
281278
TaskType taskType,
282279
Map<String, Object> serviceSettings,
283-
Map<String, Object> taskSettings,
284280
ChunkingSettings chunkingSettings,
285281
@Nullable Map<String, Object> secretSettings,
286282
String failureMessage,
287283
ConfigurationParseContext context
288284
) {
289285
switch (taskType) {
290286
case TEXT_EMBEDDING:
291-
return new MistralEmbeddingsModel(
292-
modelId,
293-
taskType,
294-
NAME,
295-
serviceSettings,
296-
taskSettings,
297-
chunkingSettings,
298-
secretSettings,
299-
context
300-
);
287+
return new MistralEmbeddingsModel(modelId, taskType, NAME, serviceSettings, chunkingSettings, secretSettings, context);
301288
case CHAT_COMPLETION, COMPLETION:
302289
return new MistralChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context);
303290
default:
@@ -309,7 +296,6 @@ private MistralModel createModelFromPersistent(
309296
String inferenceEntityId,
310297
TaskType taskType,
311298
Map<String, Object> serviceSettings,
312-
Map<String, Object> taskSettings,
313299
ChunkingSettings chunkingSettings,
314300
Map<String, Object> secretSettings,
315301
String failureMessage
@@ -318,7 +304,6 @@ private MistralModel createModelFromPersistent(
318304
inferenceEntityId,
319305
taskType,
320306
serviceSettings,
321-
taskSettings,
322307
chunkingSettings,
323308
secretSettings,
324309
failureMessage,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ public MistralChatCompletionModel(
9494
DefaultSecretSettings secrets
9595
) {
9696
super(
97-
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings()),
97+
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE),
9898
new ModelSecrets(secrets)
9999
);
100100
setPropertiesFromServiceSettings(serviceSettings);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import org.elasticsearch.inference.EmptyTaskSettings;
1313
import org.elasticsearch.inference.ModelConfigurations;
1414
import org.elasticsearch.inference.ModelSecrets;
15-
import org.elasticsearch.inference.TaskSettings;
1615
import org.elasticsearch.inference.TaskType;
1716
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1817
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
@@ -37,7 +36,6 @@ public MistralEmbeddingsModel(
3736
TaskType taskType,
3837
String service,
3938
Map<String, Object> serviceSettings,
40-
Map<String, Object> taskSettings,
4139
ChunkingSettings chunkingSettings,
4240
@Nullable Map<String, Object> secrets,
4341
ConfigurationParseContext context
@@ -47,7 +45,6 @@ public MistralEmbeddingsModel(
4745
taskType,
4846
service,
4947
MistralEmbeddingsServiceSettings.fromMap(serviceSettings, context),
50-
EmptyTaskSettings.INSTANCE, // no task settings for Mistral embeddings
5148
chunkingSettings,
5249
DefaultSecretSettings.fromMap(secrets)
5350
);
@@ -76,12 +73,11 @@ public MistralEmbeddingsModel(
7673
TaskType taskType,
7774
String service,
7875
MistralEmbeddingsServiceSettings serviceSettings,
79-
TaskSettings taskSettings,
8076
ChunkingSettings chunkingSettings,
8177
DefaultSecretSettings secrets
8278
) {
8379
super(
84-
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings),
80+
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings),
8581
new ModelSecrets(secrets)
8682
);
8783
setPropertiesFromServiceSettings(serviceSettings);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModelTests.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.elasticsearch.common.settings.SecureString;
1111
import org.elasticsearch.inference.EmptySecretSettings;
12-
import org.elasticsearch.inference.EmptyTaskSettings;
1312
import org.elasticsearch.inference.TaskType;
1413
import org.elasticsearch.test.ESTestCase;
1514
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
@@ -21,7 +20,6 @@ public static LlamaEmbeddingsModel createEmbeddingsModel(String modelId, String
2120
TaskType.TEXT_EMBEDDING,
2221
"llama",
2322
new LlamaEmbeddingsServiceSettings(modelId, url, null, null, null, null),
24-
EmptyTaskSettings.INSTANCE,
2523
null,
2624
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
2725
);
@@ -33,7 +31,6 @@ public static LlamaEmbeddingsModel createEmbeddingsModelNoAuth(String modelId, S
3331
TaskType.TEXT_EMBEDDING,
3432
"llama",
3533
new LlamaEmbeddingsServiceSettings(modelId, url, null, null, null, null),
36-
EmptyTaskSettings.INSTANCE,
3734
null,
3835
EmptySecretSettings.INSTANCE
3936
);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import org.elasticsearch.common.settings.SecureString;
1111
import org.elasticsearch.core.Nullable;
1212
import org.elasticsearch.inference.ChunkingSettings;
13-
import org.elasticsearch.inference.EmptyTaskSettings;
1413
import org.elasticsearch.inference.SimilarityMeasure;
1514
import org.elasticsearch.inference.TaskType;
1615
import org.elasticsearch.test.ESTestCase;
@@ -37,7 +36,6 @@ public static MistralEmbeddingsModel createModel(
3736
TaskType.TEXT_EMBEDDING,
3837
"mistral",
3938
new MistralEmbeddingsServiceSettings(model, dimensions, maxTokens, similarity, rateLimitSettings),
40-
EmptyTaskSettings.INSTANCE,
4139
chunkingSettings,
4240
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
4341
);
@@ -57,7 +55,6 @@ public static MistralEmbeddingsModel createModel(
5755
TaskType.TEXT_EMBEDDING,
5856
"mistral",
5957
new MistralEmbeddingsServiceSettings(model, dimensions, maxTokens, similarity, rateLimitSettings),
60-
EmptyTaskSettings.INSTANCE,
6158
null,
6259
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
6360
);

0 commit comments

Comments
 (0)