Skip to content

Commit 45966dd

Browse files
Refactor Llama Integration to support OpenAI API for embeddings and support user field
1 parent a665cd4 commit 45966dd

21 files changed

+469
-245
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ static TransportVersion def(int id) {
353353
public static final TransportVersion ESQL_TOPN_TIMINGS = def(9_128_0_00);
354354
public static final TransportVersion NODE_WEIGHTS_ADDED_TO_NODE_BALANCE_STATS = def(9_129_0_00);
355355
public static final TransportVersion RERANK_SNIPPETS = def(9_130_0_00);
356+
public static final TransportVersion ML_INFERENCE_LLAMA_OPEN_AI_API_FIX = def(9_131_0_00);
356357

357358
/*
358359
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.inference.ModelSecrets;
1313
import org.elasticsearch.inference.SecretSettings;
1414
import org.elasticsearch.inference.ServiceSettings;
15+
import org.elasticsearch.inference.TaskSettings;
1516
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1617
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
1718
import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionVisitor;
@@ -46,10 +47,21 @@ protected LlamaModel(ModelConfigurations configurations, ModelSecrets secrets) {
4647
* @param model the model configurations
4748
* @param serviceSettings the settings for the inference service
4849
*/
49-
protected LlamaModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) {
50+
protected LlamaModel(LlamaModel model, ServiceSettings serviceSettings) {
5051
super(model, serviceSettings);
5152
}
5253

54+
/**
55+
* Constructor for creating a LlamaModel with specified model, service settings, and secret settings.
56+
* @param model the model configurations
57+
* @param taskSettings the settings for the task
58+
*/
59+
protected LlamaModel(LlamaModel model, TaskSettings taskSettings) {
60+
super(model, taskSettings);
61+
this.uri = model.uri;
62+
this.rateLimitSettings = model.rateLimitSettings;
63+
}
64+
5365
public URI uri() {
5466
return this.uri;
5567
}
@@ -85,5 +97,5 @@ protected static SecretSettings retrieveSecretSettings(Map<String, Object> secre
8597
return (secrets != null && secrets.isEmpty()) ? EmptySecretSettings.INSTANCE : DefaultSecretSettings.fromMap(secrets);
8698
}
8799

88-
protected abstract ExecutableAction accept(LlamaActionVisitor creator);
100+
protected abstract ExecutableAction accept(LlamaActionVisitor creator, Map<String, Object> taskSettings);
89101
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ protected void doInfer(
118118
) {
119119
var actionCreator = new LlamaActionCreator(getSender(), getServiceComponents());
120120
if (model instanceof LlamaModel llamaModel) {
121-
llamaModel.accept(actionCreator).execute(inputs, timeout, listener);
121+
llamaModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener);
122122
} else {
123123
listener.onFailure(createInvalidModelException(model));
124124
}
@@ -145,16 +145,26 @@ protected LlamaModel createModel(
145145
String inferenceId,
146146
TaskType taskType,
147147
Map<String, Object> serviceSettings,
148+
Map<String, Object> taskSettings,
148149
ChunkingSettings chunkingSettings,
149150
Map<String, Object> secretSettings,
150151
String failureMessage,
151152
ConfigurationParseContext context
152153
) {
153154
switch (taskType) {
154155
case TEXT_EMBEDDING:
155-
return new LlamaEmbeddingsModel(inferenceId, taskType, NAME, serviceSettings, chunkingSettings, secretSettings, context);
156+
return new LlamaEmbeddingsModel(
157+
inferenceId,
158+
taskType,
159+
NAME,
160+
serviceSettings,
161+
taskSettings,
162+
chunkingSettings,
163+
secretSettings,
164+
context
165+
);
156166
case CHAT_COMPLETION, COMPLETION:
157-
return new LlamaChatCompletionModel(inferenceId, taskType, NAME, serviceSettings, secretSettings, context);
167+
return new LlamaChatCompletionModel(inferenceId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context);
158168
default:
159169
throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
160170
}
@@ -173,6 +183,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
173183
embeddingSize,
174184
similarityToUse,
175185
serviceSettings.maxInputTokens(),
186+
serviceSettings.dimensionsSetByUser(),
176187
serviceSettings.rateLimitSettings()
177188
);
178189

@@ -206,7 +217,7 @@ protected void doChunkedInfer(
206217
).batchRequestsWithListeners(listener);
207218

208219
for (var request : batchedRequests) {
209-
var action = llamaModel.accept(actionCreator);
220+
var action = llamaModel.accept(actionCreator, taskSettings);
210221
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
211222
}
212223
}
@@ -280,6 +291,7 @@ public void parseRequestConfig(
280291
modelId,
281292
taskType,
282293
serviceSettingsMap,
294+
taskSettingsMap,
283295
chunkingSettings,
284296
serviceSettingsMap,
285297
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
@@ -304,7 +316,7 @@ public Model parsePersistedConfigWithSecrets(
304316
Map<String, Object> secrets
305317
) {
306318
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
307-
removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
319+
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
308320
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
309321

310322
ChunkingSettings chunkingSettings = null;
@@ -316,6 +328,7 @@ public Model parsePersistedConfigWithSecrets(
316328
modelId,
317329
taskType,
318330
serviceSettingsMap,
331+
taskSettingsMap,
319332
chunkingSettings,
320333
secretSettingsMap,
321334
parsePersistedConfigErrorMsg(modelId, NAME)
@@ -326,6 +339,7 @@ private LlamaModel createModelFromPersistent(
326339
String inferenceEntityId,
327340
TaskType taskType,
328341
Map<String, Object> serviceSettings,
342+
Map<String, Object> taskSettings,
329343
ChunkingSettings chunkingSettings,
330344
Map<String, Object> secretSettings,
331345
String failureMessage
@@ -334,6 +348,7 @@ private LlamaModel createModelFromPersistent(
334348
inferenceEntityId,
335349
taskType,
336350
serviceSettings,
351+
taskSettings,
337352
chunkingSettings,
338353
secretSettings,
339354
failureMessage,
@@ -344,7 +359,7 @@ private LlamaModel createModelFromPersistent(
344359
@Override
345360
public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
346361
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
347-
removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
362+
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
348363

349364
ChunkingSettings chunkingSettings = null;
350365
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
@@ -355,6 +370,7 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String,
355370
modelId,
356371
taskType,
357372
serviceSettingsMap,
373+
taskSettingsMap,
358374
chunkingSettings,
359375
null,
360376
parsePersistedConfigErrorMsg(modelId, NAME)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@
1818
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1919
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
2020
import org.elasticsearch.xpack.inference.services.ServiceComponents;
21-
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity;
2221
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel;
2322
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaCompletionResponseHandler;
2423
import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel;
2524
import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsResponseHandler;
2625
import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequest;
2726
import org.elasticsearch.xpack.inference.services.llama.request.embeddings.LlamaEmbeddingsRequest;
2827
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
28+
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiEmbeddingsResponseEntity;
2929

30+
import java.util.Map;
3031
import java.util.Objects;
3132

3233
import static org.elasticsearch.core.Strings.format;
@@ -44,7 +45,7 @@ public class LlamaActionCreator implements LlamaActionVisitor {
4445

4546
private static final ResponseHandler EMBEDDINGS_HANDLER = new LlamaEmbeddingsResponseHandler(
4647
"llama text embedding",
47-
HuggingFaceEmbeddingsResponseEntity::fromResponse
48+
OpenAiEmbeddingsResponseEntity::fromResponse
4849
);
4950
private static final ResponseHandler COMPLETION_HANDLER = new LlamaCompletionResponseHandler(
5051
"llama completion",
@@ -66,34 +67,36 @@ public LlamaActionCreator(Sender sender, ServiceComponents serviceComponents) {
6667
}
6768

6869
@Override
69-
public ExecutableAction create(LlamaEmbeddingsModel model) {
70+
public ExecutableAction create(LlamaEmbeddingsModel model, Map<String, Object> taskSettings) {
71+
var overriddenModel = LlamaEmbeddingsModel.of(model, taskSettings);
7072
var manager = new GenericRequestManager<>(
7173
serviceComponents.threadPool(),
72-
model,
74+
overriddenModel,
7375
EMBEDDINGS_HANDLER,
7476
embeddingsInput -> new LlamaEmbeddingsRequest(
7577
serviceComponents.truncator(),
76-
truncate(embeddingsInput.getStringInputs(), model.getServiceSettings().maxInputTokens()),
77-
model
78+
truncate(embeddingsInput.getStringInputs(), overriddenModel.getServiceSettings().maxInputTokens()),
79+
overriddenModel
7880
),
7981
EmbeddingsInput.class
8082
);
8183

82-
var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId());
84+
var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, overriddenModel.getInferenceEntityId());
8385
return new SenderExecutableAction(sender, manager, errorMessage);
8486
}
8587

8688
@Override
87-
public ExecutableAction create(LlamaChatCompletionModel model) {
89+
public ExecutableAction create(LlamaChatCompletionModel model, Map<String, Object> taskSettings) {
90+
var overriddenModel = LlamaChatCompletionModel.of(model, taskSettings);
8891
var manager = new GenericRequestManager<>(
8992
serviceComponents.threadPool(),
90-
model,
93+
overriddenModel,
9194
COMPLETION_HANDLER,
92-
inputs -> new LlamaChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
95+
inputs -> new LlamaChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), overriddenModel),
9396
ChatCompletionInput.class
9497
);
9598

96-
var errorMessage = buildErrorMessage(TaskType.COMPLETION, model.getInferenceEntityId());
99+
var errorMessage = buildErrorMessage(TaskType.COMPLETION, overriddenModel.getInferenceEntityId());
97100
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX);
98101
}
99102

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionVisitor.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel;
1212
import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel;
1313

14+
import java.util.Map;
15+
1416
/**
1517
* Visitor interface for creating executable actions for Llama inference models.
1618
* This interface defines methods to create actions for both embeddings and chat completion models.
@@ -20,15 +22,17 @@ public interface LlamaActionVisitor {
2022
* Creates an executable action for the given Llama embeddings model.
2123
*
2224
* @param model the Llama embeddings model
25+
* @param taskSettings the settings for the task, which may include parameters like user
2326
* @return an executable action for the embeddings model
2427
*/
25-
ExecutableAction create(LlamaEmbeddingsModel model);
28+
ExecutableAction create(LlamaEmbeddingsModel model, Map<String, Object> taskSettings);
2629

2730
/**
2831
* Creates an executable action for the given Llama chat completion model.
2932
*
3033
* @param model the Llama chat completion model
34+
* @param taskSettings the settings for the task, which may include parameters like user
3135
* @return an executable action for the chat completion model
3236
*/
33-
ExecutableAction create(LlamaChatCompletionModel model);
37+
ExecutableAction create(LlamaChatCompletionModel model, Map<String, Object> taskSettings);
3438
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModel.java

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.services.llama.completion;
99

10-
import org.elasticsearch.inference.EmptyTaskSettings;
1110
import org.elasticsearch.inference.ModelConfigurations;
1211
import org.elasticsearch.inference.ModelSecrets;
1312
import org.elasticsearch.inference.SecretSettings;
@@ -17,6 +16,8 @@
1716
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
1817
import org.elasticsearch.xpack.inference.services.llama.LlamaModel;
1918
import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionVisitor;
19+
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionRequestTaskSettings;
20+
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
2021

2122
import java.util.Map;
2223

@@ -40,6 +41,7 @@ public LlamaChatCompletionModel(
4041
TaskType taskType,
4142
String service,
4243
Map<String, Object> serviceSettings,
44+
Map<String, Object> taskSettings,
4345
Map<String, Object> secrets,
4446
ConfigurationParseContext context
4547
) {
@@ -48,6 +50,7 @@ public LlamaChatCompletionModel(
4850
taskType,
4951
service,
5052
LlamaChatCompletionServiceSettings.fromMap(serviceSettings, context),
53+
OpenAiChatCompletionTaskSettings.fromMap(taskSettings),
5154
retrieveSecretSettings(secrets)
5255
);
5356
}
@@ -65,15 +68,26 @@ public LlamaChatCompletionModel(
6568
TaskType taskType,
6669
String service,
6770
LlamaChatCompletionServiceSettings serviceSettings,
71+
OpenAiChatCompletionTaskSettings taskSettings,
6872
SecretSettings secrets
6973
) {
70-
super(
71-
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE),
72-
new ModelSecrets(secrets)
73-
);
74+
super(new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secrets));
7475
setPropertiesFromServiceSettings(serviceSettings);
7576
}
7677

78+
public static LlamaChatCompletionModel of(LlamaChatCompletionModel model, Map<String, Object> taskSettings) {
79+
if (taskSettings == null || taskSettings.isEmpty()) {
80+
return model;
81+
}
82+
83+
var requestTaskSettings = OpenAiChatCompletionRequestTaskSettings.fromMap(taskSettings);
84+
return new LlamaChatCompletionModel(model, OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings));
85+
}
86+
87+
private LlamaChatCompletionModel(LlamaChatCompletionModel originalModel, OpenAiChatCompletionTaskSettings taskSettings) {
88+
super(originalModel, taskSettings);
89+
}
90+
7791
/**
7892
* Factory method to create a LlamaChatCompletionModel with overridden model settings based on the request.
7993
* If the request does not specify a model, the original model is returned.
@@ -100,6 +114,7 @@ public static LlamaChatCompletionModel of(LlamaChatCompletionModel model, Unifie
100114
model.getTaskType(),
101115
model.getConfigurations().getService(),
102116
overriddenServiceSettings,
117+
model.getTaskSettings(),
103118
model.getSecretSettings()
104119
);
105120
}
@@ -126,7 +141,12 @@ public LlamaChatCompletionServiceSettings getServiceSettings() {
126141
* @return an ExecutableAction representing this model
127142
*/
128143
@Override
129-
public ExecutableAction accept(LlamaActionVisitor creator) {
130-
return creator.create(this);
144+
public ExecutableAction accept(LlamaActionVisitor creator, Map<String, Object> taskSettings) {
145+
return creator.create(this, taskSettings);
146+
}
147+
148+
@Override
149+
public OpenAiChatCompletionTaskSettings getTaskSettings() {
150+
return (OpenAiChatCompletionTaskSettings) super.getTaskSettings();
131151
}
132152
}

0 commit comments

Comments
 (0)