diff --git a/server/src/main/java/org/elasticsearch/common/cache/Cache.java b/server/src/main/java/org/elasticsearch/common/cache/Cache.java index 779c9f2882286..422d674b02cc1 100644 --- a/server/src/main/java/org/elasticsearch/common/cache/Cache.java +++ b/server/src/main/java/org/elasticsearch/common/cache/Cache.java @@ -10,6 +10,9 @@ package org.elasticsearch.common.cache; import org.elasticsearch.core.Tuple; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.tasks.TaskCancelledException; import java.lang.reflect.Array; import java.util.HashMap; @@ -17,14 +20,16 @@ import java.util.Map; import java.util.Objects; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.LongAdder; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.function.BiConsumer; -import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.ToLongBiFunction; @@ -58,6 +63,7 @@ * @param The type of the values */ public class Cache { + private static final Logger logger = LogManager.getLogger(Cache.class); private final LongAdder hits = new LongAdder(); @@ -187,15 +193,17 @@ private final class CacheSegment { Map>> map; /** - * get an entry from the segment; expired entries will be returned as null but not removed from the cache until the LRU list is - * pruned or a manual {@link Cache#refresh()} is performed however a caller can take action using the provided callback + * get an entry from the segment with cancellation support; expired entries will be returned as null but not removed from the + * cache until the LRU list is pruned or a manual {@link Cache#refresh()} is performed however a caller can take action using + * the provided callback * - * @param key the key of the entry to get from the cache - * @param now the access time of this entry - * @param eagerEvict whether entries should be eagerly evicted on expiration + * @param key the key of the entry to get from the cache + * @param now the access time of this entry + * @param eagerEvict whether entries should be eagerly evicted on expiration + * @param cancellationRegistrar if non-null, accepts a Runnable to be called on cancellation * @return the entry if there was one, otherwise null */ - Entry get(K key, long now, boolean eagerEvict) { + Entry get(K key, long now, boolean eagerEvict, Consumer cancellationRegistrar) { CompletableFuture> future; readLock.lock(); try { @@ -206,12 +214,12 @@ Entry get(K key, long now, boolean eagerEvict) { if (future != null) { Entry entry; try { - entry = future.get(); + entry = blockOnFuture(future, cancellationRegistrar); } catch (ExecutionException e) { - assert future.isCompletedExceptionally(); misses.increment(); return null; } catch (InterruptedException e) { + Thread.currentThread().interrupt(); throw new IllegalStateException(e); } if (isExpired(entry, now)) { @@ -335,6 +343,57 @@ void remove(K key, V value, boolean notify) { } + /** + * Block on a CompletableFuture with cancellation support. + * + * @param future the future to wait on + * @param cancellationRegistrar if non-null, accepts a Runnable to be called on cancellation + * @return the result of the future + * + * @throws ExecutionException if the future completed exceptionally + * @throws InterruptedException if the thread was interrupted + */ + private static T blockOnFuture(CompletableFuture future, Consumer cancellationRegistrar) throws ExecutionException, + InterruptedException { + if (future.isDone()) { + return future.get(); + } + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + AtomicBoolean cancelled = new AtomicBoolean(false); + + future.whenComplete((value, throwable) -> { + if (throwable != null) { + error.set(throwable); + } else { + result.set(value); + } + latch.countDown(); + }); + + if (cancellationRegistrar != null) { + cancellationRegistrar.accept(() -> { + cancelled.set(true); + latch.countDown(); + }); + } + + latch.await(); + + if (future.isDone()) { + return future.get(); + } + if (cancelled.get()) { + throw new TaskCancelledException("Cache wait cancelled"); + } + if (error.get() != null) { + throw new ExecutionException(error.get()); + } + return result.get(); + } + public static final int NUMBER_OF_SEGMENTS = 256; @SuppressWarnings("unchecked") private final CacheSegment[] segments = (CacheSegment[]) Array.newInstance(CacheSegment.class, NUMBER_OF_SEGMENTS); @@ -362,8 +421,12 @@ public V get(K key) { } private V get(K key, long now, boolean eagerEvict) { + return get(key, now, eagerEvict, null); + } + + private V get(K key, long now, boolean eagerEvict, Consumer cancellationRegistrar) { CacheSegment segment = getCacheSegment(key); - Entry entry = segment.get(key, now, eagerEvict); + Entry entry = segment.get(key, now, eagerEvict, cancellationRegistrar); if (entry == null) { return null; } else { @@ -387,9 +450,27 @@ private V get(K key, long now, boolean eagerEvict) { * @throws ExecutionException thrown if loader throws an exception or returns a null value */ public V computeIfAbsent(K key, CacheLoader loader) throws ExecutionException { + return computeIfAbsent(key, loader, null); + } + + /** + * This variant supports cancellation - if a cancellation callback is provided and triggered while waiting for + * another thread to compute the value, a TaskCancelledException will be thrown. + *

+ * Waiting can happen at multiple points: + *

    + *
  • during the initial eager lookup when another thread already has an in-flight computation for the key, and
  • + *
  • after this thread loses the put-if-absent race and must wait on the winner's computation.
  • + *
+ * + * @param cancellationRegistrar if non-null, accepts a Runnable to be called when this wait should be cancelled + * @throws TaskCancelledException thrown if the operation is cancelled at any cache wait point + */ + public V computeIfAbsent(K key, CacheLoader loader, Consumer cancellationRegistrar) throws ExecutionException { long now = now(); // we have to eagerly evict expired entries or our putIfAbsent call below will fail - V value = get(key, now, true); + // this can block on an existing in-flight computation and may throw TaskCancelledException + V value = get(key, now, true, cancellationRegistrar); if (value == null) { // we need to synchronize loading of a value for a given key; however, holding the segment lock while // invoking load can lead to deadlock against another thread due to dependent key loading; therefore, we @@ -410,61 +491,65 @@ public V computeIfAbsent(K key, CacheLoader loader) throws ExecutionExcept segment.writeLock.unlock(); } - BiFunction, Throwable, ? extends V> handler = (ok, ex) -> { - if (ok != null) { - promote(ok, now); - return ok.value; - } else { - segment.writeLock.lock(); - try { - CompletableFuture> sanity = segment.map == null ? null : segment.map.get(key); - if (sanity != null && sanity.isCompletedExceptionally()) { - segment.map.remove(key); - if (segment.map.isEmpty()) { - segment.map = null; - } - } - } finally { - segment.writeLock.unlock(); - } - return null; - } - }; - - CompletableFuture completableValue; - if (future == null) { + final boolean isComputing = (future == null); + if (isComputing) { future = completableFuture; - completableValue = future.handle(handler); V loaded; try { loaded = loader.load(key); } catch (Exception e) { future.completeExceptionally(e); + cleanupFailedFuture(segment, key, future); throw new ExecutionException(e); } if (loaded == null) { NullPointerException npe = new NullPointerException("loader returned a null value"); future.completeExceptionally(npe); + cleanupFailedFuture(segment, key, future); throw new ExecutionException(npe); - } else { - future.complete(new Entry<>(key, loaded, now)); } + Entry entry = new Entry<>(key, loaded, now); + future.complete(entry); + promote(entry, now); + return loaded; } else { - completableValue = future.handle(handler); + try { + Entry entry = blockOnFuture(future, cancellationRegistrar); + if (entry == null) { + future.get(); + throw new IllegalStateException("future completed exceptionally but no exception thrown"); + } + promote(entry, now); + return entry.value; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new ExecutionException(e); + } } + } + return value; + } - try { - value = completableValue.get(); - // check to ensure the future hasn't been completed with an exception - if (future.isCompletedExceptionally()) { - future.get(); // call get to force the exception to be thrown for other concurrent callers - throw new IllegalStateException("the future was completed exceptionally but no exception was thrown"); + /** + * Clean up a failed future from the segment map. + */ + private void cleanupFailedFuture(CacheSegment segment, K key, CompletableFuture> future) { + segment.writeLock.lock(); + try { + if (segment.map != null) { + CompletableFuture> current = segment.map.get(key); + if (current == future) { + segment.map.remove(key); + if (segment.map.isEmpty()) { + segment.map = null; + } + } else if (current != null) { + logger.debug("Skipped cleanup for key [{}] because the future was replaced", key); } - } catch (InterruptedException e) { - throw new IllegalStateException(e); } + } finally { + segment.writeLock.unlock(); } - return value; } /** diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesRequestCache.java b/server/src/main/java/org/elasticsearch/indices/IndicesRequestCache.java index 10c2dbf01da0f..faca181ee2e26 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesRequestCache.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesRequestCache.java @@ -34,6 +34,7 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.ConcurrentMap; +import java.util.function.Consumer; /** * The indices request cache allows to cache a shard level request stage responses, helping with improving @@ -97,18 +98,46 @@ void clear(CacheEntity entity) { cleanCache(); } + /** + * Get or compute a cache entry without cancellation support (backward compatible). + */ BytesReference getOrCompute( CacheEntity cacheEntity, CheckedSupplier loader, MappingLookup.CacheKey mappingCacheKey, DirectoryReader reader, BytesReference cacheKey + ) throws Exception { + return getOrCompute(cacheEntity, loader, mappingCacheKey, reader, cacheKey, null); + } + + /** + * Get or compute a cache entry with cancellation support. + * + * @param cacheEntity the cache entity + * @param loader the loader to compute the value if not cached + * @param mappingCacheKey the mapping cache key + * @param reader the directory reader + * @param cacheKey the cache key + * @param cancellationRegistrar if non-null, accepts a Runnable to be called when the operation should be cancelled. + * This allows waiting threads to be notified instantly when their task is cancelled, + * rather than polling. + * @return the cached or computed value + * @throws Exception if the computation fails or the operation is cancelled + */ + BytesReference getOrCompute( + CacheEntity cacheEntity, + CheckedSupplier loader, + MappingLookup.CacheKey mappingCacheKey, + DirectoryReader reader, + BytesReference cacheKey, + Consumer cancellationRegistrar ) throws Exception { final ESCacheHelper cacheHelper = ElasticsearchDirectoryReader.getESReaderCacheHelper(reader); assert cacheHelper != null; final Key key = new Key(cacheEntity, mappingCacheKey, cacheHelper.getKey(), cacheKey); Loader cacheLoader = new Loader(cacheEntity, loader); - BytesReference value = cache.computeIfAbsent(key, cacheLoader); + BytesReference value = cache.computeIfAbsent(key, cacheLoader, cancellationRegistrar); if (cacheLoader.isLoaded()) { key.entity.onMiss(); // see if its the first time we see this reader, and make sure to register a cleanup key diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesService.java b/server/src/main/java/org/elasticsearch/indices/IndicesService.java index 9db89a4d26a37..3a79c55bdac57 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesService.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesService.java @@ -1698,13 +1698,27 @@ public static boolean canCache(ShardSearchRequest request, SearchContext context } /** - * Loads the cache result, computing it if needed by executing the query phase and otherwise deserializing the cached - * value into the {@link SearchContext#queryResult() context's query result}. The combination of load + compute allows - * to have a single load operation that will cause other requests with the same key to wait till its loaded an reuse - * the same cache. - */ + * Equivalent to {@link #loadIntoContext(ShardSearchRequest, SearchContext, Consumer)} with + * {@code cancellationRegistrar == null}. + */ public void loadIntoContext(ShardSearchRequest request, SearchContext context) throws Exception { - assert canCache(request, context); + loadIntoContext(request, context, null); + } + + /** + * Loads the cache result, computing it if needed by executing the query phase and otherwise deserializing the cached + * value into the {@link SearchContext#queryResult() context's query result}. The combination of load + compute allows + * to have a single load operation that will cause other requests with the same key to wait till its loaded an reuse + * the same cache. + * + * @param request the shard search request + * @param context the search context to populate + * @param cancellationRegistrar registers cancellation handling for the underlying work; may be {@code null} + * @throws Exception if loading, computing, or deserialization fails + */ + public void loadIntoContext(ShardSearchRequest request, SearchContext context, Consumer cancellationRegistrar) + throws Exception { + assert IndicesService.canCache(request, context); final DirectoryReader directoryReader = context.searcher().getDirectoryReader(); boolean[] loadedFromCache = new boolean[] { true }; @@ -1718,7 +1732,8 @@ public void loadIntoContext(ShardSearchRequest request, SearchContext context) t QueryPhase.execute(context); context.queryResult().writeToNoId(out); loadedFromCache[0] = false; - } + }, + cancellationRegistrar ); if (loadedFromCache[0]) { @@ -1761,6 +1776,7 @@ public long getTotalIndexingBufferBytes() { * @param reader a reader for this shard. Used to invalidate the cache when there are changes. * @param cacheKey key for the thing being cached within this shard * @param loader loads the data into the cache if needed + * @param cancellationRegistrar if non-null, accepts a Runnable to be called when the operation should be cancelled * @return the contents of the cache or the result of calling the loader */ private BytesReference cacheShardLevelResult( @@ -1768,7 +1784,8 @@ private BytesReference cacheShardLevelResult( MappingLookup.CacheKey mappingCacheKey, DirectoryReader reader, BytesReference cacheKey, - CheckedConsumer loader + CheckedConsumer loader, + Consumer cancellationRegistrar ) throws Exception { IndexShardCacheEntity cacheEntity = new IndexShardCacheEntity(shard); CheckedSupplier supplier = () -> { @@ -1787,7 +1804,7 @@ private BytesReference cacheShardLevelResult( return out.bytes(); } }; - return indicesRequestCache.getOrCompute(cacheEntity, supplier, mappingCacheKey, reader, cacheKey); + return indicesRequestCache.getOrCompute(cacheEntity, supplier, mappingCacheKey, reader, cacheKey, cancellationRegistrar); } static final class IndexShardCacheEntity extends AbstractIndexShardCacheEntity { diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 2bf8ab3313719..31cd79cc6ce76 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -159,6 +159,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; import java.util.function.LongSupplier; import java.util.function.Supplier; @@ -715,7 +716,12 @@ private void loadOrExecuteQueryPhase(final ShardSearchRequest request, final Sea final boolean canCache = IndicesService.canCache(request, context); context.getSearchExecutionContext().freezeContext(); if (canCache) { - indicesService.loadIntoContext(request, context); + CancellableTask task = context.getTask(); + Consumer cancellationRegistrar = null; + if (task != null) { + cancellationRegistrar = cancellationCallback -> { task.addListener(cancellationCallback::run); }; + } + indicesService.loadIntoContext(request, context, cancellationRegistrar); } else { QueryPhase.execute(context); } diff --git a/server/src/test/java/org/elasticsearch/common/cache/CacheTests.java b/server/src/test/java/org/elasticsearch/common/cache/CacheTests.java index a9741b44a1e45..03e62627f82eb 100644 --- a/server/src/test/java/org/elasticsearch/common/cache/CacheTests.java +++ b/server/src/test/java/org/elasticsearch/common/cache/CacheTests.java @@ -10,6 +10,7 @@ package org.elasticsearch.common.cache; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.test.ESTestCase; import org.junit.Before; @@ -32,10 +33,12 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.stream.Collectors; import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.is; @@ -749,7 +752,7 @@ public void testCachePollution() throws InterruptedException { } catch (ExecutionException e) { assertNotNull(e.getCause()); assertThat(e.getCause(), instanceOf(Exception.class)); - assertEquals(e.getCause().getMessage(), "testCachePollution"); + assertEquals("testCachePollution", e.getCause().getMessage()); } } else if (second) { cache.invalidate(key); @@ -820,4 +823,425 @@ public void testRemoveUsingValuesIterator() { assertEquals(RemovalNotification.RemovalReason.INVALIDATED, removalNotifications.get(i).getRemovalReason()); } } + + public void testComputeIfAbsentWithoutCancellationRegistrar() throws ExecutionException { + final Cache cache = CacheBuilder.builder().build(); + + String result = cache.computeIfAbsent(1, k -> "value-" + k, null); + assertEquals("value-1", result); + assertEquals("value-1", cache.get(1)); + } + + public void testComputeIfAbsentCancellationOnAlreadyCompletedFuture() throws Exception { + final Cache cache = CacheBuilder.builder().build(); + cache.put(1, "existing-value"); + + AtomicBoolean callbackCalled = new AtomicBoolean(false); + String result = cache.computeIfAbsent(1, k -> "new-value", callback -> { + callbackCalled.set(true); + // Even if we call the callback, it shouldn't matter since value exists + callback.run(); + }); + + assertEquals("existing-value", result); + assertFalse("Cancellation callback should not be registered on already-completed future", callbackCalled.get()); + } + + public void testComputeIfAbsentWithCancellationDuringInitialLookupWait() throws Exception { + final Cache cache = CacheBuilder.builder().build(); + + CountDownLatch computeStarted = new CountDownLatch(1); + CountDownLatch cancelTriggered = new CountDownLatch(1); + CountDownLatch computeComplete = new CountDownLatch(1); + + // Thread 1: Start computing a value but block + AtomicReference threadException = new AtomicReference<>(); + Thread computingThread = new Thread(() -> { + try { + cache.computeIfAbsent(1, k -> { + computeStarted.countDown(); + safeAwait(cancelTriggered); + return "computed-value"; + }); + } catch (ExecutionException e) { + threadException.set(e); + } finally { + computeComplete.countDown(); + } + }); + computingThread.start(); + safeAwait(computeStarted); + + // Thread 2: Get the same key with cancellation + AtomicBoolean wasCancelled = new AtomicBoolean(false); + AtomicBoolean waitingLoaderInvoked = new AtomicBoolean(false); + Thread waitingThread = new Thread(() -> { + try { + AtomicBoolean cancelled = new AtomicBoolean(false); + cache.computeIfAbsent(1, k -> { + waitingLoaderInvoked.set(true); + return "should-not-be-called"; + }, cancellationCallback -> { + new Thread(() -> { + try { + Thread.sleep(50); + cancelled.set(true); + cancellationCallback.run(); + } catch (InterruptedException e) { + // ignore + } + }).start(); + }); + fail("Expected TaskCancelledException"); + } catch (TaskCancelledException e) { + wasCancelled.set(true); + assertThat(e.getMessage(), containsString("Cache wait cancelled")); + } catch (ExecutionException e) { + threadException.set(e); + } + }); + waitingThread.start(); + waitingThread.join(5000); + + assertFalse("Waiting thread should have completed", waitingThread.isAlive()); + assertTrue("Waiting thread should have been cancelled", wasCancelled.get()); + assertFalse("Waiting thread must not invoke loader when waiting on existing in-flight computation", waitingLoaderInvoked.get()); + + cancelTriggered.countDown(); + safeAwait(computeComplete); + computingThread.join(5000); + + assertEquals("computed-value", cache.get(1)); + assertNull("No exception should have been thrown by computing thread", threadException.get()); + } + + public void testComputeIfAbsentPropagatesLoaderExceptionToWaitingThreadWithCancellationRegistrar() throws Exception { + final Cache cache = CacheBuilder.builder().build(); + + CountDownLatch computeStarted = new CountDownLatch(1); + CountDownLatch allowFailure = new CountDownLatch(1); + CountDownLatch cancellationRegistered = new CountDownLatch(1); + + AtomicReference computingThreadResult = new AtomicReference<>(); + Thread computingThread = new Thread(() -> { + try { + cache.computeIfAbsent(1, k -> { + computeStarted.countDown(); + safeAwait(allowFailure); + throw new IllegalStateException("failed to load"); + }); + computingThreadResult.set(new AssertionError("expected ExecutionException")); + } catch (ExecutionException e) { + computingThreadResult.set(e); + } + }); + computingThread.start(); + safeAwait(computeStarted); + + AtomicReference waitingThreadResult = new AtomicReference<>(); + Thread waitingThread = new Thread(() -> { + try { + cache.computeIfAbsent( + 1, + k -> { throw new IllegalStateException("failed to load"); }, + cancellationCallback -> cancellationRegistered.countDown() + ); + waitingThreadResult.set(new AssertionError("expected ExecutionException")); + } catch (ExecutionException | TaskCancelledException e) { + waitingThreadResult.set(e); + } + }); + waitingThread.start(); + + safeAwait(cancellationRegistered); + allowFailure.countDown(); + + computingThread.join(5000); + waitingThread.join(5000); + + assertFalse("Computing thread should have completed", computingThread.isAlive()); + assertFalse("Waiting thread should have completed", waitingThread.isAlive()); + assertThat(computingThreadResult.get(), instanceOf(ExecutionException.class)); + assertThat(computingThreadResult.get().getCause(), instanceOf(IllegalStateException.class)); + assertEquals("failed to load", computingThreadResult.get().getCause().getMessage()); + + assertThat(waitingThreadResult.get(), instanceOf(ExecutionException.class)); + assertThat(waitingThreadResult.get().getCause(), instanceOf(IllegalStateException.class)); + assertEquals("failed to load", waitingThreadResult.get().getCause().getMessage()); + } + + public void testConcurrentComputeIfAbsentWithCancellation() throws InterruptedException { + final Cache cache = CacheBuilder.builder().build(); + int numberOfThreads = randomIntBetween(4, 16); + int keysToCompute = randomIntBetween(10, 50); + + CopyOnWriteArrayList failures = new CopyOnWriteArrayList<>(); + AtomicLong cancellations = new AtomicLong(); + AtomicLong successes = new AtomicLong(); + + startInParallel(numberOfThreads, threadIndex -> { + Random random = new Random(random().nextLong()); + for (int j = 0; j < keysToCompute; j++) { + int key = random.nextInt(keysToCompute); + try { + AtomicBoolean shouldCancel = new AtomicBoolean(random.nextInt(10) == 0); // 10% chance + cache.computeIfAbsent(key, k -> { + if (random.nextBoolean()) { + try { + Thread.sleep(random.nextInt(5)); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + return "value-" + k; + }, callback -> { + if (shouldCancel.get()) { + new Thread(() -> { + try { + Thread.sleep(random.nextInt(10)); + callback.run(); + } catch (InterruptedException e) { + // ignore + } + }).start(); + } + }); + successes.incrementAndGet(); + } catch (TaskCancelledException e) { + cancellations.incrementAndGet(); + } catch (ExecutionException e) { + failures.add(e); + } + } + }); + + assertThat("No unexpected failures", failures, is(empty())); + + // Verify cache + for (int key = 0; key < keysToCompute; key++) { + String value = cache.get(key); + if (value != null) { + assertEquals("value-" + key, value); + } + } + } + + public void testNoCancellationRegistrarDoesNotDeadlock() throws Exception { + final Cache cache = CacheBuilder.builder().build(); + var computation = new SlowComputation(cache, 1, "computed-value"); + + // thread2: Waits with null cancellationRegistrar — no early-exit path + CountDownLatch waitingFinished = new CountDownLatch(1); + AtomicReference waitingResult = new AtomicReference<>(); + AtomicReference waitingError = new AtomicReference<>(); + Thread waitingThread = new Thread(() -> { + try { + waitingResult.set(cache.computeIfAbsent(1, k -> "should-not-run", null)); + } catch (ExecutionException | TaskCancelledException e) { + waitingError.set(e); + } finally { + waitingFinished.countDown(); + } + }); + waitingThread.start(); + + // thread1: finish computation + computation.complete(); + + // thread2: no deadlock + assertTrue("Thread must unblock once future completes (no deadlock)", waitingFinished.await(5, TimeUnit.SECONDS)); + assertNull("No exception expected", waitingError.get()); + assertEquals("computed-value", waitingResult.get()); + } + + public void testBlockOnFutureSynchronousCancellation() throws Exception { + final Cache cache = CacheBuilder.builder().build(); + var computation = new SlowComputation(cache, 1, "computed-value"); + + // thread2: Waits with a registrar that fires the cancellation callback + AtomicReference waitingError = new AtomicReference<>(); + CountDownLatch waitingFinished = new CountDownLatch(1); + Thread waitingThread = new Thread(() -> { + try { + cache.computeIfAbsent(1, k -> "should-not-run", Runnable::run); + fail("Expected TaskCancelledException"); + } catch (TaskCancelledException | ExecutionException e) { + waitingError.set(e); + } finally { + waitingFinished.countDown(); + } + }); + waitingThread.start(); + + assertTrue("Waiting thread must exit on synchronous cancellation", waitingFinished.await(5, TimeUnit.SECONDS)); + assertThat(waitingError.get(), instanceOf(TaskCancelledException.class)); + assertThat(waitingError.get().getMessage(), containsString("Cache wait cancelled")); + + computation.completeAndJoin(); + } + + public void testBlockOnFutureNullRegistrarNoDeadlockOnLoaderFailure() throws Exception { + final Cache cache = CacheBuilder.builder().build(); + + CountDownLatch computeStarted = new CountDownLatch(1); + CountDownLatch allowFailure = new CountDownLatch(1); + CountDownLatch waitingBlocked = new CountDownLatch(1); + + AtomicReference computingError = new AtomicReference<>(); + + // thread1: computing thread throws + Thread computingThread = new Thread(() -> { + try { + cache.computeIfAbsent(1, k -> { + computeStarted.countDown(); + safeAwait(allowFailure); + throw new IllegalStateException("loader-failure"); + }); + } catch (ExecutionException e) { + computingError.set(e); + } + }); + computingThread.start(); + safeAwait(computeStarted); + + // thread2: null registrar, blocks until future resolves (or retries on failure) + AtomicReference waitingError = new AtomicReference<>(); + AtomicReference waitingResult = new AtomicReference<>(); + CountDownLatch waitingFinished = new CountDownLatch(1); + Thread waitingThread = new Thread(() -> { + waitingBlocked.countDown(); + try { + waitingResult.set(cache.computeIfAbsent(1, k -> "retried-value", null)); + } catch (ExecutionException e) { + waitingError.set(e); + } finally { + waitingFinished.countDown(); + } + }); + waitingThread.start(); + allowFailure.countDown(); + + assertTrue("Waiting thread must unblock after loader failure (no deadlock)", waitingFinished.await(5, TimeUnit.SECONDS)); + computingThread.join(5000); + + assertThat(computingError.get(), instanceOf(ExecutionException.class)); + assertThat(computingError.get().getCause(), instanceOf(IllegalStateException.class)); + assertEquals("loader-failure", computingError.get().getCause().getMessage()); + + if (waitingError.get() != null) { + assertThat(waitingError.get(), instanceOf(ExecutionException.class)); + assertThat(waitingError.get().getCause(), instanceOf(IllegalStateException.class)); + } else { + assertNotNull("Waiting thread must have a result if no exception", waitingResult.get()); + } + } + + public void testBlockOnFutureInterruptedWhileWaiting() throws Exception { + final Cache cache = CacheBuilder.builder().build(); + var computation = new SlowComputation(cache, 1, "computed-value"); + + // thread 2: waits with no cancellationRegistrar, will be interrupted externally + AtomicBoolean interruptFlagPreserved = new AtomicBoolean(false); + AtomicBoolean threwIllegalState = new AtomicBoolean(false); + CountDownLatch waitingStarted = new CountDownLatch(1); + CountDownLatch waitingFinished = new CountDownLatch(1); + Thread waitingThread = new Thread(() -> { + try { + waitingStarted.countDown(); + cache.computeIfAbsent(1, k -> "should-not-run", null); + } catch (IllegalStateException e) { + threwIllegalState.set(true); + interruptFlagPreserved.set(Thread.currentThread().isInterrupted()); + } catch (ExecutionException e) { + // unexpected + } finally { + waitingFinished.countDown(); + } + }); + waitingThread.start(); + safeAwait(waitingStarted); + waitingThread.interrupt(); + + assertTrue("Waiting thread must exit after interrupt", waitingFinished.await(5, TimeUnit.SECONDS)); + assertTrue("InterruptedException path must have been taken", threwIllegalState.get()); + assertTrue("Interrupt flag must be preserved", interruptFlagPreserved.get()); + + computation.completeAndJoin(); + } + + public void testBlockOnFutureLateRegistrarDoesNotOverrideResult() throws Exception { + final Cache cache = CacheBuilder.builder().build(); + var computation = new SlowComputation(cache, 1, "computed-value"); + + CountDownLatch waitingBlocked = new CountDownLatch(1); + AtomicReference capturedCancellationCallback = new AtomicReference<>(); + + // thread 2: waits and captures the cancellation callback without invoking it yet + AtomicReference waitingResult = new AtomicReference<>(); + AtomicReference waitingError = new AtomicReference<>(); + CountDownLatch waitingFinished = new CountDownLatch(1); + Thread waitingThread = new Thread(() -> { + try { + waitingResult.set(cache.computeIfAbsent(1, k -> "should-not-run", callback -> { + capturedCancellationCallback.set(callback); + waitingBlocked.countDown(); + })); + } catch (TaskCancelledException | ExecutionException e) { + waitingError.set(e); + } finally { + waitingFinished.countDown(); + } + }); + waitingThread.start(); + safeAwait(waitingBlocked); + computation.complete(); + + assertTrue("Waiting thread must unblock when future completes", waitingFinished.await(5, TimeUnit.SECONDS)); + assertNotNull(capturedCancellationCallback.get()); + + assertNull("No exception expected — future completed before cancellation fired", waitingError.get()); + assertEquals("computed-value", waitingResult.get()); + + computation.join(); + } + + /** + * A slow in-flight computation on a background thread. The loader blocks until + * {@link #complete()} is called, letting tests set up waiting threads first. + * Construction starts the thread and blocks the caller until the loader is entered. + */ + private final class SlowComputation { + private final CountDownLatch started = new CountDownLatch(1); + private final CountDownLatch allowComplete = new CountDownLatch(1); + final Thread thread; + + SlowComputation(Cache cache, int key, String value) { + thread = new Thread(() -> { + try { + cache.computeIfAbsent(key, k -> { + started.countDown(); + safeAwait(allowComplete); + return value; + }); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } + }); + thread.start(); + safeAwait(started); + } + + void complete() { + allowComplete.countDown(); + } + + void join() throws InterruptedException { + thread.join(5000); + } + + void completeAndJoin() throws InterruptedException { + complete(); + join(); + } + } } diff --git a/server/src/test/java/org/elasticsearch/indices/IndicesRequestCacheTests.java b/server/src/test/java/org/elasticsearch/indices/IndicesRequestCacheTests.java index 069e7424a6338..63e82ca47a997 100644 --- a/server/src/test/java/org/elasticsearch/indices/IndicesRequestCacheTests.java +++ b/server/src/test/java/org/elasticsearch/indices/IndicesRequestCacheTests.java @@ -38,6 +38,7 @@ import org.elasticsearch.index.mapper.MappingLookup; import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; @@ -45,7 +46,12 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import static java.util.Collections.emptyList; @@ -567,6 +573,141 @@ public void testKeyEqualsAndHashCode() throws IOException { assertEquals(key1.hashCode(), key2.hashCode()); } + public void testComputingThreadDoesNotRegisterForCancellation() throws Exception { + ShardRequestCache requestCacheStats = new ShardRequestCache(); + IndicesRequestCache cache = new IndicesRequestCache(Settings.EMPTY); + Directory dir = newDirectory(); + IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig()); + + writer.addDocument(newDoc(0, "foo")); + DirectoryReader reader = ElasticsearchDirectoryReader.wrap(DirectoryReader.open(writer), new ShardId("foo", "bar", 1)); + MappingLookup.CacheKey mappingKey = MappingLookup.EMPTY.cacheKey(); + TermQueryBuilder termQuery = new TermQueryBuilder("id", "0"); + BytesReference termBytes = XContentHelper.toXContent(termQuery, XContentType.JSON, false); + AtomicBoolean indexShard = new AtomicBoolean(true); + + TestEntity entity = new TestEntity(requestCacheStats, indexShard); + AtomicBoolean loaderExecuted = new AtomicBoolean(false); + AtomicReference capturedCallback = new AtomicReference<>(); + CheckedSupplier loader = () -> { + loaderExecuted.set(true); + try (BytesStreamOutput out = new BytesStreamOutput()) { + out.writeString("computed_value"); + return out.bytes(); + } + }; + + BytesReference value = cache.getOrCompute(entity, loader, mappingKey, reader, termBytes, capturedCallback::set); + + assertEquals("computed_value", value.streamInput().readString()); + assertTrue("Loader should have been executed", loaderExecuted.get()); + assertEquals(1, cache.count()); + assertNull("Callback should NOT be registered for computing thread", capturedCallback.get()); + + IOUtils.close(reader, writer, dir, cache); + } + + public void testMultipleWaitingThreadsCanBeCancelledIndependently() throws Exception { + ShardRequestCache requestCacheStats = new ShardRequestCache(); + IndicesRequestCache cache = new IndicesRequestCache(Settings.EMPTY); + Directory dir = newDirectory(); + IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig()); + + writer.addDocument(newDoc(0, "foo")); + DirectoryReader reader = ElasticsearchDirectoryReader.wrap(DirectoryReader.open(writer), new ShardId("foo", "bar", 1)); + MappingLookup.CacheKey mappingKey = MappingLookup.EMPTY.cacheKey(); + TermQueryBuilder termQuery = new TermQueryBuilder("id", "0"); + BytesReference termBytes = XContentHelper.toXContent(termQuery, XContentType.JSON, false); + AtomicBoolean indexShard = new AtomicBoolean(true); + + ExecutorService executor = Executors.newFixedThreadPool(3); + try { + // Computing thread + CountDownLatch loaderStarted = new CountDownLatch(1); + CountDownLatch allowLoaderToComplete = new CountDownLatch(1); + CheckedSupplier slowLoader = () -> { + loaderStarted.countDown(); + try { + allowLoaderToComplete.await(30, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + try (BytesStreamOutput out = new BytesStreamOutput()) { + out.writeString("result"); + return out.bytes(); + } + }; + + executor.submit(() -> { + try { + TestEntity entity = new TestEntity(requestCacheStats, indexShard); + cache.getOrCompute(entity, slowLoader, mappingKey, reader, termBytes, callback -> {}); + } catch (Exception e) { + // ignore + } + }); + assertTrue("Thread should have started", loaderStarted.await(10, TimeUnit.SECONDS)); + + // Waiting thread - cancelled + AtomicBoolean threadCancelled = new AtomicBoolean(false); + AtomicReference cancelCallback = new AtomicReference<>(); + CountDownLatch waiter1Ready = new CountDownLatch(1); + executor.submit(() -> { + try { + TestEntity entity = new TestEntity(requestCacheStats, indexShard); + cache.getOrCompute(entity, () -> { + fail("Should not call loader"); + return null; + }, mappingKey, reader, termBytes, callback -> { + cancelCallback.set(callback); + waiter1Ready.countDown(); + }); + } catch (TaskCancelledException e) { + threadCancelled.set(true); + } catch (Exception e) { + // ignore + } + }); + + // Waiting thread - completed + AtomicBoolean threadCompleted = new AtomicBoolean(false); + CountDownLatch waiter2Ready = new CountDownLatch(1); + executor.submit(() -> { + try { + TestEntity entity = new TestEntity(requestCacheStats, indexShard); + BytesReference value = cache.getOrCompute(entity, () -> { + fail("Should not call loader"); + return null; + }, mappingKey, reader, termBytes, callback -> { waiter2Ready.countDown(); }); + if (value != null) { + threadCompleted.set(true); + } + } catch (Exception e) { + // ignore + } + }); + + assertTrue("Thread should have started", waiter1Ready.await(10, TimeUnit.SECONDS)); + assertTrue("Thread should have started", waiter2Ready.await(10, TimeUnit.SECONDS)); + + Runnable callback = cancelCallback.get(); + assertNotNull(callback); + callback.run(); + + allowLoaderToComplete.countDown(); + + assertBusy(() -> assertTrue("Waiter 1 should have been cancelled", threadCancelled.get())); + assertBusy(() -> assertTrue("Waiter 2 should have completed successfully", threadCompleted.get())); + } finally { + executor.shutdown(); + boolean done = executor.awaitTermination(10, TimeUnit.SECONDS); + if (done == false) { + executor.shutdownNow(); + } + IOUtils.close(reader, writer, dir, cache); + } + } + private static class TestBytesReference extends AbstractBytesReference { int dummyValue; diff --git a/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java b/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java index cc31052327a86..1595ba5a34a6b 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java @@ -150,8 +150,10 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Semaphore; import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.IntConsumer; @@ -2916,6 +2918,66 @@ public void testSeqNoAndPrimaryTermReturnsSentinelsWhenSequenceNumbersDisabled() }); } + public void testCancelledTaskFailsFastWithCaching() throws Exception { + createIndex("index"); + prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); + + final SearchService service = getInstanceFromNode(SearchService.class); + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + final IndexShard indexShard = indexService.getShard(0); + + SearchRequest searchRequest = new SearchRequest("index").allowPartialSearchResults(true); + searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder())); + searchRequest.requestCache(true); + + long nowInMillis = System.currentTimeMillis(); + OriginalIndices originalIndices = new OriginalIndices(new String[] { "index" }, IndicesOptions.strictExpandOpenAndForbidClosed()); + ShardSearchRequest request = new ShardSearchRequest( + originalIndices, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + nowInMillis, + null + ); + + // Create a task and cancel it before execution + SearchShardTask task = new SearchShardTask(1L, "", "", "", null, emptyMap()); + TaskCancelHelper.cancel(task, "pre-cancelled for test"); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference caughtException = new AtomicReference<>(); + AtomicBoolean succeeded = new AtomicBoolean(false); + + service.executeQueryPhase(request, task, new ActionListener<>() { + @Override + public void onResponse(SearchPhaseResult result) { + try { + service.freeReaderContext(result.getContextId()); + succeeded.set(true); + } finally { + latch.countDown(); + } + } + + @Override + public void onFailure(Exception e) { + caughtException.set(e); + latch.countDown(); + } + }); + + assertTrue("Should complete", latch.await(10, TimeUnit.SECONDS)); + assertFalse("Should not succeed", succeeded.get()); + assertNotNull("Should have exception", caughtException.get()); + assertThat(caughtException.get(), instanceOf(TaskCancelledException.class)); + assertThat(caughtException.get().getMessage(), containsString("pre-cancelled for test")); + } + private static ReaderContext createReaderContext(IndexService indexService, IndexShard indexShard) { return new ReaderContext( new ShardSearchContextId(UUIDs.randomBase64UUID(), randomNonNegativeLong()),