Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,11 @@ public interface ChunkedInference {
Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException;

/**
* A chunk of inference results containing matched text, the substring location
* in the original text and the bytes reference.
* @param matchedText
* A chunk of inference results containing the substring location in the original text and the bytes reference.
* @param textOffset
* @param bytesReference
*/
record Chunk(String matchedText, TextOffset textOffset, BytesReference bytesReference) {}
record Chunk(TextOffset textOffset, BytesReference bytesReference) {}

record TextOffset(int start, int end) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbedding
List.of(
new SparseEmbeddingResults.Chunk(
sparseEmbeddingResults.embeddings().get(i).tokens(),
inputs.get(i),
new TextOffset(0, inputs.get(i).length())
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.xcontent.XContent;

import java.util.Collections;
import java.util.Iterator;
import java.util.stream.Stream;

public record ChunkedInferenceError(Exception exception) implements ChunkedInference {

@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) {
return Stream.of(exception).map(e -> new Chunk(e.getMessage(), new TextOffset(0, 0), BytesArray.EMPTY)).iterator();
return Collections.emptyIterator();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@ public interface EmbeddingResults<C extends EmbeddingResults.Chunk, E extends Em
InferenceServiceResults {

/**
* A resulting embedding together with its input text.
* A resulting embedding together with the offset into the input text.
*/
interface Chunk {
ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException;

String matchedText();

ChunkedInference.TextOffset offset();
}

Expand All @@ -39,9 +37,9 @@ interface Chunk {
*/
interface Embedding<C extends Chunk> {
/**
* Combines the resulting embedding with the input into a chunk.
* Combines the resulting embedding with the offset into the input text into a chunk.
*/
C toChunk(String text, ChunkedInference.TextOffset offset);
C toChunk(ChunkedInference.TextOffset offset);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,15 @@ public String toString() {
}

@Override
public Chunk toChunk(String text, ChunkedInference.TextOffset offset) {
return new Chunk(tokens, text, offset);
public Chunk toChunk(ChunkedInference.TextOffset offset) {
return new Chunk(tokens, offset);
}
}

public record Chunk(List<WeightedToken> weightedTokens, String matchedText, ChunkedInference.TextOffset offset)
implements
EmbeddingResults.Chunk {
public record Chunk(List<WeightedToken> weightedTokens, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {

public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
return new ChunkedInference.Chunk(matchedText, offset, toBytesReference(xcontent, weightedTokens));
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, weightedTokens));
}

private static BytesReference toBytesReference(XContent xContent, List<WeightedToken> tokens) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,18 +187,18 @@ public int hashCode() {
}

@Override
public Chunk toChunk(String text, ChunkedInference.TextOffset offset) {
return new Chunk(values, text, offset);
public Chunk toChunk(ChunkedInference.TextOffset offset) {
return new Chunk(values, offset);
}
}

/**
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
*/
public record Chunk(byte[] embedding, String matchedText, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
public record Chunk(byte[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {

public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
return new ChunkedInference.Chunk(matchedText, offset, toBytesReference(xcontent, embedding));
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
}

private static BytesReference toBytesReference(XContent xContent, byte[] value) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,15 @@ public int hashCode() {
}

@Override
public Chunk toChunk(String text, ChunkedInference.TextOffset offset) {
return new Chunk(values, text, offset);
public Chunk toChunk(ChunkedInference.TextOffset offset) {
return new Chunk(values, offset);
}
}

public record Chunk(float[] embedding, String matchedText, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
public record Chunk(float[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {

public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
return new ChunkedInference.Chunk(matchedText, offset, toBytesReference(xcontent, embedding));
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ private List<ChunkedInference> makeChunkedResults(List<String> input, int dimens
List.of(
new TextEmbeddingFloatResults.Chunk(
nonChunkedResults.embeddings().get(i).values(),
input.get(i),
new ChunkedInference.TextOffset(0, input.get(i).length())
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,7 @@ private List<ChunkedInference> makeChunkedResults(List<String> input) {
}
results.add(
new ChunkedInferenceEmbedding(
List.of(
new SparseEmbeddingResults.Chunk(
tokens,
input.get(i),
new ChunkedInference.TextOffset(0, input.get(i).length())
)
)
List.of(new SparseEmbeddingResults.Chunk(tokens, new ChunkedInference.TextOffset(0, input.get(i).length())))
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,7 @@ private ChunkedInference mergeResultsWithInputs(int index) {
AtomicReferenceArray<EmbeddingResults.Embedding<?>> result = results.get(index);
for (int i = 0; i < request.size(); i++) {
EmbeddingResults.Chunk chunk = result.get(i)
.toChunk(
request.get(i).chunkText(),
new ChunkedInference.TextOffset(request.get(i).chunk.start(), request.get(i).chunk.end())
);
.toChunk(new ChunkedInference.TextOffset(request.get(i).chunk.start(), request.get(i).chunk.end()));
chunks.add(chunk);
}
return new ChunkedInferenceEmbedding(chunks);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ private static List<ChunkedInference> translateToChunkedResults(DocumentsOnlyInp
List.of(
new TextEmbeddingFloatResults.Chunk(
textEmbeddingResults.embeddings().get(i).values(),
inputs.getInputs().get(i),
new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).length())
)
)
Expand Down
Loading