Skip to content

Commit 4413a76

Browse files
Fixing IT
1 parent eb21ca8 commit 4413a76

File tree

2 files changed

+50
-6
lines changed

2 files changed

+50
-6
lines changed

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.xcontent.XContentBuilder;
3535
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
3636
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
37+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
3738
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
3839

3940
import java.io.IOException;
@@ -61,7 +62,7 @@ public TestSparseModel(String inferenceEntityId, TestServiceSettings serviceSett
6162
public static class TestInferenceService extends AbstractTestInferenceService {
6263
public static final String NAME = "test_service";
6364

64-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING);
65+
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING);
6566

6667
public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}
6768

@@ -110,7 +111,8 @@ public void infer(
110111
ActionListener<InferenceServiceResults> listener
111112
) {
112113
switch (model.getConfigurations().getTaskType()) {
113-
case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeResults(input));
114+
case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeSparseEmbeddingResults(input));
115+
case TEXT_EMBEDDING -> listener.onResponse(makeTextEmbeddingResults(input));
114116
default -> listener.onFailure(
115117
new ElasticsearchStatusException(
116118
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
@@ -151,7 +153,7 @@ public void chunkedInfer(
151153
}
152154
}
153155

154-
private SparseEmbeddingResults makeResults(List<String> input) {
156+
private SparseEmbeddingResults makeSparseEmbeddingResults(List<String> input) {
155157
var embeddings = new ArrayList<SparseEmbeddingResults.Embedding>();
156158
for (int i = 0; i < input.size(); i++) {
157159
var tokens = new ArrayList<WeightedToken>();
@@ -163,6 +165,18 @@ private SparseEmbeddingResults makeResults(List<String> input) {
163165
return new SparseEmbeddingResults(embeddings);
164166
}
165167

168+
private TextEmbeddingFloatResults makeTextEmbeddingResults(List<String> input) {
169+
var embeddings = new ArrayList<TextEmbeddingFloatResults.Embedding>();
170+
for (int i = 0; i < input.size(); i++) {
171+
var values = new float[5];
172+
for (int j = 0; j < 5; j++) {
173+
values[j] = random.nextFloat();
174+
}
175+
embeddings.add(new TextEmbeddingFloatResults.Embedding(values));
176+
}
177+
return new TextEmbeddingFloatResults(embeddings);
178+
}
179+
166180
private List<ChunkedInference> makeChunkedResults(List<String> input) {
167181
List<ChunkedInference> results = new ArrayList<>();
168182
for (int i = 0; i < input.size(); i++) {

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
import org.elasticsearch.xcontent.XContentBuilder;
3535
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
3636
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
37+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
3738

3839
import java.io.IOException;
40+
import java.util.ArrayList;
3941
import java.util.EnumSet;
4042
import java.util.HashMap;
4143
import java.util.Iterator;
@@ -57,7 +59,11 @@ public static class TestInferenceService extends AbstractTestInferenceService {
5759
private static final String NAME = "streaming_completion_test_service";
5860
private static final Set<TaskType> supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
5961

60-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
62+
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
63+
TaskType.COMPLETION,
64+
TaskType.CHAT_COMPLETION,
65+
TaskType.SPARSE_EMBEDDING
66+
);
6167

6268
public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}
6369

@@ -111,7 +117,19 @@ public void infer(
111117
ActionListener<InferenceServiceResults> listener
112118
) {
113119
switch (model.getConfigurations().getTaskType()) {
114-
case COMPLETION -> listener.onResponse(makeResults(input));
120+
case COMPLETION -> listener.onResponse(makeChatCompletionResults(input));
121+
case SPARSE_EMBEDDING -> {
122+
if (stream) {
123+
listener.onFailure(
124+
new ElasticsearchStatusException(
125+
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
126+
RestStatus.BAD_REQUEST
127+
)
128+
);
129+
} else {
130+
listener.onResponse(makeTextEmbeddingResults(input));
131+
}
132+
}
115133
default -> listener.onFailure(
116134
new ElasticsearchStatusException(
117135
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
@@ -139,7 +157,7 @@ public void unifiedCompletionInfer(
139157
}
140158
}
141159

142-
private StreamingChatCompletionResults makeResults(List<String> input) {
160+
private StreamingChatCompletionResults makeChatCompletionResults(List<String> input) {
143161
var responseIter = input.stream().map(s -> s.toUpperCase(Locale.ROOT)).iterator();
144162
return new StreamingChatCompletionResults(subscriber -> {
145163
subscriber.onSubscribe(new Flow.Subscription() {
@@ -158,6 +176,18 @@ public void cancel() {}
158176
});
159177
}
160178

179+
private TextEmbeddingFloatResults makeTextEmbeddingResults(List<String> input) {
180+
var embeddings = new ArrayList<TextEmbeddingFloatResults.Embedding>();
181+
for (int i = 0; i < input.size(); i++) {
182+
var values = new float[5];
183+
for (int j = 0; j < 5; j++) {
184+
values[j] = random.nextFloat();
185+
}
186+
embeddings.add(new TextEmbeddingFloatResults.Embedding(values));
187+
}
188+
return new TextEmbeddingFloatResults(embeddings);
189+
}
190+
161191
private InferenceServiceResults.Result completionChunk(String delta) {
162192
return new InferenceServiceResults.Result() {
163193
@Override

0 commit comments

Comments
 (0)