Skip to content

Commit 1ffe41b

Browse files
authored
[ML] Send mid-stream errors to users (#114549) (#114746)
If apache sends an error mid stream, forward it to the user rather than the now-ignored listener.
1 parent 633ea4d commit 1ffe41b

File tree

4 files changed

+112
-64
lines changed

4 files changed

+112
-64
lines changed

docs/changelog/114549.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 114549
2+
summary: Send mid-stream errors to users
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -153,44 +153,31 @@ public void stream(HttpRequest request, HttpContext context, ActionListener<Flow
153153
// The caller must call start() first before attempting to send a request
154154
assert status.get() == Status.STARTED : "call start() before attempting to send a request";
155155

156-
// apache can sometimes send us the same error in the consumer and the callback
157-
// sometimes it sends us an error just on the callback
158-
// notifyOnce will dedupe for us
159-
var callOnceListener = ActionListener.notifyOnce(listener);
160-
161-
SocketAccess.doPrivileged(
162-
() -> client.execute(
163-
request.requestProducer(),
164-
new StreamingHttpResultPublisher(threadPool, settings, callOnceListener),
165-
context,
166-
new FutureCallback<>() {
167-
@Override
168-
public void completed(HttpResponse response) {
169-
// StreamingHttpResultPublisher will publish results to the Flow.Publisher returned in the ActionListener
170-
}
171-
172-
@Override
173-
public void failed(Exception ex) {
174-
throttlerManager.warn(
175-
logger,
176-
format("Request from inference entity id [%s] failed", request.inferenceEntityId()),
177-
ex
178-
);
179-
failUsingUtilityThread(ex, callOnceListener);
180-
}
181-
182-
@Override
183-
public void cancelled() {
184-
failUsingUtilityThread(
156+
var streamingProcessor = new StreamingHttpResultPublisher(threadPool, settings, listener);
157+
158+
SocketAccess.doPrivileged(() -> client.execute(request.requestProducer(), streamingProcessor, context, new FutureCallback<>() {
159+
@Override
160+
public void completed(HttpResponse response) {
161+
streamingProcessor.close();
162+
}
163+
164+
@Override
165+
public void failed(Exception ex) {
166+
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> streamingProcessor.failed(ex));
167+
}
168+
169+
@Override
170+
public void cancelled() {
171+
threadPool.executor(UTILITY_THREAD_POOL_NAME)
172+
.execute(
173+
() -> streamingProcessor.failed(
185174
new CancellationException(
186175
format("Request from inference entity id [%s] was cancelled", request.inferenceEntityId())
187-
),
188-
callOnceListener
189-
);
190-
}
191-
}
192-
)
193-
);
176+
)
177+
)
178+
);
179+
}
180+
}));
194181
}
195182

196183
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisher.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResp
6565

6666
StreamingHttpResultPublisher(ThreadPool threadPool, HttpSettings settings, ActionListener<Flow.Publisher<HttpResult>> listener) {
6767
this.settings = Objects.requireNonNull(settings);
68-
this.listener = Objects.requireNonNull(listener);
68+
this.listener = ActionListener.notifyOnce(Objects.requireNonNull(listener));
6969

7070
this.taskRunner = new RequestBasedTaskRunner(new OffloadThread(), threadPool, UTILITY_THREAD_POOL_NAME);
7171
}
@@ -152,9 +152,13 @@ public void responseCompleted(HttpContext httpContext) {}
152152
@Override
153153
public void failed(Exception e) {
154154
if (this.isDone.compareAndSet(false, true)) {
155-
ex = e;
156-
queue.offer(() -> subscriber.onError(e));
157-
taskRunner.requestNextRun();
155+
if (listenerCalled.compareAndSet(false, true)) {
156+
listener.onFailure(e);
157+
} else {
158+
ex = e;
159+
queue.offer(() -> subscriber.onError(e));
160+
taskRunner.requestNextRun();
161+
}
158162
}
159163
}
160164

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisherTests.java

Lines changed: 76 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.test.ESTestCase;
1818
import org.elasticsearch.threadpool.ThreadPool;
1919
import org.junit.Before;
20+
import org.mockito.ArgumentCaptor;
2021

2122
import java.io.IOException;
2223
import java.nio.ByteBuffer;
@@ -38,6 +39,7 @@
3839
import static org.hamcrest.Matchers.notNullValue;
3940
import static org.hamcrest.Matchers.nullValue;
4041
import static org.mockito.ArgumentMatchers.any;
42+
import static org.mockito.ArgumentMatchers.eq;
4143
import static org.mockito.Mockito.doAnswer;
4244
import static org.mockito.Mockito.mock;
4345
import static org.mockito.Mockito.spy;
@@ -59,7 +61,7 @@ public void setUp() throws Exception {
5961
super.setUp();
6062
threadPool = mock(ThreadPool.class);
6163
settings = mock(HttpSettings.class);
62-
listener = ActionListener.noop();
64+
listener = spy(ActionListener.noop());
6365

6466
when(threadPool.executor(UTILITY_THREAD_POOL_NAME)).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
6567
when(settings.getMaxResponseSize()).thenReturn(ByteSizeValue.ofBytes(maxBytes));
@@ -235,33 +237,13 @@ public void testTotalBytesDecrement() throws IOException {
235237
}
236238

237239
/**
238-
* Given an error from Apache
239-
* When the subscriber requests the next set of data
240-
* Then the subscriber receives the error from Apache
240+
* When there is an error from Apache before the publisher invokes the listener
241+
* Then the publisher will forward the call to the listener's onFailure
241242
*/
242243
public void testErrorBeforeRequest() {
243-
var subscriber = subscribe();
244244
var exception = new NullPointerException("test");
245-
246245
publisher.failed(exception);
247-
assertThat("subscriber receives exception on next request", subscriber.throwable, nullValue());
248-
249-
subscriber.requestData();
250-
assertThat("subscriber receives exception", subscriber.throwable, is(exception));
251-
}
252-
253-
/**
254-
* Given the subscriber is waiting for data
255-
* When Apache sends an error
256-
* Then the subscriber immediately receives the error
257-
*/
258-
public void testErrorAfterRequest() {
259-
var subscriber = subscribe();
260-
var exception = new NullPointerException("test");
261-
262-
subscriber.requestData();
263-
publisher.failed(exception);
264-
assertThat("subscriber receives exception", subscriber.throwable, is(exception));
246+
verify(listener).onFailure(exception);
265247
}
266248

267249
/**
@@ -375,6 +357,76 @@ public void testCancelAfterRequest() {
375357
assertTrue("onComplete should be called", subscriber.completed);
376358
}
377359

360+
/**
361+
* When cancel is called
362+
* Then we only send onComplete once
363+
*/
364+
public void testCancelIsIdempotent() throws IOException {
365+
Flow.Subscriber<HttpResult> subscriber = mock();
366+
367+
var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
368+
publisher.subscribe(subscriber);
369+
verify(subscriber).onSubscribe(subscription.capture());
370+
371+
publisher.responseReceived(mock());
372+
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
373+
subscription.getValue().request(1);
374+
375+
subscription.getValue().request(1);
376+
publisher.cancel();
377+
verify(subscriber, times(1)).onComplete();
378+
subscription.getValue().request(1);
379+
publisher.cancel();
380+
verify(subscriber, times(1)).onComplete();
381+
}
382+
383+
/**
384+
* When close is called
385+
* Then we only send onComplete once
386+
*/
387+
public void testCloseIsIdempotent() throws IOException {
388+
Flow.Subscriber<HttpResult> subscriber = mock();
389+
390+
var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
391+
publisher.subscribe(subscriber);
392+
verify(subscriber).onSubscribe(subscription.capture());
393+
394+
publisher.responseReceived(mock());
395+
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
396+
subscription.getValue().request(1);
397+
398+
subscription.getValue().request(1);
399+
publisher.close();
400+
verify(subscriber, times(1)).onComplete();
401+
subscription.getValue().request(1);
402+
publisher.close();
403+
verify(subscriber, times(1)).onComplete();
404+
}
405+
406+
/**
407+
* When failed is called
408+
* Then we only send onError once
409+
*/
410+
public void testFailedIsIdempotent() throws IOException {
411+
var expectedException = new IllegalStateException("wow");
412+
Flow.Subscriber<HttpResult> subscriber = mock();
413+
414+
var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
415+
publisher.subscribe(subscriber);
416+
verify(subscriber).onSubscribe(subscription.capture());
417+
418+
publisher.responseReceived(mock());
419+
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
420+
subscription.getValue().request(1);
421+
422+
subscription.getValue().request(1);
423+
publisher.failed(expectedException);
424+
verify(subscriber, times(1)).onError(eq(expectedException));
425+
subscription.getValue().request(1);
426+
publisher.failed(expectedException);
427+
verify(subscriber, times(1)).onError(eq(expectedException));
428+
}
429+
378430
/**
379431
* Given the queue is being processed
380432
* When Apache cancels the publisher

0 commit comments

Comments
 (0)