Skip to content

Commit ff3ef50

Browse files
Refactored Hugging Face Completion Service Settings, removed Request Manager, added Unit Tests
1 parent 91fa92e commit ff3ef50

File tree

8 files changed

+595
-138
lines changed

8 files changed

+595
-138
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceEmbeddingsRequestManager.java

Lines changed: 0 additions & 88 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,66 @@
77

88
package org.elasticsearch.xpack.inference.services.huggingface;
99

10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
1014
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.common.Truncator;
16+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
17+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
1118
import org.elasticsearch.xpack.inference.external.http.sender.BaseRequestManager;
19+
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
20+
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
21+
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
22+
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest;
1223

13-
public abstract class HuggingFaceRequestManager extends BaseRequestManager {
14-
protected HuggingFaceRequestManager(HuggingFaceModel model, ThreadPool threadPool) {
24+
import java.util.List;
25+
import java.util.Objects;
26+
import java.util.function.Supplier;
27+
28+
import static org.elasticsearch.xpack.inference.common.Truncator.truncate;
29+
30+
public class HuggingFaceRequestManager extends BaseRequestManager {
31+
private static final Logger logger = LogManager.getLogger(HuggingFaceRequestManager.class);
32+
33+
public static HuggingFaceRequestManager of(
34+
HuggingFaceModel model,
35+
ResponseHandler responseHandler,
36+
Truncator truncator,
37+
ThreadPool threadPool
38+
) {
39+
return new HuggingFaceRequestManager(
40+
Objects.requireNonNull(model),
41+
Objects.requireNonNull(responseHandler),
42+
Objects.requireNonNull(truncator),
43+
Objects.requireNonNull(threadPool)
44+
);
45+
}
46+
47+
private final HuggingFaceModel model;
48+
private final ResponseHandler responseHandler;
49+
private final Truncator truncator;
50+
51+
private HuggingFaceRequestManager(HuggingFaceModel model, ResponseHandler responseHandler, Truncator truncator, ThreadPool threadPool) {
1552
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
53+
this.model = model;
54+
this.responseHandler = responseHandler;
55+
this.truncator = truncator;
56+
}
57+
58+
@Override
59+
public void execute(
60+
InferenceInputs inferenceInputs,
61+
RequestSender requestSender,
62+
Supplier<Boolean> hasRequestCompletedFunction,
63+
ActionListener<InferenceServiceResults> listener
64+
) {
65+
List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs();
66+
var truncatedInput = truncate(docsInput, model.getTokenLimit());
67+
var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model);
68+
69+
execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
1670
}
1771

1872
record RateLimitGrouping(int accountHash) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1717
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
1818
import org.elasticsearch.xpack.inference.services.ServiceComponents;
19-
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceEmbeddingsRequestManager;
19+
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRequestManager;
2020
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler;
2121
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
2222
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
@@ -58,7 +58,7 @@ public ExecutableAction create(HuggingFaceEmbeddingsModel model) {
5858
"hugging face text embeddings",
5959
HuggingFaceEmbeddingsResponseEntity::fromResponse
6060
);
61-
var requestCreator = HuggingFaceEmbeddingsRequestManager.of(
61+
var requestCreator = HuggingFaceRequestManager.of(
6262
model,
6363
responseHandler,
6464
serviceComponents.truncator(),
@@ -71,7 +71,7 @@ public ExecutableAction create(HuggingFaceEmbeddingsModel model) {
7171
@Override
7272
public ExecutableAction create(HuggingFaceElserModel model) {
7373
var responseHandler = new HuggingFaceResponseHandler("hugging face elser", HuggingFaceElserResponseEntity::fromResponse);
74-
var requestCreator = HuggingFaceEmbeddingsRequestManager.of(
74+
var requestCreator = HuggingFaceRequestManager.of(
7575
model,
7676
responseHandler,
7777
serviceComponents.truncator(),

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@
2929

3030
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
3131
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
32+
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
3233
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
3334
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
34-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
35+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
3536
import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri;
3637

3738
/**
@@ -47,11 +48,9 @@ public class HuggingFaceChatCompletionServiceSettings extends FilteredXContentOb
4748
HuggingFaceRateLimitServiceSettings {
4849

4950
public static final String NAME = "hugging_face_completion_service_settings";
50-
public static final String URL = "url";
5151
// At the time of writing HuggingFace hasn't posted the default rate limit for inference endpoints so the value his is only a guess
5252
// 3000 requests per minute
5353
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
54-
private static final int DEFAULT_TOKEN_LIMIT = 512;
5554

5655
/**
5756
* Creates a new instance of {@link HuggingFaceChatCompletionServiceSettings} from a map of settings.
@@ -62,7 +61,7 @@ public class HuggingFaceChatCompletionServiceSettings extends FilteredXContentOb
6261
public static HuggingFaceChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
6362
ValidationException validationException = new ValidationException();
6463

65-
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
64+
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
6665

6766
var uri = extractUri(map, URL, validationException);
6867

@@ -93,7 +92,7 @@ public static HuggingFaceChatCompletionServiceSettings fromMap(Map<String, Objec
9392
private final RateLimitSettings rateLimitSettings;
9493

9594
public HuggingFaceChatCompletionServiceSettings(
96-
String modelId,
95+
@Nullable String modelId,
9796
String url,
9897
@Nullable Integer maxInputTokens,
9998
@Nullable RateLimitSettings rateLimitSettings
@@ -102,14 +101,14 @@ public HuggingFaceChatCompletionServiceSettings(
102101
}
103102

104103
public HuggingFaceChatCompletionServiceSettings(
105-
String modelId,
104+
@Nullable String modelId,
106105
URI uri,
107106
@Nullable Integer maxInputTokens,
108107
@Nullable RateLimitSettings rateLimitSettings
109108
) {
110109
this.modelId = modelId;
111110
this.uri = uri;
112-
this.maxInputTokens = Objects.requireNonNullElse(maxInputTokens, DEFAULT_TOKEN_LIMIT);
111+
this.maxInputTokens = maxInputTokens;
113112
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
114113
}
115114

@@ -119,15 +118,14 @@ public HuggingFaceChatCompletionServiceSettings(
119118
* @throws IOException if an I/O error occurs
120119
*/
121120
public HuggingFaceChatCompletionServiceSettings(StreamInput in) throws IOException {
122-
this.modelId = in.readString();
121+
this.modelId = in.readOptionalString();
123122
this.uri = createUri(in.readString());
123+
this.maxInputTokens = in.readOptionalVInt();
124124

125125
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) {
126126
this.rateLimitSettings = new RateLimitSettings(in);
127-
this.maxInputTokens = in.readOptionalVInt();
128127
} else {
129128
this.rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS;
130-
this.maxInputTokens = DEFAULT_TOKEN_LIMIT;
131129
}
132130
}
133131

@@ -141,7 +139,7 @@ public URI uri() {
141139
return uri;
142140
}
143141

144-
public int maxInputTokens() {
142+
public Integer maxInputTokens() {
145143
return maxInputTokens;
146144
}
147145

@@ -161,10 +159,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
161159

162160
@Override
163161
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
164-
builder.field(MODEL_ID, modelId);
165-
162+
if (modelId != null) {
163+
builder.field(MODEL_ID, modelId);
164+
}
166165
builder.field(URL, uri.toString());
167-
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
166+
if (maxInputTokens != null) {
167+
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
168+
}
168169
rateLimitSettings.toXContent(builder, params);
169170

170171
return builder;
@@ -177,13 +178,13 @@ public String getWriteableName() {
177178

178179
@Override
179180
public TransportVersion getMinimalSupportedVersion() {
180-
return TransportVersions.V_8_12_0;
181+
return TransportVersions.V_8_14_0;
181182
}
182183

183184
@Override
184185
public void writeTo(StreamOutput out) throws IOException {
185-
out.writeString(modelId);
186-
out.writeOptionalString(uri != null ? uri.toString() : null);
186+
out.writeOptionalString(modelId);
187+
out.writeString(uri.toString());
187188
out.writeOptionalVInt(maxInputTokens);
188189

189190
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,7 @@ public void testGetConfiguration() throws Exception {
821821
{
822822
"service": "hugging_face",
823823
"name": "Hugging Face",
824-
"task_types": ["text_embedding", "sparse_embedding"],
824+
"task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"],
825825
"configurations": {
826826
"api_key": {
827827
"description": "API Key for the provider you're connecting to.",
@@ -830,7 +830,7 @@ public void testGetConfiguration() throws Exception {
830830
"sensitive": true,
831831
"updatable": true,
832832
"type": "str",
833-
"supported_task_types": ["text_embedding", "sparse_embedding"]
833+
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
834834
},
835835
"rate_limit.requests_per_minute": {
836836
"description": "Minimize the number of rate limit errors.",
@@ -839,7 +839,7 @@ public void testGetConfiguration() throws Exception {
839839
"sensitive": false,
840840
"updatable": false,
841841
"type": "int",
842-
"supported_task_types": ["text_embedding", "sparse_embedding"]
842+
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
843843
},
844844
"url": {
845845
"default_value": "https://api.openai.com/v1/embeddings",
@@ -849,7 +849,7 @@ public void testGetConfiguration() throws Exception {
849849
"sensitive": false,
850850
"updatable": false,
851851
"type": "str",
852-
"supported_task_types": ["text_embedding", "sparse_embedding"]
852+
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
853853
}
854854
}
855855
}

0 commit comments

Comments
 (0)