Skip to content

Commit 767b2cc

Browse files
prwhelanelasticmachine
authored andcommitted
[ML] Refactor Apache streaming (elastic#113726)
- Apache calls `consumeContent` once per batch, so it should be safe to reuse the bytebuffer by resetting it after each invocation. This saves on memory management a bit. - Simplifying the initial response handling - we now invoke the listener once we get the first set of bytes, either for a successful response or a failure response, and we will now just verify the HttpResponse status line on the first response. This helps with some pausing when we received responses with just headers from some providers, and the response is now smoother. - Response conversion now happens completely within the Flow processor rather than split between the processor and the response handler (so future code can all be in one place). Co-authored-by: Elastic Machine <[email protected]>
1 parent 5b83781 commit 767b2cc

File tree

10 files changed

+89
-149
lines changed

10 files changed

+89
-149
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisher.java

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResponse>, Flow.Publisher<HttpResult> {
4343
private final HttpSettings settings;
4444
private final ActionListener<Flow.Publisher<HttpResult>> listener;
45+
private final AtomicBoolean listenerCalled = new AtomicBoolean(false);
4546

4647
// used to manage the HTTP response
4748
private volatile HttpResponse response;
@@ -57,6 +58,7 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResp
5758
private final Deque<Runnable> queue = new ConcurrentLinkedDeque<>();
5859

5960
// used to control the flow of data from the Apache client, if we're producing more bytes than we can consume then we'll pause
61+
private final SimpleInputBuffer inputBuffer = new SimpleInputBuffer(4096);
6062
private final AtomicLong bytesInQueue = new AtomicLong(0);
6163
private final Object ioLock = new Object();
6264
private volatile IOControl savedIoControl;
@@ -69,16 +71,8 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResp
6971
}
7072

7173
@Override
72-
public void responseReceived(HttpResponse httpResponse) throws IOException {
74+
public void responseReceived(HttpResponse httpResponse) {
7375
this.response = httpResponse;
74-
if (response.getEntity() == null || response.getEntity().getContentLength() <= 0) {
75-
// on success, we may receive an empty content payload to initiate the stream
76-
this.queue.offer(() -> subscriber.onNext(new HttpResult(response, new byte[0])));
77-
} else {
78-
var firstResponse = HttpResult.create(settings.getMaxResponseSize(), response);
79-
this.queue.offer(() -> subscriber.onNext(firstResponse));
80-
}
81-
this.listener.onResponse(this);
8276
}
8377

8478
@Override
@@ -100,32 +94,39 @@ public void consumeContent(ContentDecoder contentDecoder, IOControl ioControl) t
10094
return;
10195
}
10296

103-
var buffer = new SimpleInputBuffer(4096);
104-
var consumed = buffer.consumeContent(contentDecoder);
105-
var allBytes = new byte[consumed];
106-
buffer.read(allBytes);
107-
108-
// we can have empty bytes, don't bother sending them
109-
if (allBytes.length > 0) {
110-
queue.offer(() -> {
111-
subscriber.onNext(new HttpResult(response, allBytes));
112-
var currentBytesInQueue = bytesInQueue.updateAndGet(current -> Long.max(0, current - allBytes.length));
113-
if (savedIoControl != null) {
114-
var maxBytes = settings.getMaxResponseSize().getBytes() * 0.5;
115-
if (currentBytesInQueue <= maxBytes) {
116-
resumeProducer();
97+
try {
98+
var consumed = inputBuffer.consumeContent(contentDecoder);
99+
var allBytes = new byte[consumed];
100+
inputBuffer.read(allBytes);
101+
102+
// we can have empty bytes, don't bother sending them
103+
if (allBytes.length > 0) {
104+
queue.offer(() -> {
105+
subscriber.onNext(new HttpResult(response, allBytes));
106+
var currentBytesInQueue = bytesInQueue.updateAndGet(current -> Long.max(0, current - allBytes.length));
107+
if (savedIoControl != null) {
108+
var maxBytes = settings.getMaxResponseSize().getBytes() * 0.5;
109+
if (currentBytesInQueue <= maxBytes) {
110+
resumeProducer();
111+
}
117112
}
118-
}
119-
});
120-
}
113+
});
114+
}
121115

122-
// always check if totalByteSize > the configured setting in case the settings change
123-
if (bytesInQueue.accumulateAndGet(allBytes.length, Long::sum) >= settings.getMaxResponseSize().getBytes()) {
124-
pauseProducer(ioControl);
125-
}
116+
// always check if totalByteSize > the configured setting in case the settings change
117+
if (bytesInQueue.accumulateAndGet(allBytes.length, Long::sum) >= settings.getMaxResponseSize().getBytes()) {
118+
pauseProducer(ioControl);
119+
}
120+
121+
// always run in case we're waking up from a pause and need to start a new thread
122+
taskRunner.requestNextRun();
126123

127-
// always run in case we're waking up from a pause and need to start a new thread
128-
taskRunner.requestNextRun();
124+
if (listenerCalled.compareAndSet(false, true)) {
125+
listener.onResponse(this);
126+
}
127+
} finally {
128+
inputBuffer.reset();
129+
}
129130
}
130131

131132
private void pauseProducer(IOControl ioControl) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,11 @@ default boolean canHandleStreamingResponses() {
6464
* HttpResults to the InferenceServiceResults.
6565
*
6666
* @param request The original request sent to the server
67-
* @param result The first result that initiated the stream. If the result is HTTP 200, this result will not contain content bytes
68-
* @param flow The remaining stream of results from the server. If the result is HTTP 200, these results will contain content bytes
67+
* @param flow The remaining stream of results from the server. If the result is HTTP 200, these results will contain content bytes
6968
* @return an inference results with {@link InferenceServiceResults#publisher()} set and {@link InferenceServiceResults#isStreaming()}
7069
* set to true.
7170
*/
72-
default InferenceServiceResults parseResult(Request request, HttpResult result, Flow.Publisher<HttpResult> flow) {
71+
default InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
7372
assert canHandleStreamingResponses() == false : "This must be implemented when canHandleStreamingResponses() == true";
7473
throw new UnsupportedOperationException("This must be implemented when canHandleStreamingResponses() == true");
7574
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ public void tryAction(ActionListener<InferenceServiceResults> listener) {
116116
try {
117117
if (request.isStreaming() && responseHandler.canHandleStreamingResponses()) {
118118
httpClient.stream(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {
119-
r.subscribe(new StreamingResponseHandler(throttlerManager, logger, request, responseHandler, l));
119+
var streamingResponseHandler = new StreamingResponseHandler(throttlerManager, logger, request, responseHandler);
120+
r.subscribe(streamingResponseHandler);
121+
l.onResponse(responseHandler.parseResult(request, streamingResponseHandler));
120122
}));
121123
} else {
122124
httpClient.send(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/StreamingResponseHandler.java

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
1212
import org.elasticsearch.ExceptionsHelper;
13-
import org.elasticsearch.action.ActionListener;
14-
import org.elasticsearch.inference.InferenceServiceResults;
1513
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1614
import org.elasticsearch.xpack.inference.external.request.Request;
1715
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
@@ -27,26 +25,18 @@ class StreamingResponseHandler implements Flow.Processor<HttpResult, HttpResult>
2725
private final Logger throttlerLogger;
2826
private final Request request;
2927
private final ResponseHandler responseHandler;
30-
private final ActionListener<InferenceServiceResults> listener;
3128

3229
private final AtomicBoolean upstreamIsClosed = new AtomicBoolean(false);
3330
private final AtomicBoolean processedFirstItem = new AtomicBoolean(false);
3431

3532
private volatile Flow.Subscription upstream;
3633
private volatile Flow.Subscriber<? super HttpResult> downstream;
3734

38-
StreamingResponseHandler(
39-
ThrottlerManager throttlerManager,
40-
Logger throttlerLogger,
41-
Request request,
42-
ResponseHandler responseHandler,
43-
ActionListener<InferenceServiceResults> listener
44-
) {
35+
StreamingResponseHandler(ThrottlerManager throttlerManager, Logger throttlerLogger, Request request, ResponseHandler responseHandler) {
4536
this.throttlerManager = throttlerManager;
4637
this.throttlerLogger = throttlerLogger;
4738
this.request = request;
4839
this.responseHandler = responseHandler;
49-
this.listener = listener;
5040
}
5141

5242
@Override
@@ -90,27 +80,21 @@ public void cancel() {
9080
@Override
9181
public void onSubscribe(Flow.Subscription subscription) {
9282
upstream = subscription;
93-
// start the first request, which will call onNext and validate the first HttpResult
94-
upstream.request(1);
9583
}
9684

9785
@Override
9886
public void onNext(HttpResult item) {
9987
if (processedFirstItem.compareAndSet(false, true)) {
10088
try {
10189
responseHandler.validateResponse(throttlerManager, throttlerLogger, request, item);
102-
var inferenceServiceResults = responseHandler.parseResult(request, item, this);
103-
assert downstream != null : "the responseHandler must invoke the subscribe method";
104-
listener.onResponse(inferenceServiceResults);
10590
} catch (Exception e) {
10691
logException(throttlerLogger, request, item, responseHandler.getRequestType(), e);
107-
listener.onFailure(e);
10892
upstream.cancel();
10993
onError(e);
94+
return;
11095
}
111-
} else {
112-
downstream.onNext(item);
11396
}
97+
downstream.onNext(item);
11498
}
11599

116100
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiChatCompletionResponseHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public boolean canHandleStreamingResponses() {
3535
}
3636

3737
@Override
38-
public InferenceServiceResults parseResult(Request request, HttpResult result, Flow.Publisher<HttpResult> flow) {
38+
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
3939
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
4040
var openAiProcessor = new OpenAiStreamingProcessor();
4141

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseListener.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.ElasticsearchException;
1313
import org.elasticsearch.action.ActionListener;
1414
import org.elasticsearch.inference.InferenceServiceResults;
15+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1516
import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest;
1617
import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler;
1718
import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseListener;
@@ -29,7 +30,7 @@ public AmazonBedrockChatCompletionResponseListener(
2930
@Override
3031
public void onResponse(ConverseResult result) {
3132
((AmazonBedrockChatCompletionResponseHandler) responseHandler).acceptChatCompletionResponseObject(result);
32-
inferenceResultsListener.onResponse(responseHandler.parseResult(request, null));
33+
inferenceResultsListener.onResponse(responseHandler.parseResult(request, (HttpResult) null));
3334
}
3435

3536
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseListener.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1415
import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockEmbeddingsRequest;
1516
import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler;
1617
import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseListener;
@@ -28,7 +29,7 @@ public AmazonBedrockEmbeddingsResponseListener(
2829
@Override
2930
public void onResponse(InvokeModelResult result) {
3031
((AmazonBedrockEmbeddingsResponseHandler) responseHandler).acceptEmbeddingsResult(result);
31-
inferenceResultsListener.onResponse(responseHandler.parseResult(request, null));
32+
inferenceResultsListener.onResponse(responseHandler.parseResult(request, (HttpResult) null));
3233
}
3334

3435
@Override

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisherTests.java

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.external.http;
99

10-
import org.apache.http.HttpEntity;
1110
import org.apache.http.HttpResponse;
1211
import org.apache.http.nio.ContentDecoder;
1312
import org.apache.http.nio.IOControl;
@@ -19,7 +18,6 @@
1918
import org.elasticsearch.threadpool.ThreadPool;
2019
import org.junit.Before;
2120

22-
import java.io.ByteArrayInputStream;
2321
import java.io.IOException;
2422
import java.nio.ByteBuffer;
2523
import java.nio.charset.StandardCharsets;
@@ -80,25 +78,7 @@ public void testFirstResponseCallsListener() throws IOException {
8078
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
8179

8280
publisher.responseReceived(mock(HttpResponse.class));
83-
84-
assertThat("Listener's onResponse should be called when we receive a response", latch.getCount(), equalTo(0L));
85-
}
86-
87-
/**
88-
* When we receive an http response with an entity with no content
89-
* Then we call the listener
90-
* And we queue the initial payload
91-
*/
92-
public void testEmptyFirstResponseCallsListener() throws IOException {
93-
var latch = new CountDownLatch(1);
94-
var listener = ActionTestUtils.<Flow.Publisher<HttpResult>>assertNoFailureListener(r -> latch.countDown());
95-
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
96-
97-
var response = mock(HttpResponse.class);
98-
var entity = mock(HttpEntity.class);
99-
when(entity.getContentLength()).thenReturn(-1L);
100-
when(response.getEntity()).thenReturn(entity);
101-
publisher.responseReceived(response);
81+
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
10282

10383
assertThat("Listener's onResponse should be called when we receive a response", latch.getCount(), equalTo(0L));
10484
}
@@ -114,12 +94,8 @@ public void testNonEmptyFirstResponseCallsListener() throws IOException {
11494
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
11595

11696
when(settings.getMaxResponseSize()).thenReturn(ByteSizeValue.ofBytes(9000));
117-
var response = mock(HttpResponse.class);
118-
var entity = mock(HttpEntity.class);
119-
when(entity.getContentLength()).thenReturn(5L);
120-
when(entity.getContent()).thenReturn(new ByteArrayInputStream(message));
121-
when(response.getEntity()).thenReturn(entity);
122-
publisher.responseReceived(response);
97+
publisher.responseReceived(mock(HttpResponse.class));
98+
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
12399

124100
assertThat("Listener's onResponse should be called when we receive a response", latch.getCount(), equalTo(0L));
125101
}
@@ -146,6 +122,7 @@ public void testNonEmptyFirstResponseCallsListener() throws IOException {
146122
public void testSubscriberAndPublisherExchange() throws IOException {
147123
var subscriber = new TestSubscriber();
148124
publisher.responseReceived(mock(HttpResponse.class));
125+
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
149126

150127
// subscribe
151128
publisher.subscribe(subscriber);
@@ -155,7 +132,6 @@ public void testSubscriberAndPublisherExchange() throws IOException {
155132
// request the initial http response
156133
subscriber.requestData();
157134
assertThat("onNext was called with the initial HttpResponse", subscriber.httpResult, notNullValue());
158-
assertTrue("HttpResponse has an empty body (because there is no HttpEntity)", subscriber.httpResult.isBodyEmpty());
159135
subscriber.httpResult = null; // reset test
160136

161137
// subscriber requests data, publisher has not sent data yet
@@ -175,14 +151,14 @@ public void testNon200Response() throws IOException {
175151
var subscriber = new TestSubscriber();
176152
// Apache sends a single response and closes the consumer
177153
publisher.responseReceived(mock(HttpResponse.class));
154+
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
178155
publisher.close();
179156

180157
// subscriber requests data
181158
publisher.subscribe(subscriber);
182159
assertThat("subscribe must call onSubscribe", subscriber.subscription, notNullValue());
183160
subscriber.requestData();
184161
assertThat("onNext was called with the initial HttpResponse", subscriber.httpResult, notNullValue());
185-
assertTrue("HttpResponse has an empty body (because there is no HttpEntity)", subscriber.httpResult.isBodyEmpty());
186162
subscriber.requestData();
187163
assertTrue("Publisher has been closed", publisher.isDone());
188164
assertTrue("Subscriber has been completed", subscriber.completed);
@@ -233,6 +209,7 @@ public void testTotalBytesDecrement() throws IOException {
233209

234210
var subscriber = new TestSubscriber();
235211
publisher.responseReceived(mock(HttpResponse.class));
212+
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
236213
publisher.subscribe(subscriber);
237214
subscriber.requestData();
238215
subscriber.httpResult = null;
@@ -476,18 +453,25 @@ public void testDoubleSubscribeFails() {
476453
* When a new request is processed
477454
* Then it should reuse that ML Utility thread
478455
*/
479-
public void testReuseMlThread() throws IOException, ExecutionException, InterruptedException, TimeoutException {
456+
public void testReuseMlThread() throws ExecutionException, InterruptedException, TimeoutException {
480457
try {
481458
threadPool = spy(createThreadPool(inferenceUtilityPool()));
482459
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
483460
var subscriber = new TestSubscriber();
484461
publisher.responseReceived(mock(HttpResponse.class));
485462
publisher.subscribe(subscriber);
486463

487-
CompletableFuture.runAsync(subscriber::requestData, threadPool.executor(UTILITY_THREAD_POOL_NAME)).get(5, TimeUnit.SECONDS);
464+
CompletableFuture.runAsync(() -> {
465+
try {
466+
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
467+
subscriber.requestData();
468+
} catch (IOException e) {
469+
throw new RuntimeException(e);
470+
}
471+
}, threadPool.executor(UTILITY_THREAD_POOL_NAME)).get(5, TimeUnit.SECONDS);
488472
verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME);
489473
assertThat("onNext was called with the initial HttpResponse", subscriber.httpResult, notNullValue());
490-
assertTrue("HttpResponse has an empty body (because there is no HttpEntity)", subscriber.httpResult.isBodyEmpty());
474+
assertFalse("Expected HttpResult to have data", subscriber.httpResult.isBodyEmpty());
491475
} finally {
492476
terminate(threadPool);
493477
}
@@ -514,6 +498,7 @@ public void testCancelBreaksInfiniteLoop() throws Exception {
514498

515499
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
516500
publisher.responseReceived(mock(HttpResponse.class));
501+
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
517502
// create an infinitely running Subscriber
518503
var subscriber = new Flow.Subscriber<HttpResult>() {
519504
Flow.Subscription subscription;
@@ -616,6 +601,7 @@ private TestSubscriber subscribe() {
616601

617602
private TestSubscriber runBefore(Runnable runDuringOnNext) throws IOException {
618603
publisher.responseReceived(mock(HttpResponse.class));
604+
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
619605
TestSubscriber subscriber = new TestSubscriber() {
620606
public void onNext(HttpResult item) {
621607
runDuringOnNext.run();
@@ -628,6 +614,7 @@ public void onNext(HttpResult item) {
628614

629615
private TestSubscriber runAfter(Runnable runDuringOnNext) throws IOException {
630616
publisher.responseReceived(mock(HttpResponse.class));
617+
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
631618
TestSubscriber subscriber = new TestSubscriber() {
632619
public void onNext(HttpResult item) {
633620
runDuringOnNext.run();

0 commit comments

Comments
 (0)