Skip to content

Commit f52a5a7

Browse files
authored
Semantic Text - Use offsets from chunked inference response (#118893)
1 parent 9533c7b commit f52a5a7

File tree

3 files changed

+42
-21
lines changed

3 files changed

+42
-21
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -425,15 +425,16 @@ public static Chunk toSemanticTextFieldChunk(
425425
ChunkedInference.Chunk chunk,
426426
boolean useInferenceMetadataFieldsFormat
427427
) {
428-
// TODO: Use offsets from ChunkedInferenceServiceResults
429-
// TODO: When using legacy semantic text format, build chunk text from offsets
430-
assert chunk.matchedText() != null; // TODO: Remove once offsets are available from chunk
431-
int startOffset = useInferenceMetadataFieldsFormat ? input.indexOf(chunk.matchedText()) + offsetAdjustment : -1;
432-
return new Chunk(
433-
useInferenceMetadataFieldsFormat ? null : chunk.matchedText(),
434-
useInferenceMetadataFieldsFormat ? startOffset : -1,
435-
useInferenceMetadataFieldsFormat ? startOffset + chunk.matchedText().length() : -1,
436-
chunk.bytesReference()
437-
);
428+
String text = null;
429+
int startOffset = -1;
430+
int endOffset = -1;
431+
if (useInferenceMetadataFieldsFormat) {
432+
startOffset = chunk.textOffset().start() + offsetAdjustment;
433+
endOffset = chunk.textOffset().end() + offsetAdjustment;
434+
} else {
435+
text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
436+
}
437+
438+
return new Chunk(text, startOffset, endOffset, chunk.bytesReference());
438439
}
439440
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ private static BulkItemRequest[] randomBulkItemRequest(
416416
} else {
417417
Map<String, List<String>> inputTextMap = Map.of(field, List.of(inputText));
418418
semanticTextField = randomSemanticText(indexVersion, field, model, List.of(inputText), requestContentType);
419-
model.putResult(inputText, toChunkedResult(inputTextMap, semanticTextField));
419+
model.putResult(inputText, toChunkedResult(indexVersion, inputTextMap, semanticTextField));
420420
}
421421

422422
if (useInferenceMetadataFieldsFormat) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,11 @@ public static Object randomSemanticTextInput() {
286286
}
287287
}
288288

289-
public static ChunkedInference toChunkedResult(Map<String, List<String>> matchedTextMap, SemanticTextField field) throws IOException {
289+
public static ChunkedInference toChunkedResult(
290+
IndexVersion indexVersion,
291+
Map<String, List<String>> matchedTextMap,
292+
SemanticTextField field
293+
) {
290294
switch (field.inference().modelSettings().taskType()) {
291295
case SPARSE_EMBEDDING -> {
292296
List<ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk> chunks = new ArrayList<>();
@@ -297,14 +301,10 @@ public static ChunkedInference toChunkedResult(Map<String, List<String>> matched
297301

298302
ListIterator<String> matchedTextIt = entryFieldMatchedText.listIterator();
299303
for (var chunk : entryChunks) {
304+
String matchedText = matchedTextIt.next();
305+
ChunkedInference.TextOffset offset = createOffset(indexVersion, chunk, matchedText);
300306
var tokens = parseWeightedTokens(chunk.rawEmbeddings(), field.contentType());
301-
chunks.add(
302-
new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk(
303-
tokens,
304-
matchedTextIt.next(),
305-
new ChunkedInference.TextOffset(chunk.startOffset(), chunk.endOffset())
306-
)
307-
);
307+
chunks.add(new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk(tokens, matchedText, offset));
308308
}
309309
}
310310
return new ChunkedInferenceEmbeddingSparse(chunks);
@@ -318,6 +318,8 @@ public static ChunkedInference toChunkedResult(Map<String, List<String>> matched
318318

319319
ListIterator<String> matchedTextIt = entryFieldMatchedText.listIterator();
320320
for (var chunk : entryChunks) {
321+
String matchedText = matchedTextIt.next();
322+
ChunkedInference.TextOffset offset = createOffset(indexVersion, chunk, matchedText);
321323
double[] values = parseDenseVector(
322324
chunk.rawEmbeddings(),
323325
field.inference().modelSettings().dimensions(),
@@ -326,8 +328,8 @@ public static ChunkedInference toChunkedResult(Map<String, List<String>> matched
326328
chunks.add(
327329
new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk(
328330
FloatConversionUtils.floatArrayOf(values),
329-
matchedTextIt.next(),
330-
new ChunkedInference.TextOffset(chunk.startOffset(), chunk.endOffset())
331+
matchedText,
332+
offset
331333
)
332334
);
333335
}
@@ -353,6 +355,24 @@ private static List<String> validateAndGetMatchedTextForField(
353355
return fieldMatchedText;
354356
}
355357

358+
/**
359+
* Create a {@link ChunkedInference.TextOffset} instance with valid offset values. When using the legacy semantic text format, the
360+
* offset values are not written to {@link SemanticTextField.Chunk}, so we cannot read them from there. Instead, use the knowledge that
361+
* the matched text corresponds to one complete input value (i.e. one input value -> one chunk) to calculate the offset values.
362+
*
363+
* @param indexVersion The index version
364+
* @param chunk The chunk to get/calculate offset values for
365+
* @param matchedText The matched text to calculate offset values for
366+
* @return A {@link ChunkedInference.TextOffset} instance with valid offset values
367+
*/
368+
private static ChunkedInference.TextOffset createOffset(IndexVersion indexVersion, SemanticTextField.Chunk chunk, String matchedText) {
369+
final boolean useInferenceMetadataFields = InferenceMetadataFieldsMapper.isEnabled(indexVersion);
370+
final int startOffset = useInferenceMetadataFields ? chunk.startOffset() : 0;
371+
final int endOffset = useInferenceMetadataFields ? chunk.endOffset() : matchedText.length();
372+
373+
return new ChunkedInference.TextOffset(startOffset, endOffset);
374+
}
375+
356376
private static double[] parseDenseVector(BytesReference value, int numDims, XContentType contentType) {
357377
try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, value, contentType)) {
358378
parser.nextToken();

0 commit comments

Comments
 (0)