Skip to content

Commit cac99e8

Browse files
committed
Prevent concurrent access to local breaker in rerank
1 parent a84dff8 commit cac99e8

File tree

4 files changed

+108
-62
lines changed

4 files changed

+108
-62
lines changed

x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/OperatorTestCase.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
import org.elasticsearch.common.util.concurrent.EsExecutors;
2323
import org.elasticsearch.compute.data.Block;
2424
import org.elasticsearch.compute.data.BlockFactory;
25+
import org.elasticsearch.compute.data.LocalCircuitBreaker;
2526
import org.elasticsearch.compute.data.Page;
27+
import org.elasticsearch.compute.operator.AsyncOperator;
2628
import org.elasticsearch.compute.operator.Driver;
2729
import org.elasticsearch.compute.operator.DriverContext;
2830
import org.elasticsearch.compute.operator.DriverRunner;
@@ -77,6 +79,19 @@ public final void testSimpleLargeInput() {
7779
assertSimple(driverContext(), between(1_000, 10_000));
7880
}
7981

82+
/**
83+
* Test with a local breaker
84+
*/
85+
public final void testWithLocalBreaker() {
86+
BlockFactory blockFactory = blockFactory();
87+
final int overReservedBytes = between(0, 1024 * 1024);
88+
final int maxOverReservedBytes = between(overReservedBytes, 1024 * 1024);
89+
var localBreaker = new LocalCircuitBreaker(blockFactory.breaker(), overReservedBytes, maxOverReservedBytes);
90+
BlockFactory localBlockFactory = blockFactory.newChildFactory(localBreaker);
91+
DriverContext driverContext = new DriverContext(localBlockFactory.bigArrays(), localBlockFactory);
92+
assertSimple(driverContext, between(10, 10_000));
93+
}
94+
8095
/**
8196
* Enough memory for {@link #simple} not to throw a {@link CircuitBreakingException}.
8297
* It's fine if this is <strong>much</strong> more memory than {@linkplain #simple} needs.
@@ -247,9 +262,19 @@ public void testSimpleFinishClose() {
247262
try (var operator = simple().get(driverContext)) {
248263
assert operator.needsInput();
249264
for (Page page : input) {
250-
operator.addInput(page);
265+
if (operator.needsInput()) {
266+
operator.addInput(page);
267+
} else {
268+
page.releaseBlocks();
269+
}
251270
}
252271
operator.finish();
272+
if (operator instanceof AsyncOperator<?> || randomBoolean()) {
273+
driverContext.finish();
274+
PlainActionFuture<Void> waitForAsync = new PlainActionFuture<>();
275+
driverContext.waitForAsyncActions(waitForAsync);
276+
waitForAsync.actionGet(TimeValue.timeValueSeconds(30));
277+
}
253278
}
254279
}
255280

x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TestDriverFactory.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.compute.test;
99

10+
import org.elasticsearch.compute.data.LocalCircuitBreaker;
1011
import org.elasticsearch.compute.operator.Driver;
1112
import org.elasticsearch.compute.operator.DriverContext;
1213
import org.elasticsearch.compute.operator.Operator;
@@ -28,7 +29,11 @@ public static Driver create(
2829
List<Operator> intermediateOperators,
2930
SinkOperator sink
3031
) {
31-
return create(driverContext, source, intermediateOperators, sink, () -> {});
32+
return create(driverContext, source, intermediateOperators, sink, () -> {
33+
if (driverContext.breaker() instanceof LocalCircuitBreaker localBreaker) {
34+
localBreaker.close();
35+
}
36+
});
3237
}
3338

3439
public static Driver create(

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java

Lines changed: 56 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919
import org.elasticsearch.compute.operator.DriverContext;
2020
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
2121
import org.elasticsearch.compute.operator.Operator;
22+
import org.elasticsearch.core.Releasable;
2223
import org.elasticsearch.core.Releasables;
2324
import org.elasticsearch.inference.TaskType;
2425
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2526
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
2627

2728
import java.util.List;
2829

29-
public class RerankOperator extends AsyncOperator<Page> {
30+
public class RerankOperator extends AsyncOperator<RerankOperator.InputPageAndRankedScores> {
3031

3132
// Move to a setting.
3233
private static final int MAX_INFERENCE_WORKER = 10;
@@ -85,20 +86,30 @@ public RerankOperator(
8586
}
8687

8788
@Override
88-
protected void performAsync(Page inputPage, ActionListener<Page> listener) {
89+
protected void performAsync(Page inputPage, ActionListener<InputPageAndRankedScores> listener) {
8990
// Ensure input page blocks are released when the listener is called.
90-
final ActionListener<Page> outputListener = ActionListener.runAfter(listener, () -> { releasePageOnAnyThread(inputPage); });
91-
91+
listener = listener.delegateResponse((l, e) -> {
92+
releasePageOnAnyThread(inputPage);
93+
l.onFailure(e);
94+
});
9295
try {
93-
inferenceRunner.doInference(
94-
buildInferenceRequest(inputPage),
95-
ActionListener.wrap(
96-
inferenceResponse -> outputListener.onResponse(buildOutput(inputPage, inferenceResponse)),
97-
outputListener::onFailure
98-
)
99-
);
96+
inferenceRunner.doInference(buildInferenceRequest(inputPage), listener.map(resp -> {
97+
if (resp.getResults() instanceof RankedDocsResults == false) {
98+
releasePageOnAnyThread(inputPage);
99+
throw new IllegalStateException(
100+
"Inference result has wrong type. Got ["
101+
+ resp.getResults().getClass()
102+
+ "] while expecting ["
103+
+ RankedDocsResults.class
104+
+ "]"
105+
);
106+
107+
}
108+
final var results = (RankedDocsResults) resp.getResults();
109+
return new InputPageAndRankedScores(inputPage, extractRankedScores(inputPage.getPositionCount(), results));
110+
}));
100111
} catch (Exception e) {
101-
outputListener.onFailure(e);
112+
listener.onFailure(e);
102113
}
103114
}
104115

@@ -108,71 +119,63 @@ protected void doClose() {
108119
}
109120

110121
@Override
111-
protected void releaseFetchedOnAnyThread(Page page) {
112-
releasePageOnAnyThread(page);
122+
protected void releaseFetchedOnAnyThread(InputPageAndRankedScores result) {
123+
releasePageOnAnyThread(result.inputPage());
113124
}
114125

115126
@Override
116127
public Page getOutput() {
117-
return fetchFromBuffer();
128+
var fetched = fetchFromBuffer();
129+
if (fetched == null) {
130+
return null;
131+
}
132+
return buildOutput(fetched.inputPage(), fetched.rankedScores());
118133
}
119134

120135
@Override
121136
public String toString() {
122137
return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]";
123138
}
124139

125-
private Page buildOutput(Page inputPage, InferenceAction.Response inferenceResponse) {
126-
if (inferenceResponse.getResults() instanceof RankedDocsResults rankedDocsResults) {
127-
return buildOutput(inputPage, rankedDocsResults);
128-
129-
}
130-
131-
throw new IllegalStateException(
132-
"Inference result has wrong type. Got ["
133-
+ inferenceResponse.getResults().getClass()
134-
+ "] while expecting ["
135-
+ RankedDocsResults.class
136-
+ "]"
137-
);
138-
}
139-
140-
private Page buildOutput(Page inputPage, RankedDocsResults rankedDocsResults) {
140+
private Page buildOutput(Page inputPage, Double[] rankedScores) {
141141
int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1);
142142
Block[] blocks = new Block[blockCount];
143-
144-
try {
143+
Page outputPage = null;
144+
try (Releasable ignored = inputPage::releaseBlocks) {
145145
for (int b = 0; b < blockCount; b++) {
146146
if (b == scoreChannel) {
147-
blocks[b] = buildScoreBlock(inputPage, rankedDocsResults);
147+
blocks[b] = buildScoreBlock(rankedScores);
148148
} else {
149149
blocks[b] = inputPage.getBlock(b);
150150
blocks[b].incRef();
151151
}
152152
}
153-
return new Page(blocks);
154-
} catch (Exception e) {
155-
Releasables.closeExpectNoException(blocks);
156-
throw (e);
153+
outputPage = new Page(blocks);
154+
return outputPage;
155+
} finally {
156+
if (outputPage == null) {
157+
Releasables.closeExpectNoException(blocks);
158+
}
157159
}
158160
}
159161

160-
private Block buildScoreBlock(Page inputPage, RankedDocsResults rankedDocsResults) {
161-
Double[] sortedRankedDocsScores = new Double[inputPage.getPositionCount()];
162-
163-
try (DoubleBlock.Builder scoreBlockFactory = blockFactory.newDoubleBlockBuilder(inputPage.getPositionCount())) {
164-
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocsResults.getRankedDocs()) {
165-
sortedRankedDocsScores[rankedDoc.index()] = (double) rankedDoc.relevanceScore();
166-
}
162+
private Double[] extractRankedScores(int positionCount, RankedDocsResults rankedDocsResults) {
163+
Double[] sortedRankedDocsScores = new Double[positionCount];
164+
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocsResults.getRankedDocs()) {
165+
sortedRankedDocsScores[rankedDoc.index()] = (double) rankedDoc.relevanceScore();
166+
}
167+
return sortedRankedDocsScores;
168+
}
167169

168-
for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
169-
if (sortedRankedDocsScores[pos] != null) {
170-
scoreBlockFactory.appendDouble(sortedRankedDocsScores[pos]);
170+
private Block buildScoreBlock(Double[] rankedScores) {
171+
try (DoubleBlock.Builder scoreBlockFactory = blockFactory.newDoubleBlockBuilder(rankedScores.length)) {
172+
for (Double rankedScore : rankedScores) {
173+
if (rankedScore != null) {
174+
scoreBlockFactory.appendDouble(rankedScore);
171175
} else {
172176
scoreBlockFactory.appendNull();
173177
}
174178
}
175-
176179
return scoreBlockFactory.build();
177180
}
178181
}
@@ -195,4 +198,8 @@ private InferenceAction.Request buildInferenceRequest(Page inputPage) {
195198
return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build();
196199
}
197200
}
201+
202+
public record InputPageAndRankedScores(Page inputPage, Double[] rankedScores) {
203+
204+
}
198205
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/RerankOperatorTests.java

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.elasticsearch.compute.test.OperatorTestCase;
3131
import org.elasticsearch.compute.test.RandomBlock;
3232
import org.elasticsearch.core.Releasables;
33+
import org.elasticsearch.core.TimeValue;
3334
import org.elasticsearch.threadpool.FixedExecutorBuilder;
3435
import org.elasticsearch.threadpool.TestThreadPool;
3536
import org.elasticsearch.threadpool.ThreadPool;
@@ -97,16 +98,23 @@ private InferenceRunner mockedSimpleInferenceRunner() {
9798
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
9899
when(inferenceRunner.getThreadContext()).thenReturn(threadPool.getThreadContext());
99100
doAnswer(invocation -> {
100-
@SuppressWarnings("unchecked")
101-
ActionListener<InferenceAction.Response> listener = (ActionListener<InferenceAction.Response>) invocation.getArgument(
102-
1,
103-
ActionListener.class
104-
);
105-
InferenceAction.Response inferenceResponse = mock(InferenceAction.Response.class);
106-
when(inferenceResponse.getResults()).thenReturn(
107-
mockedRankedDocResults(invocation.getArgument(0, InferenceAction.Request.class))
108-
);
109-
listener.onResponse(inferenceResponse);
101+
Runnable sendResponse = () -> {
102+
@SuppressWarnings("unchecked")
103+
ActionListener<InferenceAction.Response> listener = (ActionListener<InferenceAction.Response>) invocation.getArgument(
104+
1,
105+
ActionListener.class
106+
);
107+
InferenceAction.Response inferenceResponse = mock(InferenceAction.Response.class);
108+
when(inferenceResponse.getResults()).thenReturn(
109+
mockedRankedDocResults(invocation.getArgument(0, InferenceAction.Request.class))
110+
);
111+
listener.onResponse(inferenceResponse);
112+
};
113+
if (randomBoolean()) {
114+
sendResponse.run();
115+
} else {
116+
threadPool.schedule(sendResponse, TimeValue.timeValueNanos(between(1, 1_000)), threadPool.executor(ESQL_TEST_EXECUTOR));
117+
}
110118
return null;
111119
}).when(inferenceRunner).doInference(any(), any());
112120

@@ -137,7 +145,8 @@ protected Matcher<String> expectedToStringOfSimple() {
137145

138146
@Override
139147
protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
140-
return new AbstractBlockSourceOperator(blockFactory, 8 * 1024) {
148+
final int minPageSize = Math.max(1, size / 100);
149+
return new AbstractBlockSourceOperator(blockFactory, between(minPageSize, size)) {
141150
@Override
142151
protected int remaining() {
143152
return size - currentPosition;

0 commit comments

Comments
 (0)