Skip to content

Commit ebbcd12

Browse files
committed
remove Chunk generic
1 parent b31158e commit ebbcd12

File tree

9 files changed

+17
-27
lines changed

9 files changed

+17
-27
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingResults.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919
* A call to the inference service may contain multiple input texts, so this results may
2020
* contain multiple results.
2121
*/
22-
public interface EmbeddingResults<C extends EmbeddingResults.Chunk, E extends EmbeddingResults.Embedding<C, E>>
23-
extends
24-
InferenceServiceResults {
22+
public interface EmbeddingResults<E extends EmbeddingResults.Embedding<E>> extends InferenceServiceResults {
2523

2624
/**
2725
* A resulting embedding together with the offset into the input text.
@@ -35,11 +33,11 @@ interface Chunk {
3533
/**
3634
* A resulting embedding for one of the input texts to the inference service.
3735
*/
38-
interface Embedding<C extends Chunk, E extends EmbeddingResults.Embedding<C, E>> {
36+
interface Embedding<E extends Embedding<E>> {
3937
/**
4038
* Combines the resulting embedding with the offset into the input text into a chunk.
4139
*/
42-
C toChunk(ChunkedInference.TextOffset offset);
40+
Chunk toChunk(ChunkedInference.TextOffset offset);
4341

4442
/**
4543
* Merges the existing embedding and provided embedding into a new embedding.

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@
3737

3838
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
3939

40-
public record SparseEmbeddingResults(List<Embedding> embeddings)
41-
implements
42-
EmbeddingResults<SparseEmbeddingResults.Chunk, SparseEmbeddingResults.Embedding> {
40+
public record SparseEmbeddingResults(List<Embedding> embeddings) implements EmbeddingResults<SparseEmbeddingResults.Embedding> {
4341

4442
public static final String NAME = "sparse_embedding_results";
4543
public static final String SPARSE_EMBEDDING = TaskType.SPARSE_EMBEDDING.toString();
@@ -126,7 +124,7 @@ public record Embedding(List<WeightedToken> tokens, boolean isTruncated)
126124
implements
127125
Writeable,
128126
ToXContentObject,
129-
EmbeddingResults.Embedding<Chunk, Embedding> {
127+
EmbeddingResults.Embedding<Embedding> {
130128

131129
public static final String EMBEDDING = "embedding";
132130
public static final String IS_TRUNCATED = "is_truncated";

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
*/
4343
public record TextEmbeddingBitResults(List<TextEmbeddingByteResults.Embedding> embeddings)
4444
implements
45-
TextEmbeddingResults<TextEmbeddingByteResults.Chunk, TextEmbeddingByteResults.Embedding> {
45+
TextEmbeddingResults<TextEmbeddingByteResults.Embedding> {
4646
public static final String NAME = "text_embedding_service_bit_results";
4747
public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits";
4848

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@
4848
* ]
4949
* }
5050
*/
51-
public record TextEmbeddingByteResults(List<Embedding> embeddings)
52-
implements
53-
TextEmbeddingResults<TextEmbeddingByteResults.Chunk, TextEmbeddingByteResults.Embedding> {
51+
public record TextEmbeddingByteResults(List<Embedding> embeddings) implements TextEmbeddingResults<TextEmbeddingByteResults.Embedding> {
5452
public static final String NAME = "text_embedding_service_byte_results";
5553
public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes";
5654

@@ -124,7 +122,7 @@ public record Embedding(byte[] values, int[] sumMergedValues, int numberOfMerged
124122
implements
125123
Writeable,
126124
ToXContentObject,
127-
EmbeddingResults.Embedding<Chunk, Embedding> {
125+
EmbeddingResults.Embedding<Embedding> {
128126

129127
public static final String EMBEDDING = "embedding";
130128

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@
5353
* ]
5454
* }
5555
*/
56-
public record TextEmbeddingFloatResults(List<Embedding> embeddings)
57-
implements
58-
TextEmbeddingResults<TextEmbeddingFloatResults.Chunk, TextEmbeddingFloatResults.Embedding> {
56+
public record TextEmbeddingFloatResults(List<Embedding> embeddings) implements TextEmbeddingResults<TextEmbeddingFloatResults.Embedding> {
5957
public static final String NAME = "text_embedding_service_results";
6058
public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString();
6159

@@ -161,7 +159,7 @@ public record Embedding(float[] values, int numberOfMergedEmbeddings)
161159
implements
162160
Writeable,
163161
ToXContentObject,
164-
EmbeddingResults.Embedding<Chunk, Embedding> {
162+
EmbeddingResults.Embedding<Embedding> {
165163
public static final String EMBEDDING = "embedding";
166164

167165
public Embedding(float[] values) {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77

88
package org.elasticsearch.xpack.core.inference.results;
99

10-
public interface TextEmbeddingResults<C extends EmbeddingResults.Chunk, E extends EmbeddingResults.Embedding<C, E>>
11-
extends
12-
EmbeddingResults<C, E> {
10+
public interface TextEmbeddingResults<E extends EmbeddingResults.Embedding<E>> extends EmbeddingResults<E> {
1311

1412
/**
1513
* Returns the first text embedding entry in the result list's array size.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
* processing and map the results back to the original element
3737
* in the input list.
3838
*/
39-
public class EmbeddingRequestChunker<C extends EmbeddingResults.Chunk, E extends EmbeddingResults.Embedding<C, E>> {
39+
public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {
4040

4141
// Visible for testing
4242
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List<String> inputs) {
@@ -150,12 +150,12 @@ private class DebatchingListener implements ActionListener<InferenceServiceResul
150150

151151
@Override
152152
public void onResponse(InferenceServiceResults inferenceServiceResults) {
153-
if (inferenceServiceResults instanceof EmbeddingResults<?, ?> == false) {
153+
if (inferenceServiceResults instanceof EmbeddingResults<?> == false) {
154154
onFailure(unexpectedResultTypeException(inferenceServiceResults.getWriteableName()));
155155
return;
156156
}
157157
@SuppressWarnings("unchecked")
158-
EmbeddingResults<C, E> embeddingResults = (EmbeddingResults<C, E>) inferenceServiceResults;
158+
EmbeddingResults<E> embeddingResults = (EmbeddingResults<E>) inferenceServiceResults;
159159
if (embeddingResults.embeddings().size() != request.requests.size()) {
160160
onFailure(numResultsDoesntMatchException(embeddingResults.embeddings().size(), request.requests.size()));
161161
return;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ public static void getEmbeddingSize(Model model, InferenceService service, Actio
741741
InputType.INGEST,
742742
InferenceAction.Request.DEFAULT_TIMEOUT,
743743
listener.delegateFailureAndWrap((delegate, r) -> {
744-
if (r instanceof TextEmbeddingResults<?, ?> embeddingResults) {
744+
if (r instanceof TextEmbeddingResults<?> embeddingResults) {
745745
try {
746746
delegate.onResponse(embeddingResults.getFirstEmbeddingSize());
747747
} catch (Exception e) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public void validate(InferenceService service, Model model, ActionListener<Model
3333
}
3434

3535
private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) {
36-
if (results instanceof TextEmbeddingResults<?, ?> embeddingResults) {
36+
if (results instanceof TextEmbeddingResults<?> embeddingResults) {
3737
var serviceSettings = model.getServiceSettings();
3838
var dimensions = serviceSettings.dimensions();
3939
int embeddingSize = getEmbeddingSize(embeddingResults);
@@ -67,7 +67,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi
6767
}
6868
}
6969

70-
private int getEmbeddingSize(TextEmbeddingResults<?, ?> embeddingResults) {
70+
private int getEmbeddingSize(TextEmbeddingResults<?> embeddingResults) {
7171
int embeddingSize;
7272
try {
7373
embeddingSize = embeddingResults.getFirstEmbeddingSize();

0 commit comments

Comments
 (0)