Skip to content

Commit f4aef73

Browse files
committed
Added partial failure test
1 parent 6a3a4fd commit f4aef73

File tree

1 file changed

+127
-12
lines changed

1 file changed

+127
-12
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

Lines changed: 127 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.cluster.metadata.Metadata;
2929
import org.elasticsearch.cluster.metadata.ProjectMetadata;
3030
import org.elasticsearch.cluster.service.ClusterService;
31+
import org.elasticsearch.common.CheckedBiFunction;
3132
import org.elasticsearch.common.Strings;
3233
import org.elasticsearch.common.bytes.BytesReference;
3334
import org.elasticsearch.common.settings.ClusterSettings;
@@ -79,11 +80,11 @@
7980
import java.util.Set;
8081
import java.util.concurrent.CountDownLatch;
8182
import java.util.concurrent.TimeUnit;
82-
import java.util.function.Function;
8383

8484
import static org.elasticsearch.index.IndexingPressure.MAX_COORDINATING_BYTES;
8585
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
8686
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch;
87+
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
8788
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
8889
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull;
8990
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName;
@@ -505,7 +506,6 @@ public void testIndexingPressure() throws Exception {
505506
final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(Settings.EMPTY);
506507
final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
507508
final StaticModel denseModel = StaticModel.createRandomInstance(TaskType.TEXT_EMBEDDING);
508-
final Function<XContentBuilder, Long> bytesUsed = b -> BytesReference.bytes(b).ramBytesUsed();
509509
final ShardBulkInferenceActionFilter filter = createFilter(
510510
threadPool,
511511
Map.of(sparseModel.getInferenceEntityId(), sparseModel, denseModel.getInferenceEntityId(), denseModel),
@@ -558,14 +558,14 @@ public void testIndexingPressure() throws Exception {
558558

559559
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
560560
assertThat(coordinatingIndexingPressure, notNullValue());
561-
verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc0Source));
562-
verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc1Source));
563-
verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc2Source));
564-
verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc3Source));
565-
verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc4Source));
566-
verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc0UpdateSource));
561+
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc0Source));
562+
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source));
563+
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc2Source));
564+
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc3Source));
565+
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc4Source));
566+
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc0UpdateSource));
567567
if (useLegacyFormat == false) {
568-
verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc1UpdateSource));
568+
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1UpdateSource));
569569
}
570570

571571
verify(coordinatingIndexingPressure, times(useLegacyFormat ? 6 : 7)).increment(eq(0), longThat(l -> l > 0));
@@ -660,7 +660,7 @@ public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Excep
660660

661661
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
662662
assertThat(coordinatingIndexingPressure, notNullValue());
663-
verify(coordinatingIndexingPressure).increment(1, BytesReference.bytes(doc1Source).ramBytesUsed());
663+
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source));
664664
verify(coordinatingIndexingPressure, times(1)).increment(anyInt(), anyLong());
665665

666666
// Verify that the coordinating indexing pressure is maintained through downstream action filters
@@ -699,7 +699,7 @@ public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Excep
699699
public void testIndexingPressureTripsOnInferenceResponseHandling() throws Exception {
700700
final XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");
701701
final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
702-
Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (BytesReference.bytes(doc1Source).ramBytesUsed() + 1) + "b").build()
702+
Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (bytesUsed(doc1Source) + 1) + "b").build()
703703
);
704704

705705
final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
@@ -740,7 +740,7 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except
740740

741741
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
742742
assertThat(coordinatingIndexingPressure, notNullValue());
743-
verify(coordinatingIndexingPressure).increment(1, BytesReference.bytes(doc1Source).ramBytesUsed());
743+
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source));
744744
verify(coordinatingIndexingPressure).increment(eq(0), longThat(l -> l > 0));
745745
verify(coordinatingIndexingPressure, times(2)).increment(anyInt(), anyLong());
746746

@@ -776,6 +776,117 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except
776776
verify(coordinatingIndexingPressure).close();
777777
}
778778

779+
@SuppressWarnings("unchecked")
780+
public void testIndexingPressurePartialFailure() throws Exception {
781+
// Use different length strings so that doc 1 and doc 2 sources are different sizes
782+
final XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");
783+
final XContentBuilder doc2Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bazzz");
784+
785+
final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
786+
final ChunkedInferenceEmbedding barEmbedding = randomChunkedInferenceEmbedding(sparseModel, List.of("bar"));
787+
final ChunkedInferenceEmbedding bazzzEmbedding = randomChunkedInferenceEmbedding(sparseModel, List.of("bazzz"));
788+
sparseModel.putResult("bar", barEmbedding);
789+
sparseModel.putResult("bazzz", bazzzEmbedding);
790+
791+
CheckedBiFunction<List<String>, ChunkedInference, Long, IOException> estimateInferenceResultsBytes = (inputs, inference) -> {
792+
SemanticTextField semanticTextField = semanticTextFieldFromChunkedInferenceResults(
793+
useLegacyFormat,
794+
"sparse_field",
795+
sparseModel,
796+
null,
797+
inputs,
798+
inference,
799+
XContentType.JSON
800+
);
801+
XContentBuilder builder = XContentFactory.jsonBuilder();
802+
semanticTextField.toXContent(builder, EMPTY_PARAMS);
803+
return bytesUsed(builder);
804+
};
805+
806+
final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
807+
Settings.builder()
808+
.put(
809+
MAX_COORDINATING_BYTES.getKey(),
810+
(bytesUsed(doc1Source) + bytesUsed(doc2Source) + estimateInferenceResultsBytes.apply(List.of("bar"), barEmbedding)
811+
+ (estimateInferenceResultsBytes.apply(List.of("bazzz"), bazzzEmbedding) / 2)) + "b"
812+
)
813+
.build()
814+
);
815+
816+
final ShardBulkInferenceActionFilter filter = createFilter(
817+
threadPool,
818+
Map.of(sparseModel.getInferenceEntityId(), sparseModel),
819+
indexingPressure,
820+
useLegacyFormat,
821+
true
822+
);
823+
824+
CountDownLatch chainExecuted = new CountDownLatch(1);
825+
ActionFilterChain<BulkShardRequest, BulkShardResponse> actionFilterChain = (task, action, request, listener) -> {
826+
try {
827+
assertNull(request.getInferenceFieldMap());
828+
assertThat(request.items().length, equalTo(4));
829+
830+
assertNull(request.items()[0].getPrimaryResponse());
831+
assertNull(request.items()[1].getPrimaryResponse());
832+
assertNull(request.items()[3].getPrimaryResponse());
833+
834+
BulkItemRequest doc2Request = request.items()[2];
835+
BulkItemResponse doc2Response = doc2Request.getPrimaryResponse();
836+
assertNotNull(doc2Response);
837+
assertTrue(doc2Response.isFailed());
838+
BulkItemResponse.Failure doc2Failure = doc2Response.getFailure();
839+
assertThat(
840+
doc2Failure.getCause().getMessage(),
841+
containsString("Insufficient memory available to insert inference results into document [doc_2]")
842+
);
843+
assertThat(doc2Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class));
844+
assertThat(doc2Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS));
845+
846+
IndexRequest doc2IndexRequest = getIndexRequestOrNull(doc2Request.request());
847+
assertThat(doc2IndexRequest, notNullValue());
848+
assertThat(doc2IndexRequest.source(), equalTo(BytesReference.bytes(doc2Source)));
849+
850+
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
851+
assertThat(coordinatingIndexingPressure, notNullValue());
852+
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source));
853+
verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc2Source));
854+
verify(coordinatingIndexingPressure, times(2)).increment(eq(0), longThat(l -> l > 0));
855+
verify(coordinatingIndexingPressure, times(4)).increment(anyInt(), anyLong());
856+
857+
// Verify that the coordinating indexing pressure is maintained through downstream action filters
858+
verify(coordinatingIndexingPressure, never()).close();
859+
860+
// Call the listener once the request is successfully processed, like is done in the production code path
861+
listener.onResponse(null);
862+
} finally {
863+
chainExecuted.countDown();
864+
}
865+
};
866+
ActionListener<BulkShardResponse> actionListener = (ActionListener<BulkShardResponse>) mock(ActionListener.class);
867+
Task task = mock(Task.class);
868+
869+
Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
870+
"sparse_field",
871+
new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null)
872+
);
873+
874+
BulkItemRequest[] items = new BulkItemRequest[4];
875+
items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source("non_inference_field", "foo"));
876+
items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source));
877+
items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source(doc2Source));
878+
items[3] = new BulkItemRequest(3, new IndexRequest("index").id("doc_3").source("non_inference_field", "baz"));
879+
880+
BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
881+
request.setInferenceFieldMap(inferenceFieldMap);
882+
filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
883+
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
884+
885+
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
886+
assertThat(coordinatingIndexingPressure, notNullValue());
887+
verify(coordinatingIndexingPressure).close();
888+
}
889+
779890
@SuppressWarnings("unchecked")
780891
private static ShardBulkInferenceActionFilter createFilter(
781892
ThreadPool threadPool,
@@ -947,6 +1058,10 @@ private static BulkItemRequest[] randomBulkItemRequest(
9471058
new BulkItemRequest(requestId, new IndexRequest("index").source(expectedDocMap, requestContentType)) };
9481059
}
9491060

1061+
private static long bytesUsed(XContentBuilder builder) {
1062+
return BytesReference.bytes(builder).ramBytesUsed();
1063+
}
1064+
9501065
@SuppressWarnings({ "unchecked" })
9511066
private static void assertInferenceResults(
9521067
boolean useLegacyFormat,

0 commit comments

Comments
 (0)