Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@

package org.elasticsearch.compute.operator.exchange;

import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.compute.EsqlRefCountingListener;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.FailureCollector;
import org.elasticsearch.compute.operator.IsBlockedResult;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.tasks.TaskCancelledException;

import java.util.List;
import java.util.Map;
Expand All @@ -38,10 +38,9 @@ public final class ExchangeSourceHandler {

private final PendingInstances outstandingSinks;
private final PendingInstances outstandingSources;
// Collect failures that occur while fetching pages from the remote sink with `failFast=true`.
// The exchange source will stop fetching and abort as soon as any failure is added to this failure collector.
// The final failure collected will be notified to callers via the {@code completionListener}.
private final FailureCollector failure = new FailureCollector();
// Track if this exchange source should abort. There is no need to track the actual failure since the actual failure
// should be notified via #addRemoteSink(RemoteSink, boolean, Runnable, int, ActionListener).
private volatile boolean aborted = false;

private final AtomicInteger nextSinkId = new AtomicInteger();
private final Map<Integer, RemoteSink> remoteSinks = ConcurrentCollections.newConcurrentMap();
Expand All @@ -52,7 +51,7 @@ public final class ExchangeSourceHandler {
* @param maxBufferSize the maximum size of the exchange buffer. A larger buffer reduces ``pauses`` but uses more memory,
* which could otherwise be allocated for other purposes.
* @param fetchExecutor the executor used to fetch pages.
* @param completionListener a listener that will be notified when the exchange source handler fails or completes
* @param completionListener a listener that will be notified when the exchange source handler completes
*/
public ExchangeSourceHandler(int maxBufferSize, Executor fetchExecutor, ActionListener<Void> completionListener) {
this.buffer = new ExchangeBuffer(maxBufferSize);
Expand All @@ -63,14 +62,7 @@ public ExchangeSourceHandler(int maxBufferSize, Executor fetchExecutor, ActionLi
this.outstandingSources = new PendingInstances(() -> finishEarly(true, ActionListener.running(closingSinks::finishInstance)));
buffer.addCompletionListener(ActionListener.running(() -> {
final ActionListener<Void> listener = ActionListener.assertAtLeastOnce(completionListener);
try (RefCountingRunnable refs = new RefCountingRunnable(() -> {
final Exception e = failure.getFailure();
if (e != null) {
listener.onFailure(e);
} else {
listener.onResponse(null);
}
})) {
try (RefCountingRunnable refs = new RefCountingRunnable(ActionRunnable.run(listener, this::checkFailure))) {
closingSinks.completion.addListener(refs.acquireListener());
for (PendingInstances pending : List.of(outstandingSinks, outstandingSources)) {
// Create an outstanding instance and then finish to complete the completionListener
Expand All @@ -83,20 +75,19 @@ public ExchangeSourceHandler(int maxBufferSize, Executor fetchExecutor, ActionLi
}));
}

private void checkFailure() {
if (aborted) {
throw new TaskCancelledException("remote sinks failed");
}
}

private class ExchangeSourceImpl implements ExchangeSource {
private boolean finished;

ExchangeSourceImpl() {
outstandingSources.trackNewInstance();
}

private void checkFailure() {
Exception e = failure.getFailure();
if (e != null) {
throw ExceptionsHelper.convertToRuntime(e);
}
}

@Override
public Page pollPage() {
checkFailure();
Expand Down Expand Up @@ -201,7 +192,7 @@ void fetchPage() {
while (loopControl.isRunning()) {
loopControl.exiting();
// finish other sinks if one of them failed or source no longer need pages.
boolean toFinishSinks = buffer.noMoreInputs() || failure.hasFailure();
boolean toFinishSinks = buffer.noMoreInputs() || aborted;
remoteSink.fetchPageAsync(toFinishSinks, ActionListener.wrap(resp -> {
Page page = resp.takePage();
if (page != null) {
Expand Down Expand Up @@ -231,7 +222,7 @@ void fetchPage() {

void onSinkFailed(Exception e) {
if (failFast) {
failure.unwrapAndCollect(e);
aborted = true;
}
buffer.waitForReading().listener().onResponse(null); // resume the Driver if it is being blocked on reading
if (finished == false) {
Expand Down Expand Up @@ -260,12 +251,12 @@ void onSinkComplete() {
* - If {@code false}, failures from this remote sink will not cause the exchange source to abort.
* Callers must handle these failures notified via {@code listener}.
* - If {@code true}, failures from this remote sink will cause the exchange source to abort.
* Callers can safely ignore failures notified via this listener, as they are collected and
* reported by the exchange source.
*
* @param onPageFetched a callback that will be called when a page is fetched from the remote sink
* @param instances the number of concurrent ``clients`` that this handler should use to fetch pages.
* More clients reduce latency, but add overhead.
* @param listener a listener that will be notified when the sink fails or completes
* @param listener a listener that will be notified when the sink fails or completes. Callers must handle failures notified via
* this listener.
* @see ExchangeSinkHandler#fetchPageAsync(boolean, ActionListener)
*/
public void addRemoteSink(
Expand All @@ -284,7 +275,7 @@ public void addRemoteSink(
@Override
public void onFailure(Exception e) {
if (failFast) {
failure.unwrapAndCollect(e);
aborted = true;
}
buffer.waitForReading().listener().onResponse(null); // resume the Driver if it is being blocked on reading
remoteSink.close(ActionListener.running(() -> sinkListener.onFailure(e)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.compute.operator.exchange;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
Expand All @@ -16,13 +17,15 @@
import org.elasticsearch.cluster.node.VersionInformation;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.compute.EsqlRefCountingListener;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BlockWritables;
import org.elasticsearch.compute.data.IntBlock;
Expand All @@ -37,6 +40,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancellationService;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.test.transport.StubbableTransport;
Expand Down Expand Up @@ -69,6 +73,7 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;

public class ExchangeServiceTests extends ESTestCase {

Expand Down Expand Up @@ -623,14 +628,15 @@ public void sendResponse(TransportResponse transportResponse) {
);
ExchangeSinkHandler sinkHandler = exchange1.createSinkHandler(exchangeId, randomIntBetween(1, 128));
Transport.Connection connection = node0.getConnection(node1.getLocalNode());
PlainActionFuture<Void> remoteSinkFuture = new PlainActionFuture<>();
sourceHandler.addRemoteSink(
exchange0.newRemoteSink(task, exchangeId, node0, connection),
true,
() -> {},
randomIntBetween(1, 5),
ActionListener.noop()
remoteSinkFuture
);
Exception err = expectThrows(
Exception driverException = expectThrows(
Exception.class,
() -> runConcurrentTest(
maxSeqNo,
Expand All @@ -639,7 +645,9 @@ public void sendResponse(TransportResponse transportResponse) {
() -> sinkHandler.createExchangeSink(() -> {})
)
);
Throwable cause = ExceptionsHelper.unwrap(err, IOException.class);
assertThat(driverException, instanceOf(TaskCancelledException.class));
var sinkException = expectThrows(Exception.class, remoteSinkFuture::actionGet);
Throwable cause = ExceptionsHelper.unwrap(sinkException, IOException.class);
assertNotNull(cause);
assertThat(cause.getMessage(), equalTo("page is too large"));
PlainActionFuture<Void> sinkCompletionFuture = new PlainActionFuture<>();
Expand All @@ -649,6 +657,28 @@ public void sendResponse(TransportResponse transportResponse) {
}
}

public void testNoCyclicException() throws Exception {
PlainActionFuture<Void> future = new PlainActionFuture<>();
try (EsqlRefCountingListener refs = new EsqlRefCountingListener(future)) {
var exchangeSourceHandler = new ExchangeSourceHandler(between(10, 100), threadPool.generic(), refs.acquire());
int numSinks = between(5, 10);
for (int i = 0; i < numSinks; i++) {
RemoteSink remoteSink = (allSourcesFinished, listener) -> threadPool.schedule(
() -> listener.onFailure(new IOException("simulated")),
TimeValue.timeValueMillis(1),
threadPool.generic()
);
exchangeSourceHandler.addRemoteSink(remoteSink, randomBoolean(), () -> {}, between(1, 3), refs.acquire());
}
}
Exception err = expectThrows(Exception.class, () -> future.actionGet(10, TimeUnit.SECONDS));
assertThat(ExceptionsHelper.unwrap(err, IOException.class).getMessage(), equalTo("simulated"));
try (BytesStreamOutput output = new BytesStreamOutput()) {
// ensure no cyclic exception
ElasticsearchException.writeException(err, output);
}
}

private MockTransportService newTransportService() {
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>(ClusterModule.getNamedWriteables());
namedWriteables.addAll(BlockWritables.getNamedWriteables());
Expand Down