3434import org .elasticsearch .xcontent .XContentBuilder ;
3535import org .elasticsearch .xpack .core .inference .results .StreamingChatCompletionResults ;
3636import org .elasticsearch .xpack .core .inference .results .StreamingUnifiedChatCompletionResults ;
37+ import org .elasticsearch .xpack .core .inference .results .TextEmbeddingFloatResults ;
3738
3839import java .io .IOException ;
40+ import java .util .ArrayList ;
3941import java .util .EnumSet ;
4042import java .util .HashMap ;
4143import java .util .Iterator ;
@@ -57,7 +59,11 @@ public static class TestInferenceService extends AbstractTestInferenceService {
5759 private static final String NAME = "streaming_completion_test_service" ;
5860 private static final Set <TaskType > supportedStreamingTasks = Set .of (TaskType .COMPLETION , TaskType .CHAT_COMPLETION );
5961
60- private static final EnumSet <TaskType > supportedTaskTypes = EnumSet .of (TaskType .COMPLETION , TaskType .CHAT_COMPLETION );
62+ private static final EnumSet <TaskType > supportedTaskTypes = EnumSet .of (
63+ TaskType .COMPLETION ,
64+ TaskType .CHAT_COMPLETION ,
65+ TaskType .SPARSE_EMBEDDING
66+ );
6167
6268 public TestInferenceService (InferenceServiceExtension .InferenceServiceFactoryContext context ) {}
6369
@@ -111,7 +117,19 @@ public void infer(
111117 ActionListener <InferenceServiceResults > listener
112118 ) {
113119 switch (model .getConfigurations ().getTaskType ()) {
114- case COMPLETION -> listener .onResponse (makeResults (input ));
120+ case COMPLETION -> listener .onResponse (makeChatCompletionResults (input ));
121+ case SPARSE_EMBEDDING -> {
122+ if (stream ) {
123+ listener .onFailure (
124+ new ElasticsearchStatusException (
125+ TaskType .unsupportedTaskTypeErrorMsg (model .getConfigurations ().getTaskType (), name ()),
126+ RestStatus .BAD_REQUEST
127+ )
128+ );
129+ } else {
130+ listener .onResponse (makeTextEmbeddingResults (input ));
131+ }
132+ }
115133 default -> listener .onFailure (
116134 new ElasticsearchStatusException (
117135 TaskType .unsupportedTaskTypeErrorMsg (model .getConfigurations ().getTaskType (), name ()),
@@ -139,7 +157,7 @@ public void unifiedCompletionInfer(
139157 }
140158 }
141159
142- private StreamingChatCompletionResults makeResults (List <String > input ) {
160+ private StreamingChatCompletionResults makeChatCompletionResults (List <String > input ) {
143161 var responseIter = input .stream ().map (s -> s .toUpperCase (Locale .ROOT )).iterator ();
144162 return new StreamingChatCompletionResults (subscriber -> {
145163 subscriber .onSubscribe (new Flow .Subscription () {
@@ -158,6 +176,18 @@ public void cancel() {}
158176 });
159177 }
160178
179+ private TextEmbeddingFloatResults makeTextEmbeddingResults (List <String > input ) {
180+ var embeddings = new ArrayList <TextEmbeddingFloatResults .Embedding >();
181+ for (int i = 0 ; i < input .size (); i ++) {
182+ var values = new float [5 ];
183+ for (int j = 0 ; j < 5 ; j ++) {
184+ values [j ] = random .nextFloat ();
185+ }
186+ embeddings .add (new TextEmbeddingFloatResults .Embedding (values ));
187+ }
188+ return new TextEmbeddingFloatResults (embeddings );
189+ }
190+
161191 private InferenceServiceResults .Result completionChunk (String delta ) {
162192 return new InferenceServiceResults .Result () {
163193 @ Override
0 commit comments