|
7 | 7 |
|
8 | 8 | package org.elasticsearch.xpack.core.inference.action; |
9 | 9 |
|
10 | | -import org.elasticsearch.ElasticsearchStatusException; |
11 | 10 | import org.elasticsearch.TransportVersion; |
12 | | -import org.elasticsearch.TransportVersions; |
13 | 11 | import org.elasticsearch.action.ActionRequestValidationException; |
14 | 12 | import org.elasticsearch.action.ActionResponse; |
15 | 13 | import org.elasticsearch.action.ActionType; |
|
19 | 17 | import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; |
20 | 18 | import org.elasticsearch.common.xcontent.ChunkedToXContentObject; |
21 | 19 | import org.elasticsearch.core.TimeValue; |
22 | | -import org.elasticsearch.inference.InferenceResults; |
23 | 20 | import org.elasticsearch.inference.InferenceServiceResults; |
24 | 21 | import org.elasticsearch.inference.InputType; |
25 | 22 | import org.elasticsearch.inference.TaskType; |
26 | | -import org.elasticsearch.rest.RestStatus; |
27 | 23 | import org.elasticsearch.xcontent.ObjectParser; |
28 | 24 | import org.elasticsearch.xcontent.ParseField; |
29 | 25 | import org.elasticsearch.xcontent.ToXContent; |
30 | 26 | import org.elasticsearch.xcontent.XContentParser; |
31 | 27 | import org.elasticsearch.xpack.core.inference.InferenceContext; |
32 | | -import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; |
33 | | -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; |
34 | | -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; |
35 | 28 |
|
36 | 29 | import java.io.IOException; |
37 | | -import java.util.ArrayList; |
38 | | -import java.util.EnumSet; |
39 | 30 | import java.util.Iterator; |
40 | 31 | import java.util.List; |
41 | 32 | import java.util.Map; |
@@ -79,12 +70,6 @@ public static Builder builder(String inferenceEntityId, TaskType taskType) { |
79 | 70 | PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT); |
80 | 71 | } |
81 | 72 |
|
82 | | - private static final EnumSet<InputType> validEnumsBeforeUnspecifiedAdded = EnumSet.of(InputType.INGEST, InputType.SEARCH); |
83 | | - private static final EnumSet<InputType> validEnumsBeforeClassificationClusteringAdded = EnumSet.range( |
84 | | - InputType.INGEST, |
85 | | - InputType.UNSPECIFIED |
86 | | - ); |
87 | | - |
88 | 73 | public static Builder parseRequest(String inferenceEntityId, TaskType taskType, InferenceContext context, XContentParser parser) |
89 | 74 | throws IOException { |
90 | 75 | Request.Builder builder = PARSER.apply(parser, null); |
@@ -164,25 +149,11 @@ public Request(StreamInput in) throws IOException { |
164 | 149 | super(in); |
165 | 150 | this.taskType = TaskType.fromStream(in); |
166 | 151 | this.inferenceEntityId = in.readString(); |
167 | | - if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) { |
168 | | - this.input = in.readStringCollectionAsList(); |
169 | | - } else { |
170 | | - this.input = List.of(in.readString()); |
171 | | - } |
| 152 | + this.input = in.readStringCollectionAsList(); |
172 | 153 | this.taskSettings = in.readGenericMap(); |
173 | | - if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) { |
174 | | - this.inputType = in.readEnum(InputType.class); |
175 | | - } else { |
176 | | - this.inputType = InputType.UNSPECIFIED; |
177 | | - } |
178 | | - |
179 | | - if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) { |
180 | | - this.query = in.readOptionalString(); |
181 | | - this.inferenceTimeout = in.readTimeValue(); |
182 | | - } else { |
183 | | - this.query = null; |
184 | | - this.inferenceTimeout = DEFAULT_TIMEOUT; |
185 | | - } |
| 154 | + this.inputType = in.readEnum(InputType.class); |
| 155 | + this.query = in.readOptionalString(); |
| 156 | + this.inferenceTimeout = in.readTimeValue(); |
186 | 157 |
|
187 | 158 | if (in.getTransportVersion().supports(RERANK_COMMON_OPTIONS_ADDED)) { |
188 | 159 | this.returnDocuments = in.readOptionalBoolean(); |
@@ -298,41 +269,18 @@ public void writeTo(StreamOutput out) throws IOException { |
298 | 269 | super.writeTo(out); |
299 | 270 | taskType.writeTo(out); |
300 | 271 | out.writeString(inferenceEntityId); |
301 | | - if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) { |
302 | | - out.writeStringCollection(input); |
303 | | - } else { |
304 | | - out.writeString(input.get(0)); |
305 | | - } |
| 272 | + out.writeStringCollection(input); |
306 | 273 | out.writeGenericMap(taskSettings); |
307 | | - |
308 | | - if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) { |
309 | | - out.writeEnum(getInputTypeToWrite(inputType, out.getTransportVersion())); |
310 | | - } |
311 | | - |
312 | | - if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) { |
313 | | - out.writeOptionalString(query); |
314 | | - out.writeTimeValue(inferenceTimeout); |
315 | | - } |
| 274 | + out.writeEnum(inputType); |
| 275 | + out.writeOptionalString(query); |
| 276 | + out.writeTimeValue(inferenceTimeout); |
316 | 277 |
|
317 | 278 | if (out.getTransportVersion().supports(RERANK_COMMON_OPTIONS_ADDED)) { |
318 | 279 | out.writeOptionalBoolean(returnDocuments); |
319 | 280 | out.writeOptionalInt(topN); |
320 | 281 | } |
321 | 282 | } |
322 | 283 |
|
323 | | - // default for easier testing |
324 | | - static InputType getInputTypeToWrite(InputType inputType, TransportVersion version) { |
325 | | - if (version.before(TransportVersions.V_8_13_0)) { |
326 | | - if (validEnumsBeforeUnspecifiedAdded.contains(inputType) == false) { |
327 | | - return InputType.INGEST; |
328 | | - } else if (validEnumsBeforeClassificationClusteringAdded.contains(inputType) == false) { |
329 | | - return InputType.UNSPECIFIED; |
330 | | - } |
331 | | - } |
332 | | - |
333 | | - return inputType; |
334 | | - } |
335 | | - |
336 | 284 | @Override |
337 | 285 | public boolean equals(Object o) { |
338 | 286 | if (this == o) return true; |
@@ -509,65 +457,12 @@ public Response(InferenceServiceResults results, Flow.Publisher<InferenceService |
509 | 457 | } |
510 | 458 |
|
511 | 459 | public Response(StreamInput in) throws IOException { |
512 | | - if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) { |
513 | | - results = in.readNamedWriteable(InferenceServiceResults.class); |
514 | | - } else { |
515 | | - // It should only be InferenceResults aka TextEmbeddingResults from ml plugin for |
516 | | - // hugging face elser and elser |
517 | | - results = transformToServiceResults(List.of(in.readNamedWriteable(InferenceResults.class))); |
518 | | - } |
| 460 | + this.results = in.readNamedWriteable(InferenceServiceResults.class); |
519 | 461 | // streaming isn't supported via Writeable yet |
520 | 462 | this.isStreaming = false; |
521 | 463 | this.publisher = null; |
522 | 464 | } |
523 | 465 |
|
524 | | - @SuppressWarnings("deprecation") |
525 | | - public static InferenceServiceResults transformToServiceResults(List<? extends InferenceResults> parsedResults) { |
526 | | - if (parsedResults.isEmpty()) { |
527 | | - throw new ElasticsearchStatusException( |
528 | | - "Failed to transform results to response format, expected a non-empty list, please remove and re-add the service", |
529 | | - RestStatus.INTERNAL_SERVER_ERROR |
530 | | - ); |
531 | | - } |
532 | | - |
533 | | - if (parsedResults.get(0) instanceof LegacyTextEmbeddingResults openaiResults) { |
534 | | - if (parsedResults.size() > 1) { |
535 | | - throw new ElasticsearchStatusException( |
536 | | - "Failed to transform results to response format, malformed text embedding result," |
537 | | - + " please remove and re-add the service", |
538 | | - RestStatus.INTERNAL_SERVER_ERROR |
539 | | - ); |
540 | | - } |
541 | | - |
542 | | - return openaiResults.transformToTextEmbeddingResults(); |
543 | | - } else if (parsedResults.get(0) instanceof TextExpansionResults) { |
544 | | - return transformToSparseEmbeddingResult(parsedResults); |
545 | | - } else { |
546 | | - throw new ElasticsearchStatusException( |
547 | | - "Failed to transform results to response format, unknown embedding type received," |
548 | | - + " please remove and re-add the service", |
549 | | - RestStatus.INTERNAL_SERVER_ERROR |
550 | | - ); |
551 | | - } |
552 | | - } |
553 | | - |
554 | | - private static SparseEmbeddingResults transformToSparseEmbeddingResult(List<? extends InferenceResults> parsedResults) { |
555 | | - List<TextExpansionResults> textExpansionResults = new ArrayList<>(parsedResults.size()); |
556 | | - |
557 | | - for (InferenceResults result : parsedResults) { |
558 | | - if (result instanceof TextExpansionResults textExpansion) { |
559 | | - textExpansionResults.add(textExpansion); |
560 | | - } else { |
561 | | - throw new ElasticsearchStatusException( |
562 | | - "Failed to transform results to response format, please remove and re-add the service", |
563 | | - RestStatus.INTERNAL_SERVER_ERROR |
564 | | - ); |
565 | | - } |
566 | | - } |
567 | | - |
568 | | - return SparseEmbeddingResults.of(textExpansionResults); |
569 | | - } |
570 | | - |
571 | 466 | public InferenceServiceResults getResults() { |
572 | 467 | return results; |
573 | 468 | } |
|
0 commit comments