Skip to content

Commit 727fd8e

Browse files
Add Llama model support for embeddings and chat completions
1 parent 85478cf commit 727fd8e

21 files changed

+1507
-2
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ static TransportVersion def(int id) {
324324
public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION = def(9_110_0_00);
325325
public static final TransportVersion ESQL_PROFILE_INCLUDE_PLAN = def(9_111_0_00);
326326
public static final TransportVersion MAPPINGS_IN_DATA_STREAMS = def(9_112_0_00);
327+
public static final TransportVersion ML_INFERENCE_LLAMA_ADDED = def(9_113_0_00);
327328

328329
/*
329330
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@
103103
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
104104
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings;
105105
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
106+
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionServiceSettings;
107+
import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings;
106108
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings;
107109
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
108110
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
@@ -172,6 +174,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
172174
addJinaAINamedWriteables(namedWriteables);
173175
addVoyageAINamedWriteables(namedWriteables);
174176
addCustomNamedWriteables(namedWriteables);
177+
addLlamaNamedWriteables(namedWriteables);
175178

176179
addUnifiedNamedWriteables(namedWriteables);
177180

@@ -271,8 +274,25 @@ private static void addMistralNamedWriteables(List<NamedWriteableRegistry.Entry>
271274
MistralChatCompletionServiceSettings::new
272275
)
273276
);
277+
// no task settings for Mistral
278+
}
274279

275-
// note - no task settings for Mistral embeddings...
280+
private static void addLlamaNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
281+
namedWriteables.add(
282+
new NamedWriteableRegistry.Entry(
283+
ServiceSettings.class,
284+
LlamaEmbeddingsServiceSettings.NAME,
285+
LlamaEmbeddingsServiceSettings::new
286+
)
287+
);
288+
namedWriteables.add(
289+
new NamedWriteableRegistry.Entry(
290+
ServiceSettings.class,
291+
LlamaChatCompletionServiceSettings.NAME,
292+
LlamaChatCompletionServiceSettings::new
293+
)
294+
);
295+
// no task settings for Llama
276296
}
277297

278298
private static void addAzureAiStudioNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService;
133133
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
134134
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
135+
import org.elasticsearch.xpack.inference.services.llama.LlamaService;
135136
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
136137
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
137138
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
@@ -399,6 +400,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
399400
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
400401
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
401402
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
403+
context -> new LlamaService(httpFactory.get(), serviceComponents.get()),
402404
ElasticsearchInternalService::new,
403405
context -> new CustomService(httpFactory.get(), serviceComponents.get())
404406
);
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.llama;
9+
10+
import org.elasticsearch.inference.EmptySecretSettings;
11+
import org.elasticsearch.inference.ModelConfigurations;
12+
import org.elasticsearch.inference.ModelSecrets;
13+
import org.elasticsearch.inference.SecretSettings;
14+
import org.elasticsearch.inference.ServiceSettings;
15+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
16+
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
17+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
18+
19+
import java.net.URI;
20+
import java.net.URISyntaxException;
21+
import java.util.Map;
22+
import java.util.Objects;
23+
24+
public abstract class LlamaModel extends RateLimitGroupingModel {
25+
protected String modelId;
26+
protected URI uri;
27+
protected RateLimitSettings rateLimitSettings;
28+
29+
protected LlamaModel(ModelConfigurations configurations, ModelSecrets secrets) {
30+
super(configurations, secrets);
31+
}
32+
33+
protected LlamaModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) {
34+
super(model, serviceSettings);
35+
}
36+
37+
public String model() {
38+
return this.modelId;
39+
}
40+
41+
public URI uri() {
42+
return this.uri;
43+
}
44+
45+
@Override
46+
public RateLimitSettings rateLimitSettings() {
47+
return this.rateLimitSettings;
48+
}
49+
50+
@Override
51+
public int rateLimitGroupingHash() {
52+
return Objects.hash(modelId, uri, getSecretSettings());
53+
}
54+
55+
// Needed for testing only
56+
public void setURI(String newUri) {
57+
try {
58+
this.uri = new URI(newUri);
59+
} catch (URISyntaxException e) {
60+
// swallow any error
61+
}
62+
}
63+
64+
protected static SecretSettings retrieveSecretSettings(Map<String, Object> secrets) {
65+
return (secrets != null && secrets.isEmpty()) ? EmptySecretSettings.INSTANCE : DefaultSecretSettings.fromMap(secrets);
66+
}
67+
}

0 commit comments

Comments
 (0)