Skip to content

Commit 9956bed

Browse files
authored
Update TestSparseInferenceServiceExtension to not support text embeddings (#126618)
1 parent 9f2f4f2 commit 9956bed

File tree

2 files changed

+4
-19
lines changed

2 files changed

+4
-19
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
6464
@SuppressWarnings("unchecked")
6565
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
6666
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
67-
assertThat(services.size(), equalTo(16));
67+
assertThat(services.size(), equalTo(15));
6868

6969
String[] providers = new String[services.size()];
7070
for (int i = 0; i < services.size(); i++) {
@@ -86,7 +86,6 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
8686
"jinaai",
8787
"mistral",
8888
"openai",
89-
"test_service",
9089
"text_embedding_test_service",
9190
"voyageai",
9291
"watsonxai"

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import org.elasticsearch.xcontent.XContentBuilder;
3636
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
3737
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
38-
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
3938
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
4039

4140
import java.io.IOException;
@@ -64,7 +63,7 @@ public TestSparseModel(String inferenceEntityId, TestServiceSettings serviceSett
6463
public static class TestInferenceService extends AbstractTestInferenceService {
6564
public static final String NAME = "test_service";
6665

67-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING);
66+
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING);
6867

6968
public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}
7069

@@ -115,8 +114,7 @@ public void infer(
115114
ActionListener<InferenceServiceResults> listener
116115
) {
117116
switch (model.getConfigurations().getTaskType()) {
118-
case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeSparseEmbeddingResults(input));
119-
case TEXT_EMBEDDING -> listener.onResponse(makeTextEmbeddingResults(input));
117+
case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeResults(input));
120118
default -> listener.onFailure(
121119
new ElasticsearchStatusException(
122120
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
@@ -157,7 +155,7 @@ public void chunkedInfer(
157155
}
158156
}
159157

160-
private SparseEmbeddingResults makeSparseEmbeddingResults(List<String> input) {
158+
private SparseEmbeddingResults makeResults(List<String> input) {
161159
var embeddings = new ArrayList<SparseEmbeddingResults.Embedding>();
162160
for (int i = 0; i < input.size(); i++) {
163161
var tokens = new ArrayList<WeightedToken>();
@@ -169,18 +167,6 @@ private SparseEmbeddingResults makeSparseEmbeddingResults(List<String> input) {
169167
return new SparseEmbeddingResults(embeddings);
170168
}
171169

172-
private TextEmbeddingFloatResults makeTextEmbeddingResults(List<String> input) {
173-
var embeddings = new ArrayList<TextEmbeddingFloatResults.Embedding>();
174-
for (int i = 0; i < input.size(); i++) {
175-
var values = new float[5];
176-
for (int j = 0; j < 5; j++) {
177-
values[j] = random.nextFloat();
178-
}
179-
embeddings.add(new TextEmbeddingFloatResults.Embedding(values));
180-
}
181-
return new TextEmbeddingFloatResults(embeddings);
182-
}
183-
184170
private List<ChunkedInference> makeChunkedResults(List<ChunkInferenceInput> inputs) {
185171
List<ChunkedInference> results = new ArrayList<>();
186172
for (ChunkInferenceInput chunkInferenceInput : inputs) {

0 commit comments

Comments
 (0)