diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java index db9a62da5d9ea..ebbddaeeb0d21 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java @@ -7,17 +7,17 @@ package org.elasticsearch.compute.operator.exchange; -import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.compute.EsqlRefCountingListener; import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.operator.FailureCollector; import org.elasticsearch.compute.operator.IsBlockedResult; import org.elasticsearch.core.Releasable; +import org.elasticsearch.tasks.TaskCancelledException; import java.util.List; import java.util.Map; @@ -38,10 +38,9 @@ public final class ExchangeSourceHandler { private final PendingInstances outstandingSinks; private final PendingInstances outstandingSources; - // Collect failures that occur while fetching pages from the remote sink with `failFast=true`. - // The exchange source will stop fetching and abort as soon as any failure is added to this failure collector. - // The final failure collected will be notified to callers via the {@code completionListener}. - private final FailureCollector failure = new FailureCollector(); + // Track if this exchange source should abort. There is no need to track the actual failure since the actual failure + // should be notified via #addRemoteSink(RemoteSink, boolean, Runnable, int, ActionListener). + private volatile boolean aborted = false; private final AtomicInteger nextSinkId = new AtomicInteger(); private final Map remoteSinks = ConcurrentCollections.newConcurrentMap(); @@ -52,7 +51,7 @@ public final class ExchangeSourceHandler { * @param maxBufferSize the maximum size of the exchange buffer. A larger buffer reduces ``pauses`` but uses more memory, * which could otherwise be allocated for other purposes. * @param fetchExecutor the executor used to fetch pages. - * @param completionListener a listener that will be notified when the exchange source handler fails or completes + * @param completionListener a listener that will be notified when the exchange source handler completes */ public ExchangeSourceHandler(int maxBufferSize, Executor fetchExecutor, ActionListener completionListener) { this.buffer = new ExchangeBuffer(maxBufferSize); @@ -63,14 +62,7 @@ public ExchangeSourceHandler(int maxBufferSize, Executor fetchExecutor, ActionLi this.outstandingSources = new PendingInstances(() -> finishEarly(true, ActionListener.running(closingSinks::finishInstance))); buffer.addCompletionListener(ActionListener.running(() -> { final ActionListener listener = ActionListener.assertAtLeastOnce(completionListener); - try (RefCountingRunnable refs = new RefCountingRunnable(() -> { - final Exception e = failure.getFailure(); - if (e != null) { - listener.onFailure(e); - } else { - listener.onResponse(null); - } - })) { + try (RefCountingRunnable refs = new RefCountingRunnable(ActionRunnable.run(listener, this::checkFailure))) { closingSinks.completion.addListener(refs.acquireListener()); for (PendingInstances pending : List.of(outstandingSinks, outstandingSources)) { // Create an outstanding instance and then finish to complete the completionListener @@ -83,6 +75,12 @@ public ExchangeSourceHandler(int maxBufferSize, Executor fetchExecutor, ActionLi })); } + private void checkFailure() { + if (aborted) { + throw new TaskCancelledException("remote sinks failed"); + } + } + private class ExchangeSourceImpl implements ExchangeSource { private boolean finished; @@ -90,13 +88,6 @@ private class ExchangeSourceImpl implements ExchangeSource { outstandingSources.trackNewInstance(); } - private void checkFailure() { - Exception e = failure.getFailure(); - if (e != null) { - throw ExceptionsHelper.convertToRuntime(e); - } - } - @Override public Page pollPage() { checkFailure(); @@ -201,7 +192,7 @@ void fetchPage() { while (loopControl.isRunning()) { loopControl.exiting(); // finish other sinks if one of them failed or source no longer need pages. - boolean toFinishSinks = buffer.noMoreInputs() || failure.hasFailure(); + boolean toFinishSinks = buffer.noMoreInputs() || aborted; remoteSink.fetchPageAsync(toFinishSinks, ActionListener.wrap(resp -> { Page page = resp.takePage(); if (page != null) { @@ -231,7 +222,7 @@ void fetchPage() { void onSinkFailed(Exception e) { if (failFast) { - failure.unwrapAndCollect(e); + aborted = true; } buffer.waitForReading().listener().onResponse(null); // resume the Driver if it is being blocked on reading if (finished == false) { @@ -260,12 +251,12 @@ void onSinkComplete() { * - If {@code false}, failures from this remote sink will not cause the exchange source to abort. * Callers must handle these failures notified via {@code listener}. * - If {@code true}, failures from this remote sink will cause the exchange source to abort. - * Callers can safely ignore failures notified via this listener, as they are collected and - * reported by the exchange source. + * * @param onPageFetched a callback that will be called when a page is fetched from the remote sink * @param instances the number of concurrent ``clients`` that this handler should use to fetch pages. * More clients reduce latency, but add overhead. - * @param listener a listener that will be notified when the sink fails or completes + * @param listener a listener that will be notified when the sink fails or completes. Callers must handle failures notified via + * this listener. * @see ExchangeSinkHandler#fetchPageAsync(boolean, ActionListener) */ public void addRemoteSink( @@ -284,7 +275,7 @@ public void addRemoteSink( @Override public void onFailure(Exception e) { if (failFast) { - failure.unwrapAndCollect(e); + aborted = true; } buffer.waitForReading().listener().onResponse(null); // resume the Driver if it is being blocked on reading remoteSink.close(ActionListener.running(() -> sinkListener.onFailure(e))); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java index 2edf156f92da1..2927bc5439af6 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.compute.operator.exchange; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; @@ -16,6 +17,7 @@ import org.elasticsearch.cluster.node.VersionInformation; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; @@ -23,6 +25,7 @@ import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.compute.EsqlRefCountingListener; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockWritables; import org.elasticsearch.compute.data.IntBlock; @@ -37,6 +40,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskCancellationService; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.test.transport.StubbableTransport; @@ -69,6 +73,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; public class ExchangeServiceTests extends ESTestCase { @@ -623,14 +628,15 @@ public void sendResponse(TransportResponse transportResponse) { ); ExchangeSinkHandler sinkHandler = exchange1.createSinkHandler(exchangeId, randomIntBetween(1, 128)); Transport.Connection connection = node0.getConnection(node1.getLocalNode()); + PlainActionFuture remoteSinkFuture = new PlainActionFuture<>(); sourceHandler.addRemoteSink( exchange0.newRemoteSink(task, exchangeId, node0, connection), true, () -> {}, randomIntBetween(1, 5), - ActionListener.noop() + remoteSinkFuture ); - Exception err = expectThrows( + Exception driverException = expectThrows( Exception.class, () -> runConcurrentTest( maxSeqNo, @@ -639,7 +645,9 @@ public void sendResponse(TransportResponse transportResponse) { () -> sinkHandler.createExchangeSink(() -> {}) ) ); - Throwable cause = ExceptionsHelper.unwrap(err, IOException.class); + assertThat(driverException, instanceOf(TaskCancelledException.class)); + var sinkException = expectThrows(Exception.class, remoteSinkFuture::actionGet); + Throwable cause = ExceptionsHelper.unwrap(sinkException, IOException.class); assertNotNull(cause); assertThat(cause.getMessage(), equalTo("page is too large")); PlainActionFuture sinkCompletionFuture = new PlainActionFuture<>(); @@ -649,6 +657,28 @@ public void sendResponse(TransportResponse transportResponse) { } } + public void testNoCyclicException() throws Exception { + PlainActionFuture future = new PlainActionFuture<>(); + try (EsqlRefCountingListener refs = new EsqlRefCountingListener(future)) { + var exchangeSourceHandler = new ExchangeSourceHandler(between(10, 100), threadPool.generic(), refs.acquire()); + int numSinks = between(5, 10); + for (int i = 0; i < numSinks; i++) { + RemoteSink remoteSink = (allSourcesFinished, listener) -> threadPool.schedule( + () -> listener.onFailure(new IOException("simulated")), + TimeValue.timeValueMillis(1), + threadPool.generic() + ); + exchangeSourceHandler.addRemoteSink(remoteSink, randomBoolean(), () -> {}, between(1, 3), refs.acquire()); + } + } + Exception err = expectThrows(Exception.class, () -> future.actionGet(10, TimeUnit.SECONDS)); + assertThat(ExceptionsHelper.unwrap(err, IOException.class).getMessage(), equalTo("simulated")); + try (BytesStreamOutput output = new BytesStreamOutput()) { + // ensure no cyclic exception + ElasticsearchException.writeException(err, output); + } + } + private MockTransportService newTransportService() { List namedWriteables = new ArrayList<>(ClusterModule.getNamedWriteables()); namedWriteables.addAll(BlockWritables.getNamedWriteables());