|
7 | 7 |
|
8 | 8 | package org.elasticsearch.xpack.core.inference.action;
|
9 | 9 |
|
10 |
| -import org.elasticsearch.ElasticsearchStatusException; |
11 | 10 | import org.elasticsearch.TransportVersion;
|
12 |
| -import org.elasticsearch.TransportVersions; |
13 | 11 | import org.elasticsearch.action.ActionRequestValidationException;
|
14 | 12 | import org.elasticsearch.action.ActionResponse;
|
15 | 13 | import org.elasticsearch.action.ActionType;
|
|
19 | 17 | import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
|
20 | 18 | import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
|
21 | 19 | import org.elasticsearch.core.TimeValue;
|
22 |
| -import org.elasticsearch.inference.InferenceResults; |
23 | 20 | import org.elasticsearch.inference.InferenceServiceResults;
|
24 | 21 | import org.elasticsearch.inference.InputType;
|
25 | 22 | import org.elasticsearch.inference.TaskType;
|
26 |
| -import org.elasticsearch.rest.RestStatus; |
27 | 23 | import org.elasticsearch.xcontent.ObjectParser;
|
28 | 24 | import org.elasticsearch.xcontent.ParseField;
|
29 | 25 | import org.elasticsearch.xcontent.ToXContent;
|
30 | 26 | import org.elasticsearch.xcontent.XContentParser;
|
31 | 27 | import org.elasticsearch.xpack.core.inference.InferenceContext;
|
32 |
| -import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; |
33 |
| -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; |
34 |
| -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; |
35 | 28 |
|
36 | 29 | import java.io.IOException;
|
37 |
| -import java.util.ArrayList; |
38 |
| -import java.util.EnumSet; |
39 | 30 | import java.util.Iterator;
|
40 | 31 | import java.util.List;
|
41 | 32 | import java.util.Map;
|
@@ -79,12 +70,6 @@ public static Builder builder(String inferenceEntityId, TaskType taskType) {
|
79 | 70 | PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
|
80 | 71 | }
|
81 | 72 |
|
82 |
| - private static final EnumSet<InputType> validEnumsBeforeUnspecifiedAdded = EnumSet.of(InputType.INGEST, InputType.SEARCH); |
83 |
| - private static final EnumSet<InputType> validEnumsBeforeClassificationClusteringAdded = EnumSet.range( |
84 |
| - InputType.INGEST, |
85 |
| - InputType.UNSPECIFIED |
86 |
| - ); |
87 |
| - |
88 | 73 | public static Builder parseRequest(String inferenceEntityId, TaskType taskType, InferenceContext context, XContentParser parser)
|
89 | 74 | throws IOException {
|
90 | 75 | Request.Builder builder = PARSER.apply(parser, null);
|
@@ -164,25 +149,11 @@ public Request(StreamInput in) throws IOException {
|
164 | 149 | super(in);
|
165 | 150 | this.taskType = TaskType.fromStream(in);
|
166 | 151 | this.inferenceEntityId = in.readString();
|
167 |
| - if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) { |
168 |
| - this.input = in.readStringCollectionAsList(); |
169 |
| - } else { |
170 |
| - this.input = List.of(in.readString()); |
171 |
| - } |
| 152 | + this.input = in.readStringCollectionAsList(); |
172 | 153 | this.taskSettings = in.readGenericMap();
|
173 |
| - if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) { |
174 |
| - this.inputType = in.readEnum(InputType.class); |
175 |
| - } else { |
176 |
| - this.inputType = InputType.UNSPECIFIED; |
177 |
| - } |
178 |
| - |
179 |
| - if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) { |
180 |
| - this.query = in.readOptionalString(); |
181 |
| - this.inferenceTimeout = in.readTimeValue(); |
182 |
| - } else { |
183 |
| - this.query = null; |
184 |
| - this.inferenceTimeout = DEFAULT_TIMEOUT; |
185 |
| - } |
| 154 | + this.inputType = in.readEnum(InputType.class); |
| 155 | + this.query = in.readOptionalString(); |
| 156 | + this.inferenceTimeout = in.readTimeValue(); |
186 | 157 |
|
187 | 158 | if (in.getTransportVersion().supports(RERANK_COMMON_OPTIONS_ADDED)) {
|
188 | 159 | this.returnDocuments = in.readOptionalBoolean();
|
@@ -298,41 +269,18 @@ public void writeTo(StreamOutput out) throws IOException {
|
298 | 269 | super.writeTo(out);
|
299 | 270 | taskType.writeTo(out);
|
300 | 271 | out.writeString(inferenceEntityId);
|
301 |
| - if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) { |
302 |
| - out.writeStringCollection(input); |
303 |
| - } else { |
304 |
| - out.writeString(input.get(0)); |
305 |
| - } |
| 272 | + out.writeStringCollection(input); |
306 | 273 | out.writeGenericMap(taskSettings);
|
307 |
| - |
308 |
| - if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) { |
309 |
| - out.writeEnum(getInputTypeToWrite(inputType, out.getTransportVersion())); |
310 |
| - } |
311 |
| - |
312 |
| - if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) { |
313 |
| - out.writeOptionalString(query); |
314 |
| - out.writeTimeValue(inferenceTimeout); |
315 |
| - } |
| 274 | + out.writeEnum(inputType); |
| 275 | + out.writeOptionalString(query); |
| 276 | + out.writeTimeValue(inferenceTimeout); |
316 | 277 |
|
317 | 278 | if (out.getTransportVersion().supports(RERANK_COMMON_OPTIONS_ADDED)) {
|
318 | 279 | out.writeOptionalBoolean(returnDocuments);
|
319 | 280 | out.writeOptionalInt(topN);
|
320 | 281 | }
|
321 | 282 | }
|
322 | 283 |
|
323 |
| - // default for easier testing |
324 |
| - static InputType getInputTypeToWrite(InputType inputType, TransportVersion version) { |
325 |
| - if (version.before(TransportVersions.V_8_13_0)) { |
326 |
| - if (validEnumsBeforeUnspecifiedAdded.contains(inputType) == false) { |
327 |
| - return InputType.INGEST; |
328 |
| - } else if (validEnumsBeforeClassificationClusteringAdded.contains(inputType) == false) { |
329 |
| - return InputType.UNSPECIFIED; |
330 |
| - } |
331 |
| - } |
332 |
| - |
333 |
| - return inputType; |
334 |
| - } |
335 |
| - |
336 | 284 | @Override
|
337 | 285 | public boolean equals(Object o) {
|
338 | 286 | if (this == o) return true;
|
@@ -509,65 +457,12 @@ public Response(InferenceServiceResults results, Flow.Publisher<InferenceService
|
509 | 457 | }
|
510 | 458 |
|
511 | 459 | public Response(StreamInput in) throws IOException {
|
512 |
| - if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) { |
513 |
| - results = in.readNamedWriteable(InferenceServiceResults.class); |
514 |
| - } else { |
515 |
| - // It should only be InferenceResults aka TextEmbeddingResults from ml plugin for |
516 |
| - // hugging face elser and elser |
517 |
| - results = transformToServiceResults(List.of(in.readNamedWriteable(InferenceResults.class))); |
518 |
| - } |
| 460 | + this.results = in.readNamedWriteable(InferenceServiceResults.class); |
519 | 461 | // streaming isn't supported via Writeable yet
|
520 | 462 | this.isStreaming = false;
|
521 | 463 | this.publisher = null;
|
522 | 464 | }
|
523 | 465 |
|
524 |
| - @SuppressWarnings("deprecation") |
525 |
| - public static InferenceServiceResults transformToServiceResults(List<? extends InferenceResults> parsedResults) { |
526 |
| - if (parsedResults.isEmpty()) { |
527 |
| - throw new ElasticsearchStatusException( |
528 |
| - "Failed to transform results to response format, expected a non-empty list, please remove and re-add the service", |
529 |
| - RestStatus.INTERNAL_SERVER_ERROR |
530 |
| - ); |
531 |
| - } |
532 |
| - |
533 |
| - if (parsedResults.get(0) instanceof LegacyTextEmbeddingResults openaiResults) { |
534 |
| - if (parsedResults.size() > 1) { |
535 |
| - throw new ElasticsearchStatusException( |
536 |
| - "Failed to transform results to response format, malformed text embedding result," |
537 |
| - + " please remove and re-add the service", |
538 |
| - RestStatus.INTERNAL_SERVER_ERROR |
539 |
| - ); |
540 |
| - } |
541 |
| - |
542 |
| - return openaiResults.transformToTextEmbeddingResults(); |
543 |
| - } else if (parsedResults.get(0) instanceof TextExpansionResults) { |
544 |
| - return transformToSparseEmbeddingResult(parsedResults); |
545 |
| - } else { |
546 |
| - throw new ElasticsearchStatusException( |
547 |
| - "Failed to transform results to response format, unknown embedding type received," |
548 |
| - + " please remove and re-add the service", |
549 |
| - RestStatus.INTERNAL_SERVER_ERROR |
550 |
| - ); |
551 |
| - } |
552 |
| - } |
553 |
| - |
554 |
| - private static SparseEmbeddingResults transformToSparseEmbeddingResult(List<? extends InferenceResults> parsedResults) { |
555 |
| - List<TextExpansionResults> textExpansionResults = new ArrayList<>(parsedResults.size()); |
556 |
| - |
557 |
| - for (InferenceResults result : parsedResults) { |
558 |
| - if (result instanceof TextExpansionResults textExpansion) { |
559 |
| - textExpansionResults.add(textExpansion); |
560 |
| - } else { |
561 |
| - throw new ElasticsearchStatusException( |
562 |
| - "Failed to transform results to response format, please remove and re-add the service", |
563 |
| - RestStatus.INTERNAL_SERVER_ERROR |
564 |
| - ); |
565 |
| - } |
566 |
| - } |
567 |
| - |
568 |
| - return SparseEmbeddingResults.of(textExpansionResults); |
569 |
| - } |
570 |
| - |
571 | 466 | public InferenceServiceResults getResults() {
|
572 | 467 | return results;
|
573 | 468 | }
|
|
0 commit comments