diff --git a/docs/changelog/124313.yaml b/docs/changelog/124313.yaml new file mode 100644 index 0000000000000..fc4d4d9d815e4 --- /dev/null +++ b/docs/changelog/124313.yaml @@ -0,0 +1,5 @@ +pr: 124313 +summary: Optimize memory usage in `ShardBulkInferenceActionFilter` +area: Search +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java index f1bfddf0ae19b..ef408ee87c6b8 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java @@ -20,6 +20,7 @@ import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.index.mapper.SourceFieldMapper; @@ -44,6 +45,7 @@ import java.util.Map; import java.util.Set; +import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -85,7 +87,12 @@ public void setup() throws Exception { @Override protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { - return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes()); + return Settings.builder() + .put(otherSettings) + .put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial") + .put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes)) + .build(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 8e653ef327189..2824bab68ea0d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -142,6 +142,7 @@ import java.util.function.Supplier; import static java.util.Collections.singletonList; +import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE; import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG; public class InferencePlugin extends Plugin @@ -445,6 +446,7 @@ public List> getSettings() { settings.addAll(Truncator.getSettingsDefinitions()); settings.addAll(RequestExecutorServiceSettings.getSettingsDefinitions()); settings.add(SKIP_VALIDATE_AND_START); + settings.add(INDICES_INFERENCE_BATCH_SIZE); settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions()); return settings; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index e1ecfdc97be64..c2ad057d0a256 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -24,7 +24,11 @@ import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; @@ -42,6 +46,10 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.InferenceException; @@ -62,6 +70,8 @@ import java.util.stream.Collectors; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy; /** * A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified @@ -71,10 +81,23 @@ * This transformation happens on the bulk coordinator node, and the {@link SemanticTextFieldMapper} parses the * results during indexing on the shard. * - * TODO: batchSize should be configurable via a cluster setting */ public class ShardBulkInferenceActionFilter implements MappedActionFilter { - protected static final int DEFAULT_BATCH_SIZE = 512; + private static final ByteSizeValue DEFAULT_BATCH_SIZE = ByteSizeValue.ofMb(1); + + /** + * Defines the cumulative size limit of input data before triggering a batch inference call. + * This setting controls how much data can be accumulated before an inference request is sent in batch. + */ + public static Setting INDICES_INFERENCE_BATCH_SIZE = Setting.byteSizeSetting( + "indices.inference.batch_size", + DEFAULT_BATCH_SIZE, + ByteSizeValue.ONE, + ByteSizeValue.ofMb(100), + Setting.Property.NodeScope, + Setting.Property.OperatorDynamic + ); + private static final Object EXPLICIT_NULL = new Object(); private static final ChunkedInference EMPTY_CHUNKED_INFERENCE = new EmptyChunkedInference(); @@ -82,29 +105,24 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { private final InferenceServiceRegistry inferenceServiceRegistry; private final ModelRegistry modelRegistry; private final XPackLicenseState licenseState; - private final int batchSize; + private volatile long batchSizeInBytes; public ShardBulkInferenceActionFilter( ClusterService clusterService, InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, XPackLicenseState licenseState - ) { - this(clusterService, inferenceServiceRegistry, modelRegistry, licenseState, DEFAULT_BATCH_SIZE); - } - - public ShardBulkInferenceActionFilter( - ClusterService clusterService, - InferenceServiceRegistry inferenceServiceRegistry, - ModelRegistry modelRegistry, - XPackLicenseState licenseState, - int batchSize ) { this.clusterService = clusterService; this.inferenceServiceRegistry = inferenceServiceRegistry; this.modelRegistry = modelRegistry; this.licenseState = licenseState; - this.batchSize = batchSize; + this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes(); + clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize); + } + + private void setBatchSize(ByteSizeValue newBatchSize) { + batchSizeInBytes = newBatchSize.getBytes(); } @Override @@ -146,14 +164,21 @@ private record InferenceProvider(InferenceService service, Model model) {} /** * A field inference request on a single input. - * @param index The index of the request in the original bulk request. + * @param bulkItemIndex The index of the item in the original bulk request. * @param field The target field. * @param sourceField The source field. * @param input The input to run inference on. * @param inputOrder The original order of the input. * @param offsetAdjustment The adjustment to apply to the chunk text offsets. */ - private record FieldInferenceRequest(int index, String field, String sourceField, String input, int inputOrder, int offsetAdjustment) {} + private record FieldInferenceRequest( + int bulkItemIndex, + String field, + String sourceField, + String input, + int inputOrder, + int offsetAdjustment + ) {} /** * The field inference response. @@ -216,29 +241,54 @@ private AsyncBulkShardInferenceAction( @Override public void run() { - Map> inferenceRequests = createFieldInferenceRequests(bulkShardRequest); + executeNext(0); + } + + private void executeNext(int itemOffset) { + if (itemOffset >= bulkShardRequest.items().length) { + onCompletion.run(); + return; + } + + var items = bulkShardRequest.items(); + Map> fieldRequestsMap = new HashMap<>(); + long totalInputLength = 0; + int itemIndex = itemOffset; + while (itemIndex < items.length && totalInputLength < batchSizeInBytes) { + var item = items[itemIndex]; + totalInputLength += addFieldInferenceRequests(item, itemIndex, fieldRequestsMap); + itemIndex += 1; + } + int nextItemOffset = itemIndex; Runnable onInferenceCompletion = () -> { try { - for (var inferenceResponse : inferenceResults.asList()) { - var request = bulkShardRequest.items()[inferenceResponse.id]; + for (int i = itemOffset; i < nextItemOffset; i++) { + var result = inferenceResults.get(i); + if (result == null) { + continue; + } + var item = items[i]; try { - applyInferenceResponses(request, inferenceResponse); + applyInferenceResponses(item, result); } catch (Exception exc) { - request.abort(bulkShardRequest.index(), exc); + item.abort(bulkShardRequest.index(), exc); } + // we don't need to keep the inference results around + inferenceResults.set(i, null); } } finally { - onCompletion.run(); + executeNext(nextItemOffset); } }; + try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) { - for (var entry : inferenceRequests.entrySet()) { - executeShardBulkInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire()); + for (var entry : fieldRequestsMap.entrySet()) { + executeChunkedInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire()); } } } - private void executeShardBulkInferenceAsync( + private void executeChunkedInferenceAsync( final String inferenceId, @Nullable InferenceProvider inferenceProvider, final List requests, @@ -260,11 +310,11 @@ public void onResponse(UnparsedModel unparsedModel) { unparsedModel.secrets() ) ); - executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish); + executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish); } else { try (onFinish) { for (FieldInferenceRequest request : requests) { - inferenceResults.get(request.index).failures.add( + inferenceResults.get(request.bulkItemIndex).failures.add( new ResourceNotFoundException( "Inference service [{}] not found for field [{}]", unparsedModel.service(), @@ -295,7 +345,7 @@ public void onFailure(Exception exc) { request.field ); } - inferenceResults.get(request.index).failures.add(failure); + inferenceResults.get(request.bulkItemIndex).failures.add(failure); } } } @@ -303,18 +353,15 @@ public void onFailure(Exception exc) { modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); return; } - int currentBatchSize = Math.min(requests.size(), batchSize); - final List currentBatch = requests.subList(0, currentBatchSize); - final List nextBatch = requests.subList(currentBatchSize, requests.size()); - final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + final List inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); ActionListener> completionListener = new ActionListener<>() { @Override public void onResponse(List results) { - try { + try (onFinish) { var requestsIterator = requests.iterator(); for (ChunkedInference result : results) { var request = requestsIterator.next(); - var acc = inferenceResults.get(request.index); + var acc = inferenceResults.get(request.bulkItemIndex); if (result instanceof ChunkedInferenceError error) { acc.addFailure( new InferenceException( @@ -329,7 +376,7 @@ public void onResponse(List results) { new FieldInferenceResponse( request.field(), request.sourceField(), - request.input(), + useLegacyFormat ? request.input() : null, request.inputOrder(), request.offsetAdjustment(), inferenceProvider.model, @@ -338,17 +385,15 @@ public void onResponse(List results) { ); } } - } finally { - onFinish(); } } @Override public void onFailure(Exception exc) { - try { + try (onFinish) { for (FieldInferenceRequest request : requests) { addInferenceResponseFailure( - request.index, + request.bulkItemIndex, new InferenceException( "Exception when running inference id [{}] on field [{}]", exc, @@ -357,16 +402,6 @@ public void onFailure(Exception exc) { ) ); } - } finally { - onFinish(); - } - } - - private void onFinish() { - if (nextBatch.isEmpty()) { - onFinish.close(); - } else { - executeShardBulkInferenceAsync(inferenceId, inferenceProvider, nextBatch, onFinish); } } }; @@ -374,6 +409,132 @@ private void onFinish() { .chunkedInfer(inferenceProvider.model(), null, inputs, Map.of(), InputType.INGEST, TimeValue.MAX_VALUE, completionListener); } + /** + * Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap} + * for the specified {@code item}. + * + * @param item The bulk request item to process. + * @param itemIndex The position of the item within the original bulk request. + * @param requestsMap A map storing inference requests, where each key is an inference ID, + * and the value is a list of associated {@link FieldInferenceRequest} objects. + * @return The total content length of all newly added requests, or {@code 0} if no requests were added. + */ + private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map> requestsMap) { + boolean isUpdateRequest = false; + final IndexRequest indexRequest; + if (item.request() instanceof IndexRequest ir) { + indexRequest = ir; + } else if (item.request() instanceof UpdateRequest updateRequest) { + isUpdateRequest = true; + if (updateRequest.script() != null) { + addInferenceResponseFailure( + itemIndex, + new ElasticsearchStatusException( + "Cannot apply update with a script on indices that contain [{}] field(s)", + RestStatus.BAD_REQUEST, + SemanticTextFieldMapper.CONTENT_TYPE + ) + ); + return 0; + } + indexRequest = updateRequest.doc(); + } else { + // ignore delete request + return 0; + } + + final Map docMap = indexRequest.sourceAsMap(); + long inputLength = 0; + for (var entry : fieldInferenceMap.values()) { + String field = entry.getName(); + String inferenceId = entry.getInferenceId(); + + if (useLegacyFormat) { + var originalFieldValue = XContentMapValues.extractValue(field, docMap); + if (originalFieldValue instanceof Map || (originalFieldValue == null && entry.getSourceFields().length == 1)) { + // Inference has already been computed, or there is no inference required. + continue; + } + } else { + var inferenceMetadataFieldsValue = XContentMapValues.extractValue( + InferenceMetadataFieldsMapper.NAME + "." + field, + docMap, + EXPLICIT_NULL + ); + if (inferenceMetadataFieldsValue != null) { + // Inference has already been computed + continue; + } + } + + int order = 0; + for (var sourceField : entry.getSourceFields()) { + var valueObj = XContentMapValues.extractValue(sourceField, docMap, EXPLICIT_NULL); + if (useLegacyFormat == false && isUpdateRequest && valueObj == EXPLICIT_NULL) { + /** + * It's an update request, and the source field is explicitly set to null, + * so we need to propagate this information to the inference fields metadata + * to overwrite any inference previously computed on the field. + * This ensures that the field is treated as intentionally cleared, + * preventing any unintended carryover of prior inference results. + */ + var slot = ensureResponseAccumulatorSlot(itemIndex); + slot.addOrUpdateResponse( + new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE) + ); + continue; + } + if (valueObj == null || valueObj == EXPLICIT_NULL) { + if (isUpdateRequest && useLegacyFormat) { + addInferenceResponseFailure( + itemIndex, + new ElasticsearchStatusException( + "Field [{}] must be specified on an update request to calculate inference for field [{}]", + RestStatus.BAD_REQUEST, + sourceField, + field + ) + ); + break; + } + continue; + } + var slot = ensureResponseAccumulatorSlot(itemIndex); + final List values; + try { + values = SemanticTextUtils.nodeStringValues(field, valueObj); + } catch (Exception exc) { + addInferenceResponseFailure(itemIndex, exc); + break; + } + + if (INFERENCE_API_FEATURE.check(licenseState) == false) { + addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE)); + break; + } + + List requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); + int offsetAdjustment = 0; + for (String v : values) { + inputLength += v.length(); + if (v.isBlank()) { + slot.addOrUpdateResponse( + new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE) + ); + } else { + requests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment)); + } + + // When using the inference metadata fields format, all the input values are concatenated so that the + // chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment + // to apply to account for this. + offsetAdjustment += v.length() + 1; // Add one for separator char length + } + } + } + return inputLength; + } + private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { FieldInferenceResponseAccumulator acc = inferenceResults.get(id); if (acc == null) { @@ -402,7 +563,6 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons } final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); - var newDocMap = indexRequest.sourceAsMap(); Map inferenceFieldsMap = new HashMap<>(); for (var entry : response.responses.entrySet()) { var fieldName = entry.getKey(); @@ -424,28 +584,22 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons } var lst = chunkMap.computeIfAbsent(resp.sourceField, k -> new ArrayList<>()); - lst.addAll( - SemanticTextField.toSemanticTextFieldChunks( - resp.input, - resp.offsetAdjustment, - resp.chunkedResults, - indexRequest.getContentType(), - useLegacyFormat - ) - ); + var chunks = useLegacyFormat + ? toSemanticTextFieldChunksLegacy(resp.input, resp.chunkedResults, indexRequest.getContentType()) + : toSemanticTextFieldChunks(resp.offsetAdjustment, resp.chunkedResults, indexRequest.getContentType()); + lst.addAll(chunks); } - List inputs = responses.stream() - .filter(r -> r.sourceField().equals(fieldName)) - .map(r -> r.input) - .collect(Collectors.toList()); + List inputs = useLegacyFormat + ? responses.stream().filter(r -> r.sourceField().equals(fieldName)).map(r -> r.input).collect(Collectors.toList()) + : null; // The model can be null if we are only processing update requests that clear inference results. This is ok because we will // merge in the field's existing model settings on the data node. var result = new SemanticTextField( useLegacyFormat, fieldName, - useLegacyFormat ? inputs : null, + inputs, new SemanticTextField.InferenceResult( inferenceFieldMetadata.getInferenceId(), model != null ? new MinimalServiceSettings(model) : null, @@ -453,149 +607,52 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons ), indexRequest.getContentType() ); - - if (useLegacyFormat) { - SemanticTextUtils.insertValue(fieldName, newDocMap, result); - } else { - inferenceFieldsMap.put(fieldName, result); - } - } - if (useLegacyFormat == false) { - newDocMap.put(InferenceMetadataFieldsMapper.NAME, inferenceFieldsMap); + inferenceFieldsMap.put(fieldName, result); } - indexRequest.source(newDocMap, indexRequest.getContentType()); - } - /** - * Register a {@link FieldInferenceRequest} for every non-empty field referencing an inference ID in the index. - * If results are already populated for fields in the original index request, the inference request for this specific - * field is skipped, and the existing results remain unchanged. - * Validation of inference ID and model settings occurs in the {@link SemanticTextFieldMapper} during field indexing, - * where an error will be thrown if they mismatch or if the content is malformed. - *

- * TODO: We should validate the settings for pre-existing results here and apply the inference only if they differ? - */ - private Map> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) { - Map> fieldRequestsMap = new LinkedHashMap<>(); - for (int itemIndex = 0; itemIndex < bulkShardRequest.items().length; itemIndex++) { - var item = bulkShardRequest.items()[itemIndex]; - if (item.getPrimaryResponse() != null) { - // item was already aborted/processed by a filter in the chain upstream (e.g. security) - continue; + if (useLegacyFormat) { + var newDocMap = indexRequest.sourceAsMap(); + for (var entry : inferenceFieldsMap.entrySet()) { + SemanticTextUtils.insertValue(entry.getKey(), newDocMap, entry.getValue()); } - boolean isUpdateRequest = false; - final IndexRequest indexRequest; - if (item.request() instanceof IndexRequest ir) { - indexRequest = ir; - } else if (item.request() instanceof UpdateRequest updateRequest) { - isUpdateRequest = true; - if (updateRequest.script() != null) { - addInferenceResponseFailure( - itemIndex, - new ElasticsearchStatusException( - "Cannot apply update with a script on indices that contain [{}] field(s)", - RestStatus.BAD_REQUEST, - SemanticTextFieldMapper.CONTENT_TYPE - ) - ); - continue; - } - indexRequest = updateRequest.doc(); - } else { - // ignore delete request - continue; + indexRequest.source(newDocMap, indexRequest.getContentType()); + } else { + try (XContentBuilder builder = XContentBuilder.builder(indexRequest.getContentType().xContent())) { + appendSourceAndInferenceMetadata(builder, indexRequest.source(), indexRequest.getContentType(), inferenceFieldsMap); + indexRequest.source(builder); } + } + } + } - final Map docMap = indexRequest.sourceAsMap(); - for (var entry : fieldInferenceMap.values()) { - String field = entry.getName(); - String inferenceId = entry.getInferenceId(); - - if (useLegacyFormat) { - var originalFieldValue = XContentMapValues.extractValue(field, docMap); - if (originalFieldValue instanceof Map || (originalFieldValue == null && entry.getSourceFields().length == 1)) { - // Inference has already been computed, or there is no inference required. - continue; - } - } else { - var inferenceMetadataFieldsValue = XContentMapValues.extractValue( - InferenceMetadataFieldsMapper.NAME + "." + field, - docMap, - EXPLICIT_NULL - ); - if (inferenceMetadataFieldsValue != null) { - // Inference has already been computed - continue; - } - } - - int order = 0; - for (var sourceField : entry.getSourceFields()) { - var valueObj = XContentMapValues.extractValue(sourceField, docMap, EXPLICIT_NULL); - if (useLegacyFormat == false && isUpdateRequest && valueObj == EXPLICIT_NULL) { - /** - * It's an update request, and the source field is explicitly set to null, - * so we need to propagate this information to the inference fields metadata - * to overwrite any inference previously computed on the field. - * This ensures that the field is treated as intentionally cleared, - * preventing any unintended carryover of prior inference results. - */ - var slot = ensureResponseAccumulatorSlot(itemIndex); - slot.addOrUpdateResponse( - new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE) - ); - continue; - } - if (valueObj == null || valueObj == EXPLICIT_NULL) { - if (isUpdateRequest && useLegacyFormat) { - addInferenceResponseFailure( - itemIndex, - new ElasticsearchStatusException( - "Field [{}] must be specified on an update request to calculate inference for field [{}]", - RestStatus.BAD_REQUEST, - sourceField, - field - ) - ); - break; - } - continue; - } - var slot = ensureResponseAccumulatorSlot(itemIndex); - final List values; - try { - values = SemanticTextUtils.nodeStringValues(field, valueObj); - } catch (Exception exc) { - addInferenceResponseFailure(itemIndex, exc); - break; - } - - if (INFERENCE_API_FEATURE.check(licenseState) == false) { - addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE)); - break; - } - - List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); - int offsetAdjustment = 0; - for (String v : values) { - if (v.isBlank()) { - slot.addOrUpdateResponse( - new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE) - ); - } else { - fieldRequests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment)); - } - - // When using the inference metadata fields format, all the input values are concatenated so that the - // chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment - // to apply to account for this. - offsetAdjustment += v.length() + 1; // Add one for separator char length - } - } - } + /** + * Appends the original source and the new inference metadata field directly to the provided + * {@link XContentBuilder}, avoiding the need to materialize the original source as a {@link Map}. + */ + private static void appendSourceAndInferenceMetadata( + XContentBuilder builder, + BytesReference source, + XContentType xContentType, + Map inferenceFieldsMap + ) throws IOException { + builder.startObject(); + + // append the original source + try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, source, xContentType)) { + // skip start object + parser.nextToken(); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + builder.copyCurrentStructure(parser); } - return fieldRequestsMap; } + + // add the inference metadata field + builder.field(InferenceMetadataFieldsMapper.NAME); + try (XContentParser parser = XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, inferenceFieldsMap)) { + builder.copyCurrentStructure(parser); + } + + builder.endObject(); } static IndexRequest getIndexRequestOrNull(DocWriteRequest docWriteRequest) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index dedc02e0a8c3f..be7588abbbc8d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -267,37 +267,38 @@ private static List parseChunksArrayLegacy(XContentParser parser, ParserC /** * Converts the provided {@link ChunkedInference} into a list of {@link Chunk}. */ - public static List toSemanticTextFieldChunks( - String input, - int offsetAdjustment, - ChunkedInference results, - XContentType contentType, - boolean useLegacyFormat - ) throws IOException { + public static List toSemanticTextFieldChunks(int offsetAdjustment, ChunkedInference results, XContentType contentType) + throws IOException { List chunks = new ArrayList<>(); Iterator it = results.chunksAsByteReference(contentType.xContent()); while (it.hasNext()) { - chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, it.next(), useLegacyFormat)); + chunks.add(toSemanticTextFieldChunk(offsetAdjustment, it.next())); } return chunks; } - public static Chunk toSemanticTextFieldChunk( - String input, - int offsetAdjustment, - ChunkedInference.Chunk chunk, - boolean useLegacyFormat - ) { + /** + * Converts the provided {@link ChunkedInference} into a list of {@link Chunk}. + */ + public static Chunk toSemanticTextFieldChunk(int offsetAdjustment, ChunkedInference.Chunk chunk) { String text = null; - int startOffset = -1; - int endOffset = -1; - if (useLegacyFormat) { - text = input.substring(chunk.textOffset().start(), chunk.textOffset().end()); - } else { - startOffset = chunk.textOffset().start() + offsetAdjustment; - endOffset = chunk.textOffset().end() + offsetAdjustment; + int startOffset = chunk.textOffset().start() + offsetAdjustment; + int endOffset = chunk.textOffset().end() + offsetAdjustment; + return new Chunk(text, startOffset, endOffset, chunk.bytesReference()); + } + + public static List toSemanticTextFieldChunksLegacy(String input, ChunkedInference results, XContentType contentType) + throws IOException { + List chunks = new ArrayList<>(); + Iterator it = results.chunksAsByteReference(contentType.xContent()); + while (it.hasNext()) { + chunks.add(toSemanticTextFieldChunkLegacy(input, it.next())); } + return chunks; + } - return new Chunk(text, startOffset, endOffset, chunk.bytesReference()); + public static Chunk toSemanticTextFieldChunkLegacy(String input, ChunkedInference.Chunk chunk) { + var text = input.substring(chunk.textOffset().start(), chunk.textOffset().end()); + return new Chunk(text, -1, -1, chunk.bytesReference()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 29482b58b6898..aaa81d2260a84 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -27,7 +27,9 @@ import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.index.IndexVersion; @@ -65,12 +67,13 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; -import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE; +import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName; @@ -115,7 +118,7 @@ public void tearDownThreadPool() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testFilterNoop() throws Exception { - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, true); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, true); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -141,7 +144,7 @@ public void testFilterNoop() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testLicenseInvalidForInference() throws InterruptedException { StaticModel model = StaticModel.createRandomInstance(); - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, false); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -182,7 +185,6 @@ public void testInferenceNotFound() throws Exception { ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), - randomIntBetween(1, 10), useLegacyFormat, true ); @@ -229,7 +231,6 @@ public void testItemFailures() throws Exception { ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), - randomIntBetween(1, 10), useLegacyFormat, true ); @@ -300,7 +301,6 @@ public void testExplicitNull() throws Exception { ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), - randomIntBetween(1, 10), useLegacyFormat, true ); @@ -371,7 +371,6 @@ public void testHandleEmptyInput() throws Exception { ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), - randomIntBetween(1, 10), useLegacyFormat, true ); @@ -444,13 +443,7 @@ public void testManyRandomDocs() throws Exception { modifiedRequests[id] = res[1]; } - ShardBulkInferenceActionFilter filter = createFilter( - threadPool, - inferenceModelMap, - randomIntBetween(10, 30), - useLegacyFormat, - true - ); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, useLegacyFormat, true); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -484,7 +477,6 @@ public void testManyRandomDocs() throws Exception { private static ShardBulkInferenceActionFilter createFilter( ThreadPool threadPool, Map modelMap, - int batchSize, boolean useLegacyFormat, boolean isLicenseValidForInference ) { @@ -551,18 +543,17 @@ private static ShardBulkInferenceActionFilter createFilter( createClusterService(useLegacyFormat), inferenceServiceRegistry, modelRegistry, - licenseState, - batchSize + licenseState ); } private static ClusterService createClusterService(boolean useLegacyFormat) { IndexMetadata indexMetadata = mock(IndexMetadata.class); - var settings = Settings.builder() + var indexSettings = Settings.builder() .put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), IndexVersion.current()) .put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat) .build(); - when(indexMetadata.getSettings()).thenReturn(settings); + when(indexMetadata.getSettings()).thenReturn(indexSettings); Metadata metadata = mock(Metadata.class); when(metadata.index(any(String.class))).thenReturn(indexMetadata); @@ -570,7 +561,10 @@ private static ClusterService createClusterService(boolean useLegacyFormat) { ClusterState clusterState = ClusterState.builder(new ClusterName("test")).metadata(metadata).build(); ClusterService clusterService = mock(ClusterService.class); when(clusterService.state()).thenReturn(clusterState); - + long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes()); + Settings settings = Settings.builder().put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes)).build(); + when(clusterService.getSettings()).thenReturn(settings); + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(INDICES_INFERENCE_BATCH_SIZE))); return clusterService; } @@ -581,7 +575,8 @@ private static BulkItemRequest[] randomBulkItemRequest( ) throws IOException { Map docMap = new LinkedHashMap<>(); Map expectedDocMap = new LinkedHashMap<>(); - XContentType requestContentType = randomFrom(XContentType.values()); + // force JSON to avoid double/float conversions + XContentType requestContentType = XContentType.JSON; Map inferenceMetadataFields = new HashMap<>(); for (var entry : fieldInferenceMap.values()) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 2c2e5f5d6d72b..b9824d58bcd8b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -41,6 +41,7 @@ import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunk; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunkLegacy; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -259,7 +260,7 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults( while (inputsIt.hasNext() && chunkIt.hasNext()) { String input = inputsIt.next(); var chunk = chunkIt.next(); - chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, chunk, useLegacyFormat)); + chunks.add(useLegacyFormat ? toSemanticTextFieldChunkLegacy(input, chunk) : toSemanticTextFieldChunk(offsetAdjustment, chunk)); // When using the inference metadata fields format, all the input values are concatenated so that the // chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment