|
28 | 28 | import org.elasticsearch.cluster.metadata.Metadata; |
29 | 29 | import org.elasticsearch.cluster.metadata.ProjectMetadata; |
30 | 30 | import org.elasticsearch.cluster.service.ClusterService; |
| 31 | +import org.elasticsearch.common.CheckedBiFunction; |
31 | 32 | import org.elasticsearch.common.Strings; |
32 | 33 | import org.elasticsearch.common.bytes.BytesReference; |
33 | 34 | import org.elasticsearch.common.settings.ClusterSettings; |
|
79 | 80 | import java.util.Set; |
80 | 81 | import java.util.concurrent.CountDownLatch; |
81 | 82 | import java.util.concurrent.TimeUnit; |
82 | | -import java.util.function.Function; |
83 | 83 |
|
84 | 84 | import static org.elasticsearch.index.IndexingPressure.MAX_COORDINATING_BYTES; |
85 | 85 | import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; |
86 | 86 | import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; |
| 87 | +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; |
87 | 88 | import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE; |
88 | 89 | import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull; |
89 | 90 | import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; |
@@ -505,7 +506,6 @@ public void testIndexingPressure() throws Exception { |
505 | 506 | final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(Settings.EMPTY); |
506 | 507 | final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); |
507 | 508 | final StaticModel denseModel = StaticModel.createRandomInstance(TaskType.TEXT_EMBEDDING); |
508 | | - final Function<XContentBuilder, Long> bytesUsed = b -> BytesReference.bytes(b).ramBytesUsed(); |
509 | 509 | final ShardBulkInferenceActionFilter filter = createFilter( |
510 | 510 | threadPool, |
511 | 511 | Map.of(sparseModel.getInferenceEntityId(), sparseModel, denseModel.getInferenceEntityId(), denseModel), |
@@ -558,14 +558,14 @@ public void testIndexingPressure() throws Exception { |
558 | 558 |
|
559 | 559 | IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); |
560 | 560 | assertThat(coordinatingIndexingPressure, notNullValue()); |
561 | | - verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc0Source)); |
562 | | - verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc1Source)); |
563 | | - verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc2Source)); |
564 | | - verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc3Source)); |
565 | | - verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc4Source)); |
566 | | - verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc0UpdateSource)); |
| 561 | + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc0Source)); |
| 562 | + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source)); |
| 563 | + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc2Source)); |
| 564 | + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc3Source)); |
| 565 | + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc4Source)); |
| 566 | + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc0UpdateSource)); |
567 | 567 | if (useLegacyFormat == false) { |
568 | | - verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc1UpdateSource)); |
| 568 | + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1UpdateSource)); |
569 | 569 | } |
570 | 570 |
|
571 | 571 | verify(coordinatingIndexingPressure, times(useLegacyFormat ? 6 : 7)).increment(eq(0), longThat(l -> l > 0)); |
@@ -660,7 +660,7 @@ public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Excep |
660 | 660 |
|
661 | 661 | IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); |
662 | 662 | assertThat(coordinatingIndexingPressure, notNullValue()); |
663 | | - verify(coordinatingIndexingPressure).increment(1, BytesReference.bytes(doc1Source).ramBytesUsed()); |
| 663 | + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source)); |
664 | 664 | verify(coordinatingIndexingPressure, times(1)).increment(anyInt(), anyLong()); |
665 | 665 |
|
666 | 666 | // Verify that the coordinating indexing pressure is maintained through downstream action filters |
@@ -699,7 +699,7 @@ public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Excep |
699 | 699 | public void testIndexingPressureTripsOnInferenceResponseHandling() throws Exception { |
700 | 700 | final XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar"); |
701 | 701 | final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure( |
702 | | - Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (BytesReference.bytes(doc1Source).ramBytesUsed() + 1) + "b").build() |
| 702 | + Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (bytesUsed(doc1Source) + 1) + "b").build() |
703 | 703 | ); |
704 | 704 |
|
705 | 705 | final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); |
@@ -740,7 +740,7 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except |
740 | 740 |
|
741 | 741 | IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); |
742 | 742 | assertThat(coordinatingIndexingPressure, notNullValue()); |
743 | | - verify(coordinatingIndexingPressure).increment(1, BytesReference.bytes(doc1Source).ramBytesUsed()); |
| 743 | + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source)); |
744 | 744 | verify(coordinatingIndexingPressure).increment(eq(0), longThat(l -> l > 0)); |
745 | 745 | verify(coordinatingIndexingPressure, times(2)).increment(anyInt(), anyLong()); |
746 | 746 |
|
@@ -776,6 +776,117 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except |
776 | 776 | verify(coordinatingIndexingPressure).close(); |
777 | 777 | } |
778 | 778 |
|
| 779 | + @SuppressWarnings("unchecked") |
| 780 | + public void testIndexingPressurePartialFailure() throws Exception { |
| 781 | + // Use different length strings so that doc 1 and doc 2 sources are different sizes |
| 782 | + final XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar"); |
| 783 | + final XContentBuilder doc2Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bazzz"); |
| 784 | + |
| 785 | + final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); |
| 786 | + final ChunkedInferenceEmbedding barEmbedding = randomChunkedInferenceEmbedding(sparseModel, List.of("bar")); |
| 787 | + final ChunkedInferenceEmbedding bazzzEmbedding = randomChunkedInferenceEmbedding(sparseModel, List.of("bazzz")); |
| 788 | + sparseModel.putResult("bar", barEmbedding); |
| 789 | + sparseModel.putResult("bazzz", bazzzEmbedding); |
| 790 | + |
| 791 | + CheckedBiFunction<List<String>, ChunkedInference, Long, IOException> estimateInferenceResultsBytes = (inputs, inference) -> { |
| 792 | + SemanticTextField semanticTextField = semanticTextFieldFromChunkedInferenceResults( |
| 793 | + useLegacyFormat, |
| 794 | + "sparse_field", |
| 795 | + sparseModel, |
| 796 | + null, |
| 797 | + inputs, |
| 798 | + inference, |
| 799 | + XContentType.JSON |
| 800 | + ); |
| 801 | + XContentBuilder builder = XContentFactory.jsonBuilder(); |
| 802 | + semanticTextField.toXContent(builder, EMPTY_PARAMS); |
| 803 | + return bytesUsed(builder); |
| 804 | + }; |
| 805 | + |
| 806 | + final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure( |
| 807 | + Settings.builder() |
| 808 | + .put( |
| 809 | + MAX_COORDINATING_BYTES.getKey(), |
| 810 | + (bytesUsed(doc1Source) + bytesUsed(doc2Source) + estimateInferenceResultsBytes.apply(List.of("bar"), barEmbedding) |
| 811 | + + (estimateInferenceResultsBytes.apply(List.of("bazzz"), bazzzEmbedding) / 2)) + "b" |
| 812 | + ) |
| 813 | + .build() |
| 814 | + ); |
| 815 | + |
| 816 | + final ShardBulkInferenceActionFilter filter = createFilter( |
| 817 | + threadPool, |
| 818 | + Map.of(sparseModel.getInferenceEntityId(), sparseModel), |
| 819 | + indexingPressure, |
| 820 | + useLegacyFormat, |
| 821 | + true |
| 822 | + ); |
| 823 | + |
| 824 | + CountDownLatch chainExecuted = new CountDownLatch(1); |
| 825 | + ActionFilterChain<BulkShardRequest, BulkShardResponse> actionFilterChain = (task, action, request, listener) -> { |
| 826 | + try { |
| 827 | + assertNull(request.getInferenceFieldMap()); |
| 828 | + assertThat(request.items().length, equalTo(4)); |
| 829 | + |
| 830 | + assertNull(request.items()[0].getPrimaryResponse()); |
| 831 | + assertNull(request.items()[1].getPrimaryResponse()); |
| 832 | + assertNull(request.items()[3].getPrimaryResponse()); |
| 833 | + |
| 834 | + BulkItemRequest doc2Request = request.items()[2]; |
| 835 | + BulkItemResponse doc2Response = doc2Request.getPrimaryResponse(); |
| 836 | + assertNotNull(doc2Response); |
| 837 | + assertTrue(doc2Response.isFailed()); |
| 838 | + BulkItemResponse.Failure doc2Failure = doc2Response.getFailure(); |
| 839 | + assertThat( |
| 840 | + doc2Failure.getCause().getMessage(), |
| 841 | + containsString("Insufficient memory available to insert inference results into document [doc_2]") |
| 842 | + ); |
| 843 | + assertThat(doc2Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class)); |
| 844 | + assertThat(doc2Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS)); |
| 845 | + |
| 846 | + IndexRequest doc2IndexRequest = getIndexRequestOrNull(doc2Request.request()); |
| 847 | + assertThat(doc2IndexRequest, notNullValue()); |
| 848 | + assertThat(doc2IndexRequest.source(), equalTo(BytesReference.bytes(doc2Source))); |
| 849 | + |
| 850 | + IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); |
| 851 | + assertThat(coordinatingIndexingPressure, notNullValue()); |
| 852 | + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source)); |
| 853 | + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc2Source)); |
| 854 | + verify(coordinatingIndexingPressure, times(2)).increment(eq(0), longThat(l -> l > 0)); |
| 855 | + verify(coordinatingIndexingPressure, times(4)).increment(anyInt(), anyLong()); |
| 856 | + |
| 857 | + // Verify that the coordinating indexing pressure is maintained through downstream action filters |
| 858 | + verify(coordinatingIndexingPressure, never()).close(); |
| 859 | + |
| 860 | + // Call the listener once the request is successfully processed, like is done in the production code path |
| 861 | + listener.onResponse(null); |
| 862 | + } finally { |
| 863 | + chainExecuted.countDown(); |
| 864 | + } |
| 865 | + }; |
| 866 | + ActionListener<BulkShardResponse> actionListener = (ActionListener<BulkShardResponse>) mock(ActionListener.class); |
| 867 | + Task task = mock(Task.class); |
| 868 | + |
| 869 | + Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of( |
| 870 | + "sparse_field", |
| 871 | + new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null) |
| 872 | + ); |
| 873 | + |
| 874 | + BulkItemRequest[] items = new BulkItemRequest[4]; |
| 875 | + items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source("non_inference_field", "foo")); |
| 876 | + items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source)); |
| 877 | + items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source(doc2Source)); |
| 878 | + items[3] = new BulkItemRequest(3, new IndexRequest("index").id("doc_3").source("non_inference_field", "baz")); |
| 879 | + |
| 880 | + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); |
| 881 | + request.setInferenceFieldMap(inferenceFieldMap); |
| 882 | + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); |
| 883 | + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); |
| 884 | + |
| 885 | + IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); |
| 886 | + assertThat(coordinatingIndexingPressure, notNullValue()); |
| 887 | + verify(coordinatingIndexingPressure).close(); |
| 888 | + } |
| 889 | + |
779 | 890 | @SuppressWarnings("unchecked") |
780 | 891 | private static ShardBulkInferenceActionFilter createFilter( |
781 | 892 | ThreadPool threadPool, |
@@ -947,6 +1058,10 @@ private static BulkItemRequest[] randomBulkItemRequest( |
947 | 1058 | new BulkItemRequest(requestId, new IndexRequest("index").source(expectedDocMap, requestContentType)) }; |
948 | 1059 | } |
949 | 1060 |
|
| 1061 | + private static long bytesUsed(XContentBuilder builder) { |
| 1062 | + return BytesReference.bytes(builder).ramBytesUsed(); |
| 1063 | + } |
| 1064 | + |
950 | 1065 | @SuppressWarnings({ "unchecked" }) |
951 | 1066 | private static void assertInferenceResults( |
952 | 1067 | boolean useLegacyFormat, |
|
0 commit comments