Skip to content

Commit 73fe062

Browse files
committed
inference generics
1 parent a7d799c commit 73fe062

File tree

9 files changed

+132
-159
lines changed

9 files changed

+132
-159
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingByte.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,5 @@ private static BytesReference toBytesReference(XContent xContent, byte[] value)
4141
return BytesReference.bytes(builder);
4242
}
4343

44-
public record ByteEmbeddingChunk(byte[] embedding, String matchedText, TextOffset offset) {}
44+
public record ByteEmbeddingChunk(byte[] embedding, String matchedText, TextOffset offset) implements EmbeddingResults.EmbeddingChunk {}
4545
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingFloat.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,7 @@ private static BytesReference toBytesReference(XContent xContent, float[] value)
4141
return BytesReference.bytes(b);
4242
}
4343

44-
public record FloatEmbeddingChunk(float[] embedding, String matchedText, TextOffset offset) {}
44+
public record FloatEmbeddingChunk(float[] embedding, String matchedText, TextOffset offset)
45+
implements
46+
EmbeddingResults.EmbeddingChunk {}
4547
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingSparse.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,7 @@ private static BytesReference toBytesReference(XContent xContent, List<WeightedT
6363
return BytesReference.bytes(b);
6464
}
6565

66-
public record SparseEmbeddingChunk(List<WeightedToken> weightedTokens, String matchedText, TextOffset offset) {}
66+
public record SparseEmbeddingChunk(List<WeightedToken> weightedTokens, String matchedText, TextOffset offset)
67+
implements
68+
EmbeddingResults.EmbeddingChunk {}
6769
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.results;
9+
10+
import org.elasticsearch.inference.ChunkedInference;
11+
import org.elasticsearch.inference.InferenceServiceResults;
12+
13+
import java.util.List;
14+
15+
public interface EmbeddingResults<C extends EmbeddingResults.EmbeddingChunk, E extends EmbeddingResults.EmbeddingResult<C>>
16+
extends
17+
InferenceServiceResults {
18+
19+
interface EmbeddingChunk {}
20+
21+
interface EmbeddingResult<C extends EmbeddingResults.EmbeddingChunk> {
22+
C toEmbeddingChunk(String text, ChunkedInference.TextOffset offset);
23+
}
24+
25+
List<E> embeddings();
26+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceByteEmbedding.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,20 @@
1313
import org.elasticsearch.common.io.stream.StreamInput;
1414
import org.elasticsearch.common.io.stream.StreamOutput;
1515
import org.elasticsearch.common.io.stream.Writeable;
16+
import org.elasticsearch.inference.ChunkedInference;
1617
import org.elasticsearch.xcontent.ToXContentObject;
1718
import org.elasticsearch.xcontent.XContentBuilder;
1819

1920
import java.io.IOException;
2021
import java.util.Arrays;
2122
import java.util.List;
2223

23-
public record InferenceByteEmbedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingInt {
24+
public record InferenceByteEmbedding(byte[] values)
25+
implements
26+
Writeable,
27+
ToXContentObject,
28+
EmbeddingInt,
29+
EmbeddingResults.EmbeddingResult<ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk> {
2430
public static final String EMBEDDING = "embedding";
2531

2632
public InferenceByteEmbedding(StreamInput in) throws IOException {
@@ -92,4 +98,9 @@ public boolean equals(Object o) {
9298
public int hashCode() {
9399
return Arrays.hashCode(values);
94100
}
101+
102+
@Override
103+
public ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk toEmbeddingChunk(String text, ChunkedInference.TextOffset offset) {
104+
return new ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk(values, text, offset);
105+
}
95106
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingByteResults.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import org.elasticsearch.common.io.stream.StreamOutput;
1414
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
1515
import org.elasticsearch.inference.InferenceResults;
16-
import org.elasticsearch.inference.InferenceServiceResults;
1716
import org.elasticsearch.xcontent.ToXContent;
1817
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
1918

@@ -42,7 +41,10 @@
4241
* ]
4342
* }
4443
*/
45-
public record InferenceTextEmbeddingByteResults(List<InferenceByteEmbedding> embeddings) implements InferenceServiceResults, TextEmbedding {
44+
public record InferenceTextEmbeddingByteResults(List<InferenceByteEmbedding> embeddings)
45+
implements
46+
EmbeddingResults<ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk, InferenceByteEmbedding>,
47+
TextEmbedding {
4648
public static final String NAME = "text_embedding_service_byte_results";
4749
public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes";
4850

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingFloatResults.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
import org.elasticsearch.common.io.stream.StreamOutput;
1616
import org.elasticsearch.common.io.stream.Writeable;
1717
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
18+
import org.elasticsearch.inference.ChunkedInference;
1819
import org.elasticsearch.inference.InferenceResults;
19-
import org.elasticsearch.inference.InferenceServiceResults;
2020
import org.elasticsearch.inference.TaskType;
2121
import org.elasticsearch.rest.RestStatus;
2222
import org.elasticsearch.xcontent.ToXContent;
@@ -53,7 +53,7 @@
5353
*/
5454
public record InferenceTextEmbeddingFloatResults(List<InferenceFloatEmbedding> embeddings)
5555
implements
56-
InferenceServiceResults,
56+
EmbeddingResults<ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk, InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>,
5757
TextEmbedding {
5858
public static final String NAME = "text_embedding_service_results";
5959
public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString();
@@ -151,7 +151,12 @@ public int hashCode() {
151151
return Objects.hash(embeddings);
152152
}
153153

154-
public record InferenceFloatEmbedding(float[] values) implements Writeable, ToXContentObject, EmbeddingInt {
154+
public record InferenceFloatEmbedding(float[] values)
155+
implements
156+
Writeable,
157+
ToXContentObject,
158+
EmbeddingInt,
159+
EmbeddingResults.EmbeddingResult<ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk> {
155160
public static final String EMBEDDING = "embedding";
156161

157162
public InferenceFloatEmbedding(StreamInput in) throws IOException {
@@ -220,5 +225,10 @@ public boolean equals(Object o) {
220225
public int hashCode() {
221226
return Arrays.hashCode(values);
222227
}
228+
229+
@Override
230+
public ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk toEmbeddingChunk(String text, ChunkedInference.TextOffset offset) {
231+
return new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk(values, text, offset);
232+
}
223233
}
224234
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import org.elasticsearch.common.io.stream.StreamOutput;
1414
import org.elasticsearch.common.io.stream.Writeable;
1515
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
16+
import org.elasticsearch.inference.ChunkedInference;
1617
import org.elasticsearch.inference.InferenceResults;
17-
import org.elasticsearch.inference.InferenceServiceResults;
1818
import org.elasticsearch.inference.TaskType;
1919
import org.elasticsearch.rest.RestStatus;
2020
import org.elasticsearch.xcontent.ToXContent;
@@ -33,7 +33,9 @@
3333

3434
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
3535

36-
public record SparseEmbeddingResults(List<Embedding> embeddings) implements InferenceServiceResults {
36+
public record SparseEmbeddingResults(List<Embedding> embeddings)
37+
implements
38+
EmbeddingResults<ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk, SparseEmbeddingResults.Embedding> {
3739

3840
public static final String NAME = "sparse_embedding_results";
3941
public static final String SPARSE_EMBEDDING = TaskType.SPARSE_EMBEDDING.toString();
@@ -114,7 +116,11 @@ public List<? extends InferenceResults> transformToLegacyFormat() {
114116
.toList();
115117
}
116118

117-
public record Embedding(List<WeightedToken> tokens, boolean isTruncated) implements Writeable, ToXContentObject {
119+
public record Embedding(List<WeightedToken> tokens, boolean isTruncated)
120+
implements
121+
Writeable,
122+
ToXContentObject,
123+
EmbeddingResult<ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk> {
118124

119125
public static final String EMBEDDING = "embedding";
120126
public static final String IS_TRUNCATED = "is_truncated";
@@ -163,5 +169,10 @@ public Map<String, Object> asMap() {
163169
public String toString() {
164170
return Strings.toString(this);
165171
}
172+
173+
@Override
174+
public ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk toEmbeddingChunk(String text, ChunkedInference.TextOffset offset) {
175+
return new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk(tokens, text, offset);
176+
}
166177
}
167178
}

0 commit comments

Comments
 (0)