|
7 | 7 |
|
8 | 8 | package org.elasticsearch.xpack.inference.external.response.elastic; |
9 | 9 |
|
10 | | -import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; |
11 | 10 | import org.elasticsearch.common.xcontent.XContentParserUtils; |
| 11 | +import org.elasticsearch.xcontent.ConstructingObjectParser; |
| 12 | +import org.elasticsearch.xcontent.ParseField; |
12 | 13 | import org.elasticsearch.xcontent.XContentFactory; |
13 | | -import org.elasticsearch.xcontent.XContentParser; |
14 | 14 | import org.elasticsearch.xcontent.XContentParserConfiguration; |
15 | 15 | import org.elasticsearch.xcontent.XContentType; |
16 | 16 | import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; |
17 | 17 | import org.elasticsearch.xpack.inference.external.http.HttpResult; |
18 | 18 | import org.elasticsearch.xpack.inference.external.request.Request; |
19 | 19 |
|
20 | 20 | import java.io.IOException; |
21 | | -import java.util.Collections; |
22 | 21 | import java.util.List; |
23 | 22 |
|
24 | | -import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; |
25 | | -import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; |
26 | | -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; |
27 | | -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; |
| 23 | +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; |
28 | 24 |
|
29 | 25 | public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntity { |
30 | 26 |
|
31 | | - private static final String FAILED_TO_FIND_FIELD_TEMPLATE = |
32 | | - "Failed to find required field [%s] in Elastic Inference Service dense text embeddings response"; |
33 | | - |
34 | 27 | /** |
35 | 28 | * Parses the Elastic Inference Service Dense Text Embeddings response. |
36 | 29 | * |
@@ -64,43 +57,51 @@ public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntity { |
64 | 57 | * </code> |
65 | 58 | * </pre> |
66 | 59 | */ |
67 | | - |
68 | 60 | public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { |
69 | | - var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); |
70 | | - |
71 | | - try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { |
72 | | - moveToFirstToken(jsonParser); |
73 | | - |
74 | | - XContentParser.Token token = jsonParser.currentToken(); |
75 | | - ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); |
| 61 | + try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) { |
| 62 | + return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults(); |
| 63 | + } |
| 64 | + } |
76 | 65 |
|
77 | | - positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); |
| 66 | + public record EmbeddingFloatResult(List<EmbeddingFloatResultEntry> embeddingResults) { |
| 67 | + @SuppressWarnings("unchecked") |
| 68 | + public static final ConstructingObjectParser<EmbeddingFloatResult, Void> PARSER = new ConstructingObjectParser<>( |
| 69 | + EmbeddingFloatResult.class.getSimpleName(), |
| 70 | + true, |
| 71 | + args -> new EmbeddingFloatResult((List<EmbeddingFloatResultEntry>) args[0]) |
| 72 | + ); |
78 | 73 |
|
79 | | - List<TextEmbeddingFloatResults.Embedding> parsedEmbeddings = parseList( |
80 | | - jsonParser, |
81 | | - (parser, index) -> ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.parseTextEmbeddingObject(parser) |
| 74 | + static { |
| 75 | + // Custom field declaration to handle array of arrays format |
| 76 | + PARSER.declareField( |
| 77 | + constructorArg(), |
| 78 | + (parser, context) -> { |
| 79 | + return XContentParserUtils.parseList(parser, (p, index) -> { |
| 80 | + List<Float> embedding = XContentParserUtils.parseList(p, (innerParser, innerIndex) -> innerParser.floatValue()); |
| 81 | + return EmbeddingFloatResultEntry.fromFloatArray(embedding); |
| 82 | + }); |
| 83 | + }, |
| 84 | + new ParseField("data"), |
| 85 | + org.elasticsearch.xcontent.ObjectParser.ValueType.OBJECT_ARRAY |
82 | 86 | ); |
83 | | - |
84 | | - if (parsedEmbeddings.isEmpty()) { |
85 | | - return new TextEmbeddingFloatResults(Collections.emptyList()); |
86 | | - } |
87 | | - |
88 | | - return new TextEmbeddingFloatResults(parsedEmbeddings); |
89 | 87 | } |
90 | | - } |
91 | 88 |
|
92 | | - private static TextEmbeddingFloatResults.Embedding parseTextEmbeddingObject(XContentParser parser) throws IOException { |
93 | | - List<Float> embeddingValueList = parseList( |
94 | | - parser, |
95 | | - ElasticInferenceServiceDenseTextEmbeddingsResponseEntity::parseEmbeddingFloatValueList |
96 | | - ); |
97 | | - return TextEmbeddingFloatResults.Embedding.of(embeddingValueList); |
| 89 | + public TextEmbeddingFloatResults toTextEmbeddingFloatResults() { |
| 90 | + return new TextEmbeddingFloatResults( |
| 91 | + embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList() |
| 92 | + ); |
| 93 | + } |
98 | 94 | } |
99 | 95 |
|
100 | | - private static float parseEmbeddingFloatValueList(XContentParser parser) throws IOException { |
101 | | - XContentParser.Token token = parser.currentToken(); |
102 | | - XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); |
103 | | - return parser.floatValue(); |
| 96 | + /** |
| 97 | + * Represents a single embedding entry in the response. |
| 98 | + * For the Elastic Inference Service, each entry is just an array of floats (no wrapper object). |
| 99 | + * This is a simpler wrapper that just holds the float array. |
| 100 | + */ |
| 101 | + public record EmbeddingFloatResultEntry(List<Float> embedding) { |
| 102 | + public static EmbeddingFloatResultEntry fromFloatArray(List<Float> floats) { |
| 103 | + return new EmbeddingFloatResultEntry(floats); |
| 104 | + } |
104 | 105 | } |
105 | 106 |
|
106 | 107 | private ElasticInferenceServiceDenseTextEmbeddingsResponseEntity() {} |
|
0 commit comments