Skip to content

Commit cb86fd4

Browse files
authored
Inference Metadata Fields - Chunk On Delimiter (#118694)
1 parent fa45c50 commit cb86fd4

File tree

3 files changed

+42
-44
lines changed

3 files changed

+42
-44
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.elasticsearch.action.update.UpdateRequest;
2626
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
2727
import org.elasticsearch.cluster.service.ClusterService;
28-
import org.elasticsearch.common.Strings;
2928
import org.elasticsearch.common.util.concurrent.AtomicArray;
3029
import org.elasticsearch.common.xcontent.support.XContentMapValues;
3130
import org.elasticsearch.core.Nullable;
@@ -57,8 +56,6 @@
5756
import java.util.Map;
5857
import java.util.stream.Collectors;
5958

60-
import static org.elasticsearch.lucene.search.uhighlight.CustomUnifiedHighlighter.MULTIVAL_SEP_CHAR;
61-
6259
/**
6360
* A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified
6461
* as {@link SemanticTextFieldMapper} in the index mapping. For each semantic text field referencing fields in
@@ -140,15 +137,17 @@ private record InferenceProvider(InferenceService service, Model model) {}
140137
* @param sourceField The source field.
141138
* @param input The input to run inference on.
142139
* @param inputOrder The original order of the input.
140+
* @param offsetAdjustment The adjustment to apply to the chunk text offsets.
143141
*/
144-
private record FieldInferenceRequest(int index, String field, String sourceField, String input, int inputOrder) {}
142+
private record FieldInferenceRequest(int index, String field, String sourceField, String input, int inputOrder, int offsetAdjustment) {}
145143

146144
/**
147145
* The field inference response.
148146
* @param field The target field.
149147
* @param sourceField The input that was used to run inference.
150148
* @param input The input that was used to run inference.
151149
* @param inputOrder The original order of the input.
150+
* @param offsetAdjustment The adjustment to apply to the chunk text offsets.
152151
* @param model The model used to run inference.
153152
* @param chunkedResults The actual results.
154153
*/
@@ -157,6 +156,7 @@ private record FieldInferenceResponse(
157156
String sourceField,
158157
String input,
159158
int inputOrder,
159+
int offsetAdjustment,
160160
Model model,
161161
ChunkedInference chunkedResults
162162
) {}
@@ -317,6 +317,7 @@ public void onResponse(List<ChunkedInference> results) {
317317
request.sourceField(),
318318
request.input(),
319319
request.inputOrder(),
320+
request.offsetAdjustment(),
320321
inferenceProvider.model,
321322
result
322323
)
@@ -402,6 +403,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
402403
lst.addAll(
403404
SemanticTextField.toSemanticTextFieldChunks(
404405
resp.input,
406+
resp.offsetAdjustment,
405407
resp.chunkedResults,
406408
indexRequest.getContentType(),
407409
addMetadataField
@@ -528,16 +530,14 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
528530
}
529531

530532
List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
531-
if (useInferenceMetadataFieldsFormat) {
532-
// When using the inference metadata fields format, all the input values are concatenated so that the chunk
533-
// offsets are expressed in the context of a single string
534-
String concatenatedValue = Strings.collectionToDelimitedString(values, String.valueOf(MULTIVAL_SEP_CHAR));
535-
fieldRequests.add(new FieldInferenceRequest(itemIndex, field, sourceField, concatenatedValue, order++));
536-
} else {
537-
// When using the legacy format, each input value is processed using its own inference request
538-
for (String v : values) {
539-
fieldRequests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++));
540-
}
533+
int offsetAdjustment = 0;
534+
for (String v : values) {
535+
fieldRequests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment));
536+
537+
// When using the inference metadata fields format, all the input values are concatenated so that the
538+
// chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment
539+
// to apply to account for this.
540+
offsetAdjustment += v.length() + 1; // Add one for separator char length
541541
}
542542
}
543543
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,23 +406,29 @@ private static List<Chunk> parseChunksArrayLegacy(XContentParser parser, ParserC
406406
*/
407407
public static List<Chunk> toSemanticTextFieldChunks(
408408
String input,
409+
int offsetAdjustment,
409410
ChunkedInference results,
410411
XContentType contentType,
411412
boolean useInferenceMetadataFieldsFormat
412413
) throws IOException {
413414
List<Chunk> chunks = new ArrayList<>();
414415
Iterator<ChunkedInference.Chunk> it = results.chunksAsMatchedTextAndByteReference(contentType.xContent());
415416
while (it.hasNext()) {
416-
chunks.add(toSemanticTextFieldChunk(input, it.next(), useInferenceMetadataFieldsFormat));
417+
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, it.next(), useInferenceMetadataFieldsFormat));
417418
}
418419
return chunks;
419420
}
420421

421-
public static Chunk toSemanticTextFieldChunk(String input, ChunkedInference.Chunk chunk, boolean useInferenceMetadataFieldsFormat) {
422+
public static Chunk toSemanticTextFieldChunk(
423+
String input,
424+
int offsetAdjustment,
425+
ChunkedInference.Chunk chunk,
426+
boolean useInferenceMetadataFieldsFormat
427+
) {
422428
// TODO: Use offsets from ChunkedInferenceServiceResults
423429
// TODO: When using legacy semantic text format, build chunk text from offsets
424430
assert chunk.matchedText() != null; // TODO: Remove once offsets are available from chunk
425-
int startOffset = useInferenceMetadataFieldsFormat ? input.indexOf(chunk.matchedText()) : -1;
431+
int startOffset = useInferenceMetadataFieldsFormat ? input.indexOf(chunk.matchedText()) + offsetAdjustment : -1;
426432
return new Chunk(
427433
useInferenceMetadataFieldsFormat ? null : chunk.matchedText(),
428434
useInferenceMetadataFieldsFormat ? startOffset : -1,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.mapper;
99

10-
import org.elasticsearch.common.Strings;
1110
import org.elasticsearch.common.bytes.BytesReference;
1211
import org.elasticsearch.common.xcontent.XContentHelper;
1312
import org.elasticsearch.index.IndexVersion;
@@ -37,10 +36,8 @@
3736
import java.util.Map;
3837
import java.util.function.Predicate;
3938

40-
import static org.elasticsearch.lucene.search.uhighlight.CustomUnifiedHighlighter.MULTIVAL_SEP_CHAR;
4139
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
4240
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunk;
43-
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks;
4441
import static org.hamcrest.Matchers.containsString;
4542
import static org.hamcrest.Matchers.equalTo;
4643

@@ -236,31 +233,26 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults(
236233
) throws IOException {
237234
final boolean useInferenceMetadataFields = InferenceMetadataFieldsMapper.isEnabled(indexVersion);
238235

239-
final List<SemanticTextField.Chunk> chunks;
240-
if (useInferenceMetadataFields) {
241-
// When using the inference metadata fields format, all the input values are concatenated so that the chunk offsets are
242-
// expressed in the context of a single string
243-
chunks = toSemanticTextFieldChunks(
244-
Strings.collectionToDelimitedString(inputs, String.valueOf(MULTIVAL_SEP_CHAR)),
245-
results,
246-
contentType,
247-
useInferenceMetadataFields
248-
);
249-
} else {
250-
// When using the legacy format, each input value is processed using its own inference request.
251-
// In this test framework, we don't perform "real" chunking; each input generates one chunk. Thus, we can assume there is a
252-
// one-to-one relationship between inputs and chunks. Iterate over the inputs and chunks to match each input with its
253-
// corresponding chunk.
254-
chunks = new ArrayList<>(inputs.size());
255-
Iterator<String> inputsIt = inputs.iterator();
256-
Iterator<ChunkedInference.Chunk> chunkIt = results.chunksAsMatchedTextAndByteReference(contentType.xContent());
257-
while (inputsIt.hasNext() && chunkIt.hasNext()) {
258-
chunks.add(toSemanticTextFieldChunk(inputsIt.next(), chunkIt.next(), useInferenceMetadataFields));
259-
}
236+
// In this test framework, we don't perform "real" chunking; each input generates one chunk. Thus, we can assume there is a
237+
// one-to-one relationship between inputs and chunks. Iterate over the inputs and chunks to match each input with its
238+
// corresponding chunk.
239+
final List<SemanticTextField.Chunk> chunks = new ArrayList<>(inputs.size());
240+
int offsetAdjustment = 0;
241+
Iterator<String> inputsIt = inputs.iterator();
242+
Iterator<ChunkedInference.Chunk> chunkIt = results.chunksAsMatchedTextAndByteReference(contentType.xContent());
243+
while (inputsIt.hasNext() && chunkIt.hasNext()) {
244+
String input = inputsIt.next();
245+
var chunk = chunkIt.next();
246+
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, chunk, useInferenceMetadataFields));
247+
248+
// When using the inference metadata fields format, all the input values are concatenated so that the
249+
// chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment
250+
// to apply to account for this.
251+
offsetAdjustment = input.length() + 1; // Add one for separator char length
252+
}
260253

261-
if (inputsIt.hasNext() || chunkIt.hasNext()) {
262-
throw new IllegalArgumentException("Input list size and chunk count do not match");
263-
}
254+
if (inputsIt.hasNext() || chunkIt.hasNext()) {
255+
throw new IllegalArgumentException("Input list size and chunk count do not match");
264256
}
265257

266258
return new SemanticTextField(

0 commit comments

Comments
 (0)