Skip to content

Commit beb18a8

Browse files
Jan-Kazlouski-elasticelasticsearchmachine
andauthored
Add Llama support to Inference Plugin (#130092)
* Refactor Hugging Face service settings and completion request methods for consistency * Add Llama model support for embeddings and chat completions * Refactor Llama request classes to improve secret settings handling * Refactor DeltaParser in LlamaStreamingProcessor to improve argument handling * Enhance Llama streaming processing by adding support for nullable object arrays * [CI] Auto commit changes from spotless * Fix error messages in LlamaActionCreator * [CI] Auto commit changes from spotless * Add detailed Javadoc comments to Llama classes for improved documentation * Enhance LlamaChatCompletionResponseHandler to support mid-stream error handling and improve error response parsing * Add Javadoc comments to Llama classes for improved documentation and clarity * Fix checkstyle * Update LlamaEmbeddingsRequest to use mediaTypeWithoutParameters for content type header * Add unit tests for LlamaActionCreator and related models * Add unit tests for LlamaChatCompletionServiceSettings to validate configuration parsing and serialization * Add unit tests for LlamaEmbeddingsServiceSettings to validate configuration parsing and serialization * Add unit tests for LlamaEmbeddingsServiceSettings to validate various configuration scenarios * Add unit tests for LlamaChatCompletionResponseHandler to validate error response handling * Refactor Llama embedding and chat completion tests for consistency and clarity * Add unit tests for LlamaChatCompletionRequestEntity to validate message serialization * Add unit tests for LlamaEmbeddingsRequest to validate request creation and truncation behavior * Add unit tests for LlamaEmbeddingsRequestEntity to validate XContent serialization * Add unit tests for LlamaErrorResponse to validate error handling from HTTP responses * Add unit tests for LlamaChatCompletionServiceSettings to validate configuration parsing and serialization * Add tests for LlamaService request configuration validation and error handling * Fix error message formatting in LlamaServiceTests for better localization support * Refactor Llama model classes to implement accept method for action visitors * Hide Llama service from configuration API to enhance security and reduce exposure * Refactor Llama model classes to remove modelId and update embedding request handling * Refactor Llama request classes to use pattern matching for secret settings * Update embeddings handler to use HuggingFace response entity * Refactor Mistral model classes to remove modelId and update rate limit hashing * Refactor Mistral action classes to remove taskSettings parameter and streamline action creation * Refactor Llama and Mistral models to remove taskSettings parameter and simplify model instantiation * Refactor Llama service tests to use Model instead of CustomModel and update similarity measure to DOT_PRODUCT * Remove unused tests and imports from LlamaServiceTests * Add chunking settings support to Llama embeddings model tests * Add changelog * Add support for version checks in Llama settings and define new transport version * Refactor Llama model assertions and remove unused version support methods * Refactor Llama service constructors to include ClusterService and improve error message handling --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent f664cf5 commit beb18a8

File tree

54 files changed

+4517
-92
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+4517
-92
lines changed

docs/changelog/130092.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 130092
2+
summary: "Added Llama provider support to the Inference Plugin"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

libs/x-content/src/main/java/org/elasticsearch/xcontent/ConstructingObjectParser.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,27 @@ public <T> void declareField(BiConsumer<Value, T> consumer, ContextParser<Contex
220220
}
221221
}
222222

223+
/**
224+
* Declare a field that is an array of objects or null. Used to avoid calling the consumer when used with
225+
* {@link #optionalConstructorArg()} or {@link #constructorArg()}.
226+
* @param consumer Consumer that will be passed as is to the {@link #declareField(BiConsumer, ContextParser, ParseField, ValueType)}.
227+
* @param objectParser Parser that will parse the objects in the array, checking for nulls.
228+
* @param field Field to declare.
229+
*/
230+
@Override
231+
public <T> void declareObjectArrayOrNull(
232+
BiConsumer<Value, List<T>> consumer,
233+
ContextParser<Context, T> objectParser,
234+
ParseField field
235+
) {
236+
declareField(
237+
consumer,
238+
(p, c) -> p.currentToken() == XContentParser.Token.VALUE_NULL ? null : parseArray(p, c, objectParser),
239+
field,
240+
ValueType.OBJECT_ARRAY_OR_NULL
241+
);
242+
}
243+
223244
@Override
224245
public <T> void declareNamedObject(
225246
BiConsumer<Value, T> consumer,

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ static TransportVersion def(int id) {
343343
public static final TransportVersion ESQL_CATEGORIZE_OPTIONS = def(9_122_0_00);
344344
public static final TransportVersion ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED = def(9_123_0_00);
345345
public static final TransportVersion PROJECT_STATE_REGISTRY_ENTRY = def(9_124_0_00);
346+
public static final TransportVersion ML_INFERENCE_LLAMA_ADDED = def(9_125_0_00);
346347

347348
/*
348349
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Para
121121
* - Key: {@link #MODEL_FIELD}, Value: modelId
122122
* - Key: {@link #MAX_COMPLETION_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
123123
*/
124-
public static Params withMaxCompletionTokensTokens(String modelId, Params params) {
124+
public static Params withMaxCompletionTokens(String modelId, Params params) {
125125
return new DelegatingMapParams(
126126
Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_COMPLETION_TOKENS_FIELD)),
127127
params

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ public void testParseAllFields() throws IOException {
119119

120120
assertThat(request, is(expected));
121121
assertThat(
122-
Strings.toString(request, UnifiedCompletionRequest.withMaxCompletionTokensTokens("gpt-4o", ToXContent.EMPTY_PARAMS)),
122+
Strings.toString(request, UnifiedCompletionRequest.withMaxCompletionTokens("gpt-4o", ToXContent.EMPTY_PARAMS)),
123123
is(XContentHelper.stripWhitespace(requestJson))
124124
);
125125
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@
106106
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
107107
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings;
108108
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
109+
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionServiceSettings;
110+
import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings;
109111
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings;
110112
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
111113
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
@@ -175,6 +177,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
175177
addJinaAINamedWriteables(namedWriteables);
176178
addVoyageAINamedWriteables(namedWriteables);
177179
addCustomNamedWriteables(namedWriteables);
180+
addLlamaNamedWriteables(namedWriteables);
178181

179182
addUnifiedNamedWriteables(namedWriteables);
180183

@@ -274,8 +277,25 @@ private static void addMistralNamedWriteables(List<NamedWriteableRegistry.Entry>
274277
MistralChatCompletionServiceSettings::new
275278
)
276279
);
280+
// no task settings for Mistral
281+
}
277282

278-
// note - no task settings for Mistral embeddings...
283+
private static void addLlamaNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
284+
namedWriteables.add(
285+
new NamedWriteableRegistry.Entry(
286+
ServiceSettings.class,
287+
LlamaEmbeddingsServiceSettings.NAME,
288+
LlamaEmbeddingsServiceSettings::new
289+
)
290+
);
291+
namedWriteables.add(
292+
new NamedWriteableRegistry.Entry(
293+
ServiceSettings.class,
294+
LlamaChatCompletionServiceSettings.NAME,
295+
LlamaChatCompletionServiceSettings::new
296+
)
297+
);
298+
// no task settings for Llama
279299
}
280300

281301
private static void addAzureAiStudioNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService;
134134
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
135135
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
136+
import org.elasticsearch.xpack.inference.services.llama.LlamaService;
136137
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
137138
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
138139
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
@@ -402,6 +403,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
402403
context -> new JinaAIService(httpFactory.get(), serviceComponents.get(), context),
403404
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get(), context),
404405
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context),
406+
context -> new LlamaService(httpFactory.get(), serviceComponents.get(), context),
405407
ElasticsearchInternalService::new,
406408
context -> new CustomService(httpFactory.get(), serviceComponents.get(), context)
407409
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.core.Tuple;
1818
import org.elasticsearch.inference.InputType;
1919
import org.elasticsearch.inference.Model;
20+
import org.elasticsearch.inference.ModelConfigurations;
2021
import org.elasticsearch.inference.SimilarityMeasure;
2122
import org.elasticsearch.inference.TaskType;
2223
import org.elasticsearch.rest.RestStatus;
@@ -304,6 +305,12 @@ public static String invalidSettingError(String settingName, String scope) {
304305
return Strings.format("[%s] does not allow the setting [%s]", scope, settingName);
305306
}
306307

308+
public static URI extractUri(Map<String, Object> map, String fieldName, ValidationException validationException) {
309+
String parsedUrl = extractRequiredString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
310+
311+
return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
312+
}
313+
307314
public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) {
308315
try {
309316
return createOptionalUri(url);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInpu
2828
@Override
2929
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
3030
builder.startObject();
31-
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokensTokens(modelId, params));
31+
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokens(modelId, params));
3232
builder.endObject();
3333

3434
return builder;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,10 @@
3131
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
3232
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
3333
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
34-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri;
3534
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
3635
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
37-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
3836
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
37+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri;
3938

4039
public class HuggingFaceServiceSettings extends FilteredXContentObject implements ServiceSettings, HuggingFaceRateLimitServiceSettings {
4140
public static final String NAME = "hugging_face_service_settings";
@@ -70,12 +69,6 @@ public static HuggingFaceServiceSettings fromMap(Map<String, Object> map, Config
7069
return new HuggingFaceServiceSettings(uri, similarityMeasure, dims, maxInputTokens, rateLimitSettings);
7170
}
7271

73-
public static URI extractUri(Map<String, Object> map, String fieldName, ValidationException validationException) {
74-
String parsedUrl = extractRequiredString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
75-
76-
return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
77-
}
78-
7972
private final URI uri;
8073
private final SimilarityMeasure similarity;
8174
private final Integer dimensions;

0 commit comments

Comments
 (0)