Skip to content

Commit 75cbf85

Browse files
Refactor Llama model classes to implement accept method for action visitors
1 parent 15c14d7 commit 75cbf85

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
import org.elasticsearch.inference.ModelSecrets;
1313
import org.elasticsearch.inference.SecretSettings;
1414
import org.elasticsearch.inference.ServiceSettings;
15+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1516
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
17+
import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionVisitor;
1618
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
1719
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
1820

@@ -87,4 +89,6 @@ public void setURI(String newUri) {
8789
protected static SecretSettings retrieveSecretSettings(Map<String, Object> secrets) {
8890
return (secrets != null && secrets.isEmpty()) ? EmptySecretSettings.INSTANCE : DefaultSecretSettings.fromMap(secrets);
8991
}
92+
93+
protected abstract ExecutableAction accept(LlamaActionVisitor creator);
9094
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,8 @@ protected void doInfer(
106106
ActionListener<InferenceServiceResults> listener
107107
) {
108108
var actionCreator = new LlamaActionCreator(getSender(), getServiceComponents());
109-
110-
if (model instanceof LlamaEmbeddingsModel llamaEmbeddingsModel) {
111-
llamaEmbeddingsModel.accept(actionCreator).execute(inputs, timeout, listener);
112-
} else if (model instanceof LlamaChatCompletionModel llamaChatCompletionModel) {
113-
llamaChatCompletionModel.accept(actionCreator).execute(inputs, timeout, listener);
109+
if (model instanceof LlamaModel llamaModel) {
110+
llamaModel.accept(actionCreator).execute(inputs, timeout, listener);
114111
} else {
115112
listener.onFailure(createInvalidModelException(model));
116113
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModel.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ public LlamaChatCompletionServiceSettings getServiceSettings() {
126126
* @param creator the visitor that creates the executable action
127127
* @return an ExecutableAction representing this model
128128
*/
129+
@Override
129130
public ExecutableAction accept(LlamaActionVisitor creator) {
130131
return creator.create(this);
131132
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModel.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ public LlamaEmbeddingsServiceSettings getServiceSettings() {
118118
* @param creator the visitor that creates the executable action
119119
* @return an ExecutableAction representing the Llama embeddings model
120120
*/
121+
@Override
121122
public ExecutableAction accept(LlamaActionVisitor creator) {
122123
return creator.create(this);
123124
}

0 commit comments

Comments
 (0)