Skip to content

Commit a143af1

Browse files
committed
hugging face
1 parent 356d546 commit a143af1

File tree

10 files changed

+82
-13
lines changed

10 files changed

+82
-13
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreator.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.external.action.huggingface;
99

10+
import org.elasticsearch.inference.InputType;
1011
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1112
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1213
import org.elasticsearch.xpack.inference.external.http.sender.HuggingFaceRequestManager;
@@ -35,13 +36,14 @@ public HuggingFaceActionCreator(Sender sender, ServiceComponents serviceComponen
3536
}
3637

3738
@Override
38-
public ExecutableAction create(HuggingFaceEmbeddingsModel model) {
39+
public ExecutableAction create(HuggingFaceEmbeddingsModel model, InputType inputType) {
3940
var responseHandler = new HuggingFaceResponseHandler(
4041
"hugging face text embeddings",
4142
HuggingFaceEmbeddingsResponseEntity::fromResponse
4243
);
44+
var overriddenModel = HuggingFaceEmbeddingsModel.of(model, inputType);
4345
var requestCreator = HuggingFaceRequestManager.of(
44-
model,
46+
overriddenModel,
4547
responseHandler,
4648
serviceComponents.truncator(),
4749
serviceComponents.threadPool()

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionVisitor.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77

88
package org.elasticsearch.xpack.inference.external.action.huggingface;
99

10+
import org.elasticsearch.inference.InputType;
1011
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1112
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
1213
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
1314

1415
public interface HuggingFaceActionVisitor {
15-
ExecutableAction create(HuggingFaceEmbeddingsModel mode);
16+
ExecutableAction create(HuggingFaceEmbeddingsModel mode, InputType inputType);
1617

1718
ExecutableAction create(HuggingFaceElserModel mode);
1819
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ public void doInfer(
155155
var huggingFaceModel = (HuggingFaceModel) model;
156156
var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents());
157157

158-
var action = huggingFaceModel.accept(actionCreator);
158+
var action = huggingFaceModel.accept(actionCreator, inputType);
159159
action.execute(inputs, timeout, listener);
160160
}
161161
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.common.settings.SecureString;
1111
import org.elasticsearch.core.Nullable;
12+
import org.elasticsearch.inference.InputType;
1213
import org.elasticsearch.inference.Model;
1314
import org.elasticsearch.inference.ModelConfigurations;
1415
import org.elasticsearch.inference.ModelSecrets;
@@ -44,6 +45,6 @@ public SecureString apiKey() {
4445

4546
public abstract Integer getTokenLimit();
4647

47-
public abstract ExecutableAction accept(HuggingFaceActionVisitor creator);
48+
public abstract ExecutableAction accept(HuggingFaceActionVisitor creator, InputType inputType);
4849

4950
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ protected void doChunkedInfer(
134134
).batchRequestsWithListeners(listener);
135135

136136
for (var request : batchedRequests) {
137-
var action = huggingFaceModel.accept(actionCreator);
137+
var action = huggingFaceModel.accept(actionCreator, inputType);
138138
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
139139
}
140140
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModel.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.huggingface.elser;
99

1010
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.inference.InputType;
1112
import org.elasticsearch.inference.ModelConfigurations;
1213
import org.elasticsearch.inference.ModelSecrets;
1314
import org.elasticsearch.inference.TaskType;
@@ -63,7 +64,7 @@ public DefaultSecretSettings getSecretSettings() {
6364
}
6465

6566
@Override
66-
public ExecutableAction accept(HuggingFaceActionVisitor creator) {
67+
public ExecutableAction accept(HuggingFaceActionVisitor creator, InputType inputType) {
6768
return creator.create(this);
6869
}
6970

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,34 @@
77

88
package org.elasticsearch.xpack.inference.services.huggingface.embeddings;
99

10+
import org.elasticsearch.common.ValidationException;
1011
import org.elasticsearch.core.Nullable;
1112
import org.elasticsearch.inference.ChunkingSettings;
13+
import org.elasticsearch.inference.InputType;
1214
import org.elasticsearch.inference.ModelConfigurations;
1315
import org.elasticsearch.inference.ModelSecrets;
1416
import org.elasticsearch.inference.TaskType;
1517
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1618
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor;
1719
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
20+
import org.elasticsearch.xpack.inference.services.ServiceUtils;
1821
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
1922
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
2023
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
2124

2225
import java.util.Map;
2326

2427
public class HuggingFaceEmbeddingsModel extends HuggingFaceModel {
28+
public static HuggingFaceEmbeddingsModel of(HuggingFaceEmbeddingsModel model, InputType inputType) {
29+
ValidationException validationException = new ValidationException();
30+
ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException);
31+
if (validationException.validationErrors().isEmpty() == false) {
32+
throw validationException;
33+
}
34+
35+
return model;
36+
}
37+
2538
public HuggingFaceEmbeddingsModel(
2639
String inferenceEntityId,
2740
TaskType taskType,
@@ -85,7 +98,7 @@ public Integer getTokenLimit() {
8598
}
8699

87100
@Override
88-
public ExecutableAction accept(HuggingFaceActionVisitor creator) {
89-
return creator.create(this);
101+
public ExecutableAction accept(HuggingFaceActionVisitor creator, InputType inputType) {
102+
return creator.create(this, inputType);
90103
}
91104
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I
205205

206206
var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret");
207207
var actionCreator = new HuggingFaceActionCreator(sender, createWithEmptySettings(threadPool));
208-
var action = actionCreator.create(model);
208+
var action = actionCreator.create(model, null);
209209

210210
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
211211
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
@@ -263,7 +263,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws
263263
sender,
264264
new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator())
265265
);
266-
var action = actionCreator.create(model);
266+
var action = actionCreator.create(model, null);
267267

268268
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
269269
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
@@ -318,7 +318,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc
318318

319319
var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret");
320320
var actionCreator = new HuggingFaceActionCreator(sender, createWithEmptySettings(threadPool));
321-
var action = actionCreator.create(model);
321+
var action = actionCreator.create(model, null);
322322

323323
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
324324
action.execute(new DocumentsOnlyInput(List.of("abcd")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
@@ -376,7 +376,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException {
376376
// truncated to 1 token = 3 characters
377377
var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1);
378378
var actionCreator = new HuggingFaceActionCreator(sender, createWithEmptySettings(threadPool));
379-
var action = actionCreator.create(model);
379+
var action = actionCreator.create(model, null);
380380

381381
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
382382
action.execute(new DocumentsOnlyInput(List.of("123456")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.ElasticsearchStatusException;
1414
import org.elasticsearch.action.ActionListener;
1515
import org.elasticsearch.action.support.PlainActionFuture;
16+
import org.elasticsearch.common.ValidationException;
1617
import org.elasticsearch.common.bytes.BytesArray;
1718
import org.elasticsearch.common.bytes.BytesReference;
1819
import org.elasticsearch.common.settings.Settings;
@@ -580,6 +581,29 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException {
580581
}
581582
}
582583

584+
public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException {
585+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
586+
587+
var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret");
588+
589+
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
590+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
591+
service.infer(
592+
model,
593+
null,
594+
List.of("abc"),
595+
false,
596+
new HashMap<>(),
597+
InputType.INGEST,
598+
InferenceAction.Request.DEFAULT_TIMEOUT,
599+
listener
600+
);
601+
602+
var thrownException = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT));
603+
assertThat(thrownException.getMessage(), is("Invalid value [search] received. [input_type] is not allowed;"));
604+
}
605+
}
606+
583607
public void testInfer_SendsElserRequest() throws IOException {
584608
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
585609

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,18 @@
77

88
package org.elasticsearch.xpack.inference.services.huggingface.embeddings;
99

10+
import org.elasticsearch.common.ValidationException;
1011
import org.elasticsearch.common.settings.SecureString;
1112
import org.elasticsearch.core.Nullable;
13+
import org.elasticsearch.inference.InputType;
1214
import org.elasticsearch.inference.SimilarityMeasure;
1315
import org.elasticsearch.inference.TaskType;
1416
import org.elasticsearch.test.ESTestCase;
1517
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
1618
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
19+
import org.hamcrest.CoreMatchers;
20+
import org.hamcrest.MatcherAssert;
21+
import org.hamcrest.Matchers;
1722

1823
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
1924
import static org.hamcrest.Matchers.containsString;
@@ -74,4 +79,26 @@ public static HuggingFaceEmbeddingsModel createModel(
7479
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
7580
);
7681
}
82+
83+
public void testThrowsError_WhenInputTypeSpecified() {
84+
var model = createModel("url", "api_key");
85+
86+
var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceEmbeddingsModel.of(model, InputType.SEARCH));
87+
assertThat(
88+
thrownException.getMessage(),
89+
CoreMatchers.is("Validation Failed: 1: Invalid value [search] received. [input_type] is not allowed;")
90+
);
91+
}
92+
93+
public void testAcceptsInternalInputType() {
94+
var model = createModel("url", "api_key");
95+
var overriddenModel = HuggingFaceEmbeddingsModel.of(model, InputType.INTERNAL_SEARCH);
96+
MatcherAssert.assertThat(overriddenModel, Matchers.is(model));
97+
}
98+
99+
public void testAcceptsNullInputType() {
100+
var model = createModel("url", "api_key");
101+
var overriddenModel = HuggingFaceEmbeddingsModel.of(model, null);
102+
MatcherAssert.assertThat(overriddenModel, Matchers.is(model));
103+
}
77104
}

0 commit comments

Comments
 (0)