Skip to content

Commit 44cdcc4

Browse files
committed
more refactor
1 parent 0ada0a9 commit 44cdcc4

File tree

26 files changed

+336
-236
lines changed

26 files changed

+336
-236
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.results;
9+
10+
import org.elasticsearch.inference.ChunkedInference;
11+
import org.elasticsearch.xcontent.XContent;
12+
13+
import java.io.IOException;
14+
import java.util.ArrayList;
15+
import java.util.Iterator;
16+
import java.util.List;
17+
18+
public record ChunkedInferenceEmbedding(List<? extends EmbeddingResults.EmbeddingChunk> chunks) implements ChunkedInference {
19+
20+
@Override
21+
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException {
22+
var asChunk = new ArrayList<Chunk>();
23+
for (var chunk : chunks()) {
24+
asChunk.add(chunk.toChunk(xcontent));
25+
}
26+
return asChunk.iterator();
27+
}
28+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingByte.java

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,28 @@
1313
import org.elasticsearch.xcontent.XContentBuilder;
1414

1515
import java.io.IOException;
16-
import java.util.ArrayList;
17-
import java.util.Iterator;
18-
import java.util.List;
19-
20-
public record ChunkedInferenceEmbeddingByte(List<ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk> chunks) implements ChunkedInference {
21-
22-
@Override
23-
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException {
24-
var asChunk = new ArrayList<Chunk>();
25-
for (var chunk : chunks) {
26-
asChunk.add(new Chunk(chunk.matchedText(), chunk.offset(), toBytesReference(xcontent, chunk.embedding())));
27-
}
28-
return asChunk.iterator();
29-
}
3016

17+
public record ChunkedInferenceEmbeddingByte() {
3118
/**
3219
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
3320
*/
34-
private static BytesReference toBytesReference(XContent xContent, byte[] value) throws IOException {
35-
XContentBuilder builder = XContentBuilder.builder(xContent);
36-
builder.startArray();
37-
for (byte v : value) {
38-
builder.value(v);
21+
22+
public record ByteEmbeddingChunk(byte[] embedding, String matchedText, ChunkedInference.TextOffset offset)
23+
implements
24+
EmbeddingResults.EmbeddingChunk {
25+
26+
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
27+
return new ChunkedInference.Chunk(matchedText, offset, toBytesReference(xcontent, embedding));
3928
}
40-
builder.endArray();
41-
return BytesReference.bytes(builder);
42-
}
4329

44-
public record ByteEmbeddingChunk(byte[] embedding, String matchedText, TextOffset offset) implements EmbeddingResults.EmbeddingChunk {}
30+
private static BytesReference toBytesReference(XContent xContent, byte[] value) throws IOException {
31+
XContentBuilder builder = XContentBuilder.builder(xContent);
32+
builder.startArray();
33+
for (byte v : value) {
34+
builder.value(v);
35+
}
36+
builder.endArray();
37+
return BytesReference.bytes(builder);
38+
}
39+
}
4540
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingFloat.java

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,35 +13,29 @@
1313
import org.elasticsearch.xcontent.XContentBuilder;
1414

1515
import java.io.IOException;
16-
import java.util.ArrayList;
17-
import java.util.Iterator;
1816
import java.util.List;
1917

20-
public record ChunkedInferenceEmbeddingFloat(List<FloatEmbeddingChunk> chunks) implements ChunkedInference {
18+
public record ChunkedInferenceEmbeddingFloat(List<FloatEmbeddingChunk> chunks) {
2119

22-
@Override
23-
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException {
24-
var asChunk = new ArrayList<Chunk>();
25-
for (var chunk : chunks) {
26-
asChunk.add(new Chunk(chunk.matchedText(), chunk.offset(), toBytesReference(xcontent, chunk.embedding())));
20+
public record FloatEmbeddingChunk(float[] embedding, String matchedText, ChunkedInference.TextOffset offset)
21+
implements
22+
EmbeddingResults.EmbeddingChunk {
23+
24+
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
25+
return new ChunkedInference.Chunk(matchedText, offset, toBytesReference(xcontent, embedding));
2726
}
28-
return asChunk.iterator();
29-
}
3027

31-
/**
32-
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
33-
*/
34-
private static BytesReference toBytesReference(XContent xContent, float[] value) throws IOException {
35-
XContentBuilder b = XContentBuilder.builder(xContent);
36-
b.startArray();
37-
for (float v : value) {
38-
b.value(v);
28+
/**
29+
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
30+
*/
31+
private static BytesReference toBytesReference(XContent xContent, float[] value) throws IOException {
32+
XContentBuilder b = XContentBuilder.builder(xContent);
33+
b.startArray();
34+
for (float v : value) {
35+
b.value(v);
36+
}
37+
b.endArray();
38+
return BytesReference.bytes(b);
3939
}
40-
b.endArray();
41-
return BytesReference.bytes(b);
4240
}
43-
44-
public record FloatEmbeddingChunk(float[] embedding, String matchedText, TextOffset offset)
45-
implements
46-
EmbeddingResults.EmbeddingChunk {}
4741
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingSparse.java

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,24 @@
1616

1717
import java.io.IOException;
1818
import java.util.ArrayList;
19-
import java.util.Iterator;
2019
import java.util.List;
2120

2221
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings;
2322

24-
public record ChunkedInferenceEmbeddingSparse(List<SparseEmbeddingChunk> chunks) implements ChunkedInference {
23+
public record ChunkedInferenceEmbeddingSparse(List<SparseEmbeddingChunk> chunks) {
2524

2625
public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbeddingResults sparseEmbeddingResults) {
2726
validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size());
2827

2928
var results = new ArrayList<ChunkedInference>(inputs.size());
3029
for (int i = 0; i < inputs.size(); i++) {
3130
results.add(
32-
new ChunkedInferenceEmbeddingSparse(
31+
new ChunkedInferenceEmbedding(
3332
List.of(
3433
new SparseEmbeddingChunk(
3534
sparseEmbeddingResults.embeddings().get(i).tokens(),
3635
inputs.get(i),
37-
new TextOffset(0, inputs.get(i).length())
36+
new ChunkedInference.TextOffset(0, inputs.get(i).length())
3837
)
3938
)
4039
)
@@ -44,26 +43,22 @@ public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbedding
4443
return results;
4544
}
4645

47-
@Override
48-
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException {
49-
var asChunk = new ArrayList<Chunk>();
50-
for (var chunk : chunks) {
51-
asChunk.add(new Chunk(chunk.matchedText(), chunk.offset(), toBytesReference(xcontent, chunk.weightedTokens())));
46+
public record SparseEmbeddingChunk(List<WeightedToken> weightedTokens, String matchedText, ChunkedInference.TextOffset offset)
47+
implements
48+
EmbeddingResults.EmbeddingChunk {
49+
50+
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
51+
return new ChunkedInference.Chunk(matchedText, offset, toBytesReference(xcontent, weightedTokens));
5252
}
53-
return asChunk.iterator();
54-
}
5553

56-
private static BytesReference toBytesReference(XContent xContent, List<WeightedToken> tokens) throws IOException {
57-
XContentBuilder b = XContentBuilder.builder(xContent);
58-
b.startObject();
59-
for (var weightedToken : tokens) {
60-
weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS);
54+
private static BytesReference toBytesReference(XContent xContent, List<WeightedToken> tokens) throws IOException {
55+
XContentBuilder b = XContentBuilder.builder(xContent);
56+
b.startObject();
57+
for (var weightedToken : tokens) {
58+
weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS);
59+
}
60+
b.endObject();
61+
return BytesReference.bytes(b);
6162
}
62-
b.endObject();
63-
return BytesReference.bytes(b);
6463
}
65-
66-
public record SparseEmbeddingChunk(List<WeightedToken> weightedTokens, String matchedText, TextOffset offset)
67-
implements
68-
EmbeddingResults.EmbeddingChunk {}
6964
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingResults.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,22 @@
99

1010
import org.elasticsearch.inference.ChunkedInference;
1111
import org.elasticsearch.inference.InferenceServiceResults;
12+
import org.elasticsearch.xcontent.XContent;
1213

14+
import java.io.IOException;
1315
import java.util.List;
1416

1517
public interface EmbeddingResults<C extends EmbeddingResults.EmbeddingChunk, E extends EmbeddingResults.EmbeddingResult<C>>
1618
extends
1719
InferenceServiceResults {
1820

19-
interface EmbeddingChunk {}
21+
interface EmbeddingChunk {
22+
ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException;
23+
24+
String matchedText();
25+
26+
ChunkedInference.TextOffset offset();
27+
}
2028

2129
interface EmbeddingResult<C extends EmbeddingResults.EmbeddingChunk> {
2230
C toEmbeddingChunk(String text, ChunkedInference.TextOffset offset);

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.rest.RestStatus;
3535
import org.elasticsearch.xcontent.ToXContentObject;
3636
import org.elasticsearch.xcontent.XContentBuilder;
37+
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
3738
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat;
3839
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
3940

@@ -179,7 +180,7 @@ private List<ChunkedInference> makeChunkedResults(List<String> input, int dimens
179180
var results = new ArrayList<ChunkedInference>();
180181
for (int i = 0; i < input.size(); i++) {
181182
results.add(
182-
new ChunkedInferenceEmbeddingFloat(
183+
new ChunkedInferenceEmbedding(
183184
List.of(
184185
new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk(
185186
nonChunkedResults.embeddings().get(i).values(),

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.elasticsearch.rest.RestStatus;
3333
import org.elasticsearch.xcontent.ToXContentObject;
3434
import org.elasticsearch.xcontent.XContentBuilder;
35+
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
3536
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
3637
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
3738
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
@@ -171,7 +172,7 @@ private List<ChunkedInference> makeChunkedResults(List<String> input) {
171172
tokens.add(new WeightedToken("feature_" + j, generateEmbedding(input.get(i), j)));
172173
}
173174
results.add(
174-
new ChunkedInferenceEmbeddingSparse(
175+
new ChunkedInferenceEmbedding(
175176
List.of(
176177
new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk(
177178
tokens,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
import org.elasticsearch.inference.ChunkingSettings;
1616
import org.elasticsearch.inference.InferenceServiceResults;
1717
import org.elasticsearch.rest.RestStatus;
18-
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingByte;
19-
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat;
20-
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
18+
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
2119
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
2220
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
2321
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
@@ -237,28 +235,6 @@ private ChunkedInference mergeResultsWithInputs(int index) {
237235
);
238236
embeddingChunks.add(chunk);
239237
}
240-
241-
switch (embeddingType) {
242-
case FLOAT:
243-
List<ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk> floatEmbeddingChunks = new ArrayList<>();
244-
embeddingChunks.forEach(chunk -> floatEmbeddingChunks.add((ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk) chunk));
245-
return new ChunkedInferenceEmbeddingFloat(floatEmbeddingChunks);
246-
case BYTE:
247-
List<ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk> byteEmbeddingChunks = new ArrayList<>();
248-
embeddingChunks.forEach(chunk -> byteEmbeddingChunks.add((ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk) chunk));
249-
return new ChunkedInferenceEmbeddingByte(byteEmbeddingChunks);
250-
case SPARSE:
251-
List<ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk> sparseEmbeddingChunks = new ArrayList<>();
252-
embeddingChunks.forEach(chunk -> sparseEmbeddingChunks.add((ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk) chunk));
253-
return new ChunkedInferenceEmbeddingSparse(sparseEmbeddingChunks);
254-
default:
255-
return new ChunkedInferenceError(
256-
new ElasticsearchStatusException(
257-
"Unexpected class [{}]",
258-
RestStatus.INTERNAL_SERVER_ERROR,
259-
embeddingChunks.getFirst().getClass().getName()
260-
)
261-
);
262-
}
238+
return new ChunkedInferenceEmbedding(embeddingChunks);
263239
}
264240
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.inference.TaskType;
2626
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
2727
import org.elasticsearch.rest.RestStatus;
28+
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
2829
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat;
2930
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
3031
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
@@ -118,7 +119,7 @@ private static List<ChunkedInference> translateToChunkedResults(DocumentsOnlyInp
118119

119120
for (int i = 0; i < inputs.getInputs().size(); i++) {
120121
results.add(
121-
new ChunkedInferenceEmbeddingFloat(
122+
new ChunkedInferenceEmbedding(
122123
List.of(
123124
new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk(
124125
textEmbeddingResults.embeddings().get(i).values(),

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
import org.elasticsearch.threadpool.ThreadPool;
4848
import org.elasticsearch.xcontent.XContentType;
4949
import org.elasticsearch.xcontent.json.JsonXContent;
50-
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
50+
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
5151
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
5252
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
5353
import org.elasticsearch.xpack.inference.model.TestModel;
@@ -596,7 +596,7 @@ public static StaticModel createRandomInstance() {
596596
}
597597

598598
ChunkedInference getResults(String text) {
599-
return resultMap.getOrDefault(text, new ChunkedInferenceEmbeddingSparse(List.of()));
599+
return resultMap.getOrDefault(text, new ChunkedInferenceEmbedding(List.of()));
600600
}
601601

602602
void putResult(String text, ChunkedInference result) {

0 commit comments

Comments
 (0)