Skip to content

Commit f5cb6ea

Browse files
authored
Make MutableSearchResponse ref-counted to prevent use-after-close in async search (#134359)
1 parent 54c35e0 commit f5cb6ea

File tree

7 files changed

+677
-43
lines changed

7 files changed

+677
-43
lines changed

docs/changelog/134359.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 134359
2+
summary: Make `MutableSearchResponse` ref-counted to prevent use-after-close in async
3+
search
4+
area: Search
5+
type: bug
6+
issues: []
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.search;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.ExceptionsHelper;
12+
import org.elasticsearch.action.index.IndexRequestBuilder;
13+
import org.elasticsearch.common.settings.Settings;
14+
import org.elasticsearch.core.TimeValue;
15+
import org.elasticsearch.rest.RestStatus;
16+
import org.elasticsearch.search.aggregations.AggregationBuilders;
17+
import org.elasticsearch.search.builder.SearchSourceBuilder;
18+
import org.elasticsearch.test.ESIntegTestCase.SuiteScopeTestCase;
19+
import org.elasticsearch.xpack.core.search.action.AsyncSearchResponse;
20+
21+
import java.util.ArrayList;
22+
import java.util.HashSet;
23+
import java.util.List;
24+
import java.util.Queue;
25+
import java.util.Set;
26+
import java.util.concurrent.CompletableFuture;
27+
import java.util.concurrent.ConcurrentLinkedQueue;
28+
import java.util.concurrent.CountDownLatch;
29+
import java.util.concurrent.ExecutionException;
30+
import java.util.concurrent.ExecutorService;
31+
import java.util.concurrent.Executors;
32+
import java.util.concurrent.Future;
33+
import java.util.concurrent.TimeUnit;
34+
import java.util.concurrent.TimeoutException;
35+
import java.util.concurrent.atomic.AtomicBoolean;
36+
import java.util.concurrent.atomic.AtomicReference;
37+
import java.util.concurrent.atomic.LongAdder;
38+
import java.util.stream.IntStream;
39+
40+
@SuiteScopeTestCase
41+
public class AsyncSearchConcurrentStatusIT extends AsyncSearchIntegTestCase {
42+
private static String indexName;
43+
private static int numShards;
44+
45+
private static int numKeywords;
46+
47+
@Override
48+
public void setupSuiteScopeCluster() {
49+
indexName = "test-async";
50+
numShards = randomIntBetween(1, 20);
51+
int numDocs = randomIntBetween(100, 1000);
52+
createIndex(indexName, Settings.builder().put("index.number_of_shards", numShards).build());
53+
numKeywords = randomIntBetween(50, 100);
54+
Set<String> keywordSet = new HashSet<>();
55+
for (int i = 0; i < numKeywords; i++) {
56+
keywordSet.add(randomAlphaOfLengthBetween(10, 20));
57+
}
58+
numKeywords = keywordSet.size();
59+
String[] keywords = keywordSet.toArray(String[]::new);
60+
List<IndexRequestBuilder> reqs = new ArrayList<>();
61+
for (int i = 0; i < numDocs; i++) {
62+
float metric = randomFloat();
63+
String keyword = keywords[randomIntBetween(0, numKeywords - 1)];
64+
reqs.add(prepareIndex(indexName).setSource("terms", keyword, "metric", metric));
65+
}
66+
indexRandom(true, true, reqs);
67+
}
68+
69+
/**
70+
* This test spins up a set of poller threads that repeatedly call
71+
* {@code _async_search/{id}}. Each poller starts immediately, and once enough
72+
* requests have been issued they signal a latch to indicate the group is "warmed up".
73+
* The test waits on this latch to deterministically ensure pollers are active.
74+
* In parallel, a consumer thread drives the async search to completion using the
75+
* blocking iterator. This coordinated overlap exercises the window where the task
76+
* is closing and some status calls may return {@code 410 GONE}.
77+
*/
78+
public void testConcurrentStatusFetchWhileTaskCloses() throws Exception {
79+
final TimeValue timeout = TimeValue.timeValueSeconds(3);
80+
final String aggName = "terms";
81+
final SearchSourceBuilder source = new SearchSourceBuilder().aggregation(
82+
AggregationBuilders.terms(aggName).field("terms.keyword").size(Math.max(1, numKeywords))
83+
);
84+
85+
final int progressStep = (numShards > 2) ? randomIntBetween(2, numShards) : 2;
86+
try (SearchResponseIterator it = assertBlockingIterator(indexName, numShards, source, 0, progressStep)) {
87+
String id = getAsyncId(it);
88+
89+
PollStats stats = new PollStats();
90+
91+
// Pick a random number of status-poller threads, at least 1, up to (4×numShards)
92+
int pollerThreads = randomIntBetween(1, 4 * numShards);
93+
94+
// Wait for pollers to be active
95+
CountDownLatch warmed = new CountDownLatch(1);
96+
97+
// Executor and coordination for pollers
98+
ExecutorService pollerExec = Executors.newFixedThreadPool(pollerThreads);
99+
AtomicBoolean running = new AtomicBoolean(true);
100+
Queue<Throwable> failures = new ConcurrentLinkedQueue<>();
101+
102+
CompletableFuture<Void> pollers = createPollers(id, pollerThreads, stats, warmed, pollerExec, running, failures);
103+
104+
// Wait until pollers are issuing requests (warming period)
105+
assertTrue("pollers did not warm up in time", warmed.await(timeout.millis(), TimeUnit.MILLISECONDS));
106+
107+
// Start consumer on a separate thread and capture errors
108+
var consumerExec = Executors.newSingleThreadExecutor();
109+
AtomicReference<Throwable> consumerError = new AtomicReference<>();
110+
Future<?> consumer = consumerExec.submit(() -> {
111+
try {
112+
consumeAllResponses(it, aggName);
113+
} catch (Throwable t) {
114+
consumerError.set(t);
115+
}
116+
});
117+
118+
// Join consumer & surface errors
119+
try {
120+
consumer.get(timeout.millis(), TimeUnit.MILLISECONDS);
121+
122+
if (consumerError.get() != null) {
123+
fail("consumeAllResponses failed: " + consumerError.get());
124+
}
125+
} catch (TimeoutException e) {
126+
consumer.cancel(true);
127+
fail(e, "Consumer thread did not finish within timeout");
128+
} catch (Exception ignored) {
129+
// ignored
130+
} finally {
131+
// Stop pollers
132+
running.set(false);
133+
try {
134+
pollers.get(timeout.millis(), TimeUnit.MILLISECONDS);
135+
} catch (TimeoutException te) {
136+
// The finally block will shut down the pollers forcibly
137+
} catch (ExecutionException ee) {
138+
failures.add(ExceptionsHelper.unwrapCause(ee.getCause()));
139+
} catch (InterruptedException ie) {
140+
Thread.currentThread().interrupt();
141+
} finally {
142+
pollerExec.shutdownNow();
143+
try {
144+
pollerExec.awaitTermination(timeout.millis(), TimeUnit.MILLISECONDS);
145+
} catch (InterruptedException ie) {
146+
Thread.currentThread().interrupt();
147+
fail("Interrupted while stopping pollers: " + ie.getMessage());
148+
}
149+
}
150+
151+
// Shut down the consumer executor
152+
consumerExec.shutdown();
153+
try {
154+
consumerExec.awaitTermination(timeout.millis(), TimeUnit.MILLISECONDS);
155+
} catch (InterruptedException ie) {
156+
Thread.currentThread().interrupt();
157+
}
158+
}
159+
160+
assertNoWorkerFailures(failures);
161+
assertStats(stats);
162+
}
163+
}
164+
165+
private void assertNoWorkerFailures(Queue<Throwable> failures) {
166+
assertTrue(
167+
"Unexpected worker failures:\n" + failures.stream().map(ExceptionsHelper::stackTrace).reduce("", (a, b) -> a + "\n---\n" + b),
168+
failures.isEmpty()
169+
);
170+
}
171+
172+
private void assertStats(PollStats stats) {
173+
assertEquals(stats.totalCalls.sum(), stats.runningResponses.sum() + stats.completedResponses.sum());
174+
assertEquals("There should be no exceptions other than GONE", 0, stats.exceptions.sum());
175+
}
176+
177+
private String getAsyncId(SearchResponseIterator it) {
178+
AsyncSearchResponse response = it.next();
179+
try {
180+
assertNotNull(response.getId());
181+
return response.getId();
182+
} finally {
183+
response.decRef();
184+
}
185+
}
186+
187+
private void consumeAllResponses(SearchResponseIterator it, String aggName) throws Exception {
188+
while (it.hasNext()) {
189+
AsyncSearchResponse response = it.next();
190+
try {
191+
if (response.getSearchResponse() != null && response.getSearchResponse().getAggregations() != null) {
192+
assertNotNull(response.getSearchResponse().getAggregations().get(aggName));
193+
}
194+
} finally {
195+
response.decRef();
196+
}
197+
}
198+
}
199+
200+
private CompletableFuture<Void> createPollers(
201+
String id,
202+
int threads,
203+
PollStats stats,
204+
CountDownLatch warmed,
205+
ExecutorService pollerExec,
206+
AtomicBoolean running,
207+
Queue<Throwable> failures
208+
) {
209+
@SuppressWarnings("unchecked")
210+
final CompletableFuture<Void>[] tasks = IntStream.range(0, threads).mapToObj(i -> CompletableFuture.runAsync(() -> {
211+
while (running.get()) {
212+
AsyncSearchResponse resp = null;
213+
try {
214+
resp = getAsyncSearch(id);
215+
stats.totalCalls.increment();
216+
217+
// Once enough requests have been sent, consider pollers "warmed".
218+
if (stats.totalCalls.sum() >= threads) {
219+
warmed.countDown();
220+
}
221+
222+
if (resp.isRunning()) {
223+
stats.runningResponses.increment();
224+
} else {
225+
// Success-only assertions: if reported completed, we must have a proper search response
226+
assertNull("Async search reported completed with failure", resp.getFailure());
227+
assertNotNull("Completed async search must carry a SearchResponse", resp.getSearchResponse());
228+
assertNotNull("Completed async search must have aggregations", resp.getSearchResponse().getAggregations());
229+
assertNotNull(
230+
"Completed async search must contain the expected aggregation",
231+
resp.getSearchResponse().getAggregations().get("terms")
232+
);
233+
stats.completedResponses.increment();
234+
}
235+
} catch (Exception e) {
236+
Throwable cause = ExceptionsHelper.unwrapCause(e);
237+
if (cause instanceof ElasticsearchStatusException) {
238+
RestStatus status = ExceptionsHelper.status(cause);
239+
if (status == RestStatus.GONE) {
240+
stats.gone410.increment();
241+
} else {
242+
stats.exceptions.increment();
243+
failures.add(cause);
244+
}
245+
} else {
246+
stats.exceptions.increment();
247+
failures.add(cause);
248+
}
249+
} finally {
250+
if (resp != null) {
251+
resp.decRef();
252+
}
253+
}
254+
}
255+
}, pollerExec).whenComplete((v, ex) -> {
256+
if (ex != null) {
257+
failures.add(ExceptionsHelper.unwrapCause(ex));
258+
}
259+
})).toArray(CompletableFuture[]::new);
260+
261+
return CompletableFuture.allOf(tasks);
262+
}
263+
264+
static final class PollStats {
265+
final LongAdder totalCalls = new LongAdder();
266+
final LongAdder runningResponses = new LongAdder();
267+
final LongAdder completedResponses = new LongAdder();
268+
final LongAdder exceptions = new LongAdder();
269+
final LongAdder gone410 = new LongAdder();
270+
}
271+
}

0 commit comments

Comments
 (0)