diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverScheduler.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverScheduler.java index 05fe38007a929..9d82f73f3105f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverScheduler.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverScheduler.java @@ -7,7 +7,9 @@ package org.elasticsearch.compute.operator; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import java.util.List; import java.util.concurrent.Executor; @@ -21,7 +23,7 @@ */ final class DriverScheduler { private final AtomicReference delayedTask = new AtomicReference<>(); - private final AtomicReference scheduledTask = new AtomicReference<>(); + private final AtomicReference scheduledTask = new AtomicReference<>(); private final AtomicBoolean completing = new AtomicBoolean(); void addOrRunDelayedTask(Runnable task) { @@ -35,22 +37,32 @@ void addOrRunDelayedTask(Runnable task) { } } - void scheduleOrRunTask(Executor executor, Runnable task) { - final Runnable existing = scheduledTask.getAndSet(task); + void scheduleOrRunTask(Executor executor, AbstractRunnable task) { + final AbstractRunnable existing = scheduledTask.getAndSet(task); assert existing == null : existing; final Executor executorToUse = completing.get() ? EsExecutors.DIRECT_EXECUTOR_SERVICE : executor; - executorToUse.execute(() -> { - final Runnable next = scheduledTask.getAndSet(null); - if (next != null) { - assert next == task; - next.run(); + executorToUse.execute(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + assert e instanceof EsRejectedExecutionException : new AssertionError(e); + if (scheduledTask.getAndUpdate(t -> t == task ? null : t) == task) { + task.onFailure(e); + } + } + + @Override + protected void doRun() { + AbstractRunnable toRun = scheduledTask.getAndSet(null); + if (toRun == task) { + task.run(); + } } }); } void runPendingTasks() { completing.set(true); - for (var taskHolder : List.of(delayedTask, scheduledTask)) { + for (var taskHolder : List.of(scheduledTask, delayedTask)) { final Runnable task = taskHolder.getAndSet(null); if (task != null) { task.run(); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverSchedulerTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverSchedulerTests.java new file mode 100644 index 0000000000000..ec6bf38e557a9 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverSchedulerTests.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.FixedExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.hamcrest.Matchers.equalTo; + +public class DriverSchedulerTests extends ESTestCase { + + public void testClearPendingTaskOnRejection() { + DriverScheduler scheduler = new DriverScheduler(); + AtomicInteger counter = new AtomicInteger(); + var threadPool = new TestThreadPool( + "test", + new FixedExecutorBuilder(Settings.EMPTY, "test", 1, 2, "test", EsExecutors.TaskTrackingConfig.DEFAULT) + ); + CountDownLatch latch = new CountDownLatch(1); + Executor executor = threadPool.executor("test"); + try { + for (int i = 0; i < 10; i++) { + try { + executor.execute(() -> safeAwait(latch)); + } catch (EsRejectedExecutionException e) { + break; + } + } + scheduler.scheduleOrRunTask(executor, new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + counter.incrementAndGet(); + } + + @Override + protected void doRun() { + counter.incrementAndGet(); + } + }); + scheduler.runPendingTasks(); + assertThat(counter.get(), equalTo(1)); + } finally { + latch.countDown(); + terminate(threadPool); + } + } +}