Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,17 @@ public interface ChunkedInference {
* Implementations of this function serialize their embeddings to {@link BytesReference} for storage in semantic text fields.
*
* @param xcontent provided by the SemanticTextField
* @return an iterator of the serialized {@link Chunk} which includes the matched text (input) and bytes reference (output/embedding).
* @return an iterator of the serialized {@link Chunk} which includes the offset into the input text and bytes reference
* (output/embedding).
*/
Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException;
Iterator<Chunk> chunksAsByteReference(XContent xcontent) throws IOException;

/**
* A chunk of inference results containing matched text, the substring location
* in the original text and the bytes reference.
* @param matchedText
* A chunk of inference results containing the substring location in the original text and the bytes reference.
* @param textOffset
* @param bytesReference
*/
record Chunk(String matchedText, TextOffset textOffset, BytesReference bytesReference) {}
record Chunk(TextOffset textOffset, BytesReference bytesReference) {}

record TextOffset(int start, int end) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbedding
List.of(
new SparseEmbeddingResults.Chunk(
sparseEmbeddingResults.embeddings().get(i).tokens(),
inputs.get(i),
new TextOffset(0, inputs.get(i).length())
)
)
Expand All @@ -41,7 +40,7 @@ public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbedding
}

@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException {
public Iterator<Chunk> chunksAsByteReference(XContent xcontent) throws IOException {
var asChunk = new ArrayList<Chunk>();
for (var chunk : chunks()) {
asChunk.add(chunk.toChunk(xcontent));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.xcontent.XContent;

import java.util.Collections;
import java.util.Iterator;
import java.util.stream.Stream;

public record ChunkedInferenceError(Exception exception) implements ChunkedInference {

@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) {
return Stream.of(exception).map(e -> new Chunk(e.getMessage(), new TextOffset(0, 0), BytesArray.EMPTY)).iterator();
public Iterator<Chunk> chunksAsByteReference(XContent xcontent) {
return Collections.emptyIterator();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@ public interface EmbeddingResults<C extends EmbeddingResults.Chunk, E extends Em
InferenceServiceResults {

/**
* A resulting embedding together with its input text.
* A resulting embedding together with the offset into the input text.
*/
interface Chunk {
ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException;

String matchedText();

ChunkedInference.TextOffset offset();
}

Expand All @@ -39,9 +37,9 @@ interface Chunk {
*/
interface Embedding<C extends Chunk> {
/**
* Combines the resulting embedding with the input into a chunk.
* Combines the resulting embedding with the offset into the input text into a chunk.
*/
C toChunk(String text, ChunkedInference.TextOffset offset);
C toChunk(ChunkedInference.TextOffset offset);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,15 @@ public String toString() {
}

@Override
public Chunk toChunk(String text, ChunkedInference.TextOffset offset) {
return new Chunk(tokens, text, offset);
public Chunk toChunk(ChunkedInference.TextOffset offset) {
return new Chunk(tokens, offset);
}
}

public record Chunk(List<WeightedToken> weightedTokens, String matchedText, ChunkedInference.TextOffset offset)
implements
EmbeddingResults.Chunk {
public record Chunk(List<WeightedToken> weightedTokens, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {

public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
return new ChunkedInference.Chunk(matchedText, offset, toBytesReference(xcontent, weightedTokens));
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, weightedTokens));
}

private static BytesReference toBytesReference(XContent xContent, List<WeightedToken> tokens) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,18 +187,18 @@ public int hashCode() {
}

@Override
public Chunk toChunk(String text, ChunkedInference.TextOffset offset) {
return new Chunk(values, text, offset);
public Chunk toChunk(ChunkedInference.TextOffset offset) {
return new Chunk(values, offset);
}
}

/**
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
*/
public record Chunk(byte[] embedding, String matchedText, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
public record Chunk(byte[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {

public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
return new ChunkedInference.Chunk(matchedText, offset, toBytesReference(xcontent, embedding));
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
}

private static BytesReference toBytesReference(XContent xContent, byte[] value) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,15 @@ public int hashCode() {
}

@Override
public Chunk toChunk(String text, ChunkedInference.TextOffset offset) {
return new Chunk(values, text, offset);
public Chunk toChunk(ChunkedInference.TextOffset offset) {
return new Chunk(values, offset);
}
}

public record Chunk(float[] embedding, String matchedText, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
public record Chunk(float[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {

public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
return new ChunkedInference.Chunk(matchedText, offset, toBytesReference(xcontent, embedding));
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ private List<ChunkedInference> makeChunkedResults(List<String> input, int dimens
List.of(
new TextEmbeddingFloatResults.Chunk(
nonChunkedResults.embeddings().get(i).values(),
input.get(i),
new ChunkedInference.TextOffset(0, input.get(i).length())
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,7 @@ private List<ChunkedInference> makeChunkedResults(List<String> input) {
}
results.add(
new ChunkedInferenceEmbedding(
List.of(
new SparseEmbeddingResults.Chunk(
tokens,
input.get(i),
new ChunkedInference.TextOffset(0, input.get(i).length())
)
)
List.of(new SparseEmbeddingResults.Chunk(tokens, new ChunkedInference.TextOffset(0, input.get(i).length())))
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ static IndexRequest getIndexRequestOrNull(DocWriteRequest<?> docWriteRequest) {

private static class EmptyChunkedInference implements ChunkedInference {
@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) {
public Iterator<Chunk> chunksAsByteReference(XContent xcontent) {
return Collections.emptyIterator();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,7 @@ private ChunkedInference mergeResultsWithInputs(int index) {
AtomicReferenceArray<EmbeddingResults.Embedding<?>> result = results.get(index);
for (int i = 0; i < request.size(); i++) {
EmbeddingResults.Chunk chunk = result.get(i)
.toChunk(
request.get(i).chunkText(),
new ChunkedInference.TextOffset(request.get(i).chunk.start(), request.get(i).chunk.end())
);
.toChunk(new ChunkedInference.TextOffset(request.get(i).chunk.start(), request.get(i).chunk.end()));
chunks.add(chunk);
}
return new ChunkedInferenceEmbedding(chunks);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ public static List<Chunk> toSemanticTextFieldChunks(
boolean useLegacyFormat
) throws IOException {
List<Chunk> chunks = new ArrayList<>();
Iterator<ChunkedInference.Chunk> it = results.chunksAsMatchedTextAndByteReference(contentType.xContent());
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
while (it.hasNext()) {
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, it.next(), useLegacyFormat));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ private static List<ChunkedInference> translateToChunkedResults(DocumentsOnlyInp
List.of(
new TextEmbeddingFloatResults.Chunk(
textEmbeddingResults.embeddings().get(i).values(),
inputs.getInputs().get(i),
new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).length())
)
)
Expand Down
Loading