Skip to content

Commit 85478cf

Browse files
Refactor Hugging Face service settings and completion request methods for consistency
1 parent 21180df commit 85478cf

File tree

11 files changed

+25
-27
lines changed

11 files changed

+25
-27
lines changed

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Para
121121
* - Key: {@link #MODEL_FIELD}, Value: modelId
122122
* - Key: {@link #MAX_COMPLETION_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
123123
*/
124-
public static Params withMaxCompletionTokensTokens(String modelId, Params params) {
124+
public static Params withMaxCompletionTokens(String modelId, Params params) {
125125
return new DelegatingMapParams(
126126
Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_COMPLETION_TOKENS_FIELD)),
127127
params

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ public void testParseAllFields() throws IOException {
119119

120120
assertThat(request, is(expected));
121121
assertThat(
122-
Strings.toString(request, UnifiedCompletionRequest.withMaxCompletionTokensTokens("gpt-4o", ToXContent.EMPTY_PARAMS)),
122+
Strings.toString(request, UnifiedCompletionRequest.withMaxCompletionTokens("gpt-4o", ToXContent.EMPTY_PARAMS)),
123123
is(XContentHelper.stripWhitespace(requestJson))
124124
);
125125
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.core.Tuple;
1818
import org.elasticsearch.inference.InputType;
1919
import org.elasticsearch.inference.Model;
20+
import org.elasticsearch.inference.ModelConfigurations;
2021
import org.elasticsearch.inference.SimilarityMeasure;
2122
import org.elasticsearch.inference.TaskType;
2223
import org.elasticsearch.rest.RestStatus;
@@ -304,6 +305,12 @@ public static String invalidSettingError(String settingName, String scope) {
304305
return Strings.format("[%s] does not allow the setting [%s]", scope, settingName);
305306
}
306307

308+
public static URI extractUri(Map<String, Object> map, String fieldName, ValidationException validationException) {
309+
String parsedUrl = extractRequiredString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
310+
311+
return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
312+
}
313+
307314
public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) {
308315
try {
309316
return createOptionalUri(url);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInpu
2828
@Override
2929
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
3030
builder.startObject();
31-
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokensTokens(modelId, params));
31+
unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokens(modelId, params));
3232
builder.endObject();
3333

3434
return builder;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,10 @@
3131
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
3232
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
3333
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
34-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri;
3534
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
3635
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
37-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
3836
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
37+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri;
3938

4039
public class HuggingFaceServiceSettings extends FilteredXContentObject implements ServiceSettings, HuggingFaceRateLimitServiceSettings {
4140
public static final String NAME = "hugging_face_service_settings";
@@ -70,12 +69,6 @@ public static HuggingFaceServiceSettings fromMap(Map<String, Object> map, Config
7069
return new HuggingFaceServiceSettings(uri, similarityMeasure, dims, maxInputTokens, rateLimitSettings);
7170
}
7271

73-
public static URI extractUri(Map<String, Object> map, String fieldName, ValidationException validationException) {
74-
String parsedUrl = extractRequiredString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
75-
76-
return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
77-
}
78-
7972
private final URI uri;
8073
private final SimilarityMeasure similarity;
8174
private final Integer dimensions;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
3232
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
3333
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
34-
import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri;
34+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri;
3535

3636
/**
3737
* Settings for the Hugging Face chat completion service.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
3030
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
31-
import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri;
31+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri;
3232

3333
public class HuggingFaceElserServiceSettings extends FilteredXContentObject
3434
implements

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import java.util.Objects;
2828

2929
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
30-
import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri;
30+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri;
3131

3232
public class HuggingFaceRerankServiceSettings extends FilteredXContentObject
3333
implements

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import java.util.HashMap;
5555
import java.util.List;
5656
import java.util.Map;
57+
import java.util.Objects;
5758
import java.util.Set;
5859

5960
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
@@ -98,16 +99,12 @@ protected void doInfer(
9899
) {
99100
var actionCreator = new MistralActionCreator(getSender(), getServiceComponents());
100101

101-
switch (model) {
102-
case MistralEmbeddingsModel mistralEmbeddingsModel:
103-
mistralEmbeddingsModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener);
104-
break;
105-
case MistralChatCompletionModel mistralChatCompletionModel:
106-
mistralChatCompletionModel.accept(actionCreator).execute(inputs, timeout, listener);
107-
break;
108-
default:
109-
listener.onFailure(createInvalidModelException(model));
110-
break;
102+
if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) {
103+
mistralEmbeddingsModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener);
104+
} else if (model instanceof MistralChatCompletionModel mistralChatCompletionModel) {
105+
mistralChatCompletionModel.accept(actionCreator).execute(inputs, timeout, listener);
106+
} else {
107+
listener.onFailure(createInvalidModelException(model));
111108
}
112109
}
113110

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,13 @@ public boolean equals(Object o) {
178178
return Objects.equals(model, that.model)
179179
&& Objects.equals(dimensions, that.dimensions)
180180
&& Objects.equals(maxInputTokens, that.maxInputTokens)
181-
&& Objects.equals(similarity, that.similarity);
181+
&& Objects.equals(similarity, that.similarity)
182+
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
182183
}
183184

184185
@Override
185186
public int hashCode() {
186-
return Objects.hash(model, dimensions, maxInputTokens, similarity);
187+
return Objects.hash(model, dimensions, maxInputTokens, similarity, rateLimitSettings);
187188
}
188189

189190
}

0 commit comments

Comments
 (0)