Skip to content

Commit 1840e7f

Browse files
committed
Prevent concurrent access to local breaker in rerank
1 parent a84dff8 commit 1840e7f

File tree

4 files changed

+132
-81
lines changed

4 files changed

+132
-81
lines changed

x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/OperatorTestCase.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
import org.elasticsearch.common.util.concurrent.EsExecutors;
2323
import org.elasticsearch.compute.data.Block;
2424
import org.elasticsearch.compute.data.BlockFactory;
25+
import org.elasticsearch.compute.data.LocalCircuitBreaker;
2526
import org.elasticsearch.compute.data.Page;
27+
import org.elasticsearch.compute.operator.AsyncOperator;
2628
import org.elasticsearch.compute.operator.Driver;
2729
import org.elasticsearch.compute.operator.DriverContext;
2830
import org.elasticsearch.compute.operator.DriverRunner;
@@ -77,6 +79,19 @@ public final void testSimpleLargeInput() {
7779
assertSimple(driverContext(), between(1_000, 10_000));
7880
}
7981

82+
/**
83+
* Test with a local breaker
84+
*/
85+
public final void testWithLocalBreaker() {
86+
BlockFactory blockFactory = blockFactory();
87+
final int overReservedBytes = between(0, 1024 * 1024);
88+
final int maxOverReservedBytes = between(overReservedBytes, 1024 * 1024);
89+
var localBreaker = new LocalCircuitBreaker(blockFactory.breaker(), overReservedBytes, maxOverReservedBytes);
90+
BlockFactory localBlockFactory = blockFactory.newChildFactory(localBreaker);
91+
DriverContext driverContext = new DriverContext(localBlockFactory.bigArrays(), localBlockFactory);
92+
assertSimple(driverContext, between(10, 10_000));
93+
}
94+
8095
/**
8196
* Enough memory for {@link #simple} not to throw a {@link CircuitBreakingException}.
8297
* It's fine if this is <strong>much</strong> more memory than {@linkplain #simple} needs.
@@ -247,9 +262,19 @@ public void testSimpleFinishClose() {
247262
try (var operator = simple().get(driverContext)) {
248263
assert operator.needsInput();
249264
for (Page page : input) {
250-
operator.addInput(page);
265+
if (operator.needsInput()) {
266+
operator.addInput(page);
267+
} else {
268+
page.releaseBlocks();
269+
}
251270
}
252271
operator.finish();
272+
if (operator instanceof AsyncOperator<?> || randomBoolean()) {
273+
driverContext.finish();
274+
PlainActionFuture<Void> waitForAsync = new PlainActionFuture<>();
275+
driverContext.waitForAsyncActions(waitForAsync);
276+
waitForAsync.actionGet(TimeValue.timeValueSeconds(30));
277+
}
253278
}
254279
}
255280

x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TestDriverFactory.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.compute.test;
99

10+
import org.elasticsearch.compute.data.LocalCircuitBreaker;
1011
import org.elasticsearch.compute.operator.Driver;
1112
import org.elasticsearch.compute.operator.DriverContext;
1213
import org.elasticsearch.compute.operator.Operator;
@@ -28,7 +29,11 @@ public static Driver create(
2829
List<Operator> intermediateOperators,
2930
SinkOperator sink
3031
) {
31-
return create(driverContext, source, intermediateOperators, sink, () -> {});
32+
return create(driverContext, source, intermediateOperators, sink, () -> {
33+
if (driverContext.breaker() instanceof LocalCircuitBreaker localBreaker) {
34+
localBreaker.close();
35+
}
36+
});
3237
}
3338

3439
public static Driver create(

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java

Lines changed: 80 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919
import org.elasticsearch.compute.operator.DriverContext;
2020
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
2121
import org.elasticsearch.compute.operator.Operator;
22+
import org.elasticsearch.core.Releasable;
2223
import org.elasticsearch.core.Releasables;
2324
import org.elasticsearch.inference.TaskType;
2425
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2526
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
2627

2728
import java.util.List;
2829

29-
public class RerankOperator extends AsyncOperator<Page> {
30+
public class RerankOperator extends AsyncOperator<RerankOperator.OngoingRerank> {
3031

3132
// Move to a setting.
3233
private static final int MAX_INFERENCE_WORKER = 10;
@@ -85,20 +86,16 @@ public RerankOperator(
8586
}
8687

8788
@Override
88-
protected void performAsync(Page inputPage, ActionListener<Page> listener) {
89+
protected void performAsync(Page inputPage, ActionListener<OngoingRerank> listener) {
8990
// Ensure input page blocks are released when the listener is called.
90-
final ActionListener<Page> outputListener = ActionListener.runAfter(listener, () -> { releasePageOnAnyThread(inputPage); });
91-
91+
listener = listener.delegateResponse((l, e) -> {
92+
releasePageOnAnyThread(inputPage);
93+
l.onFailure(e);
94+
});
9295
try {
93-
inferenceRunner.doInference(
94-
buildInferenceRequest(inputPage),
95-
ActionListener.wrap(
96-
inferenceResponse -> outputListener.onResponse(buildOutput(inputPage, inferenceResponse)),
97-
outputListener::onFailure
98-
)
99-
);
96+
inferenceRunner.doInference(buildInferenceRequest(inputPage), listener.map(resp -> new OngoingRerank(inputPage, resp)));
10097
} catch (Exception e) {
101-
outputListener.onFailure(e);
98+
listener.onFailure(e);
10299
}
103100
}
104101

@@ -108,91 +105,106 @@ protected void doClose() {
108105
}
109106

110107
@Override
111-
protected void releaseFetchedOnAnyThread(Page page) {
112-
releasePageOnAnyThread(page);
108+
protected void releaseFetchedOnAnyThread(OngoingRerank result) {
109+
releasePageOnAnyThread(result.inputPage);
113110
}
114111

115112
@Override
116113
public Page getOutput() {
117-
return fetchFromBuffer();
114+
var fetched = fetchFromBuffer();
115+
if (fetched == null) {
116+
return null;
117+
} else {
118+
return fetched.buildOutput(blockFactory, scoreChannel);
119+
}
118120
}
119121

120122
@Override
121123
public String toString() {
122124
return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]";
123125
}
124126

125-
private Page buildOutput(Page inputPage, InferenceAction.Response inferenceResponse) {
126-
if (inferenceResponse.getResults() instanceof RankedDocsResults rankedDocsResults) {
127-
return buildOutput(inputPage, rankedDocsResults);
128-
129-
}
130-
131-
throw new IllegalStateException(
132-
"Inference result has wrong type. Got ["
133-
+ inferenceResponse.getResults().getClass()
134-
+ "] while expecting ["
135-
+ RankedDocsResults.class
136-
+ "]"
137-
);
138-
}
139-
140-
private Page buildOutput(Page inputPage, RankedDocsResults rankedDocsResults) {
141-
int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1);
142-
Block[] blocks = new Block[blockCount];
127+
private InferenceAction.Request buildInferenceRequest(Page inputPage) {
128+
try (BytesRefBlock encodedRowsBlock = (BytesRefBlock) rowEncoder.eval(inputPage)) {
129+
assert (encodedRowsBlock.getPositionCount() == inputPage.getPositionCount());
130+
String[] inputs = new String[inputPage.getPositionCount()];
131+
BytesRef buffer = new BytesRef();
143132

144-
try {
145-
for (int b = 0; b < blockCount; b++) {
146-
if (b == scoreChannel) {
147-
blocks[b] = buildScoreBlock(inputPage, rankedDocsResults);
133+
for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
134+
if (encodedRowsBlock.isNull(pos)) {
135+
inputs[pos] = "";
148136
} else {
149-
blocks[b] = inputPage.getBlock(b);
150-
blocks[b].incRef();
137+
buffer = encodedRowsBlock.getBytesRef(encodedRowsBlock.getFirstValueIndex(pos), buffer);
138+
inputs[pos] = BytesRefs.toString(buffer);
151139
}
152140
}
153-
return new Page(blocks);
154-
} catch (Exception e) {
155-
Releasables.closeExpectNoException(blocks);
156-
throw (e);
141+
142+
return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build();
157143
}
158144
}
159145

160-
private Block buildScoreBlock(Page inputPage, RankedDocsResults rankedDocsResults) {
161-
Double[] sortedRankedDocsScores = new Double[inputPage.getPositionCount()];
146+
public static final class OngoingRerank {
147+
final Page inputPage;
148+
final Double[] rankedScores;
149+
150+
OngoingRerank(Page inputPage, InferenceAction.Response resp) {
151+
if (resp.getResults() instanceof RankedDocsResults == false) {
152+
releasePageOnAnyThread(inputPage);
153+
throw new IllegalStateException(
154+
"Inference result has wrong type. Got ["
155+
+ resp.getResults().getClass()
156+
+ "] while expecting ["
157+
+ RankedDocsResults.class
158+
+ "]"
159+
);
162160

163-
try (DoubleBlock.Builder scoreBlockFactory = blockFactory.newDoubleBlockBuilder(inputPage.getPositionCount())) {
161+
}
162+
final var results = (RankedDocsResults) resp.getResults();
163+
this.inputPage = inputPage;
164+
this.rankedScores = extractRankedScores(inputPage.getPositionCount(), results);
165+
}
166+
167+
private static Double[] extractRankedScores(int positionCount, RankedDocsResults rankedDocsResults) {
168+
Double[] sortedRankedDocsScores = new Double[positionCount];
164169
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocsResults.getRankedDocs()) {
165170
sortedRankedDocsScores[rankedDoc.index()] = (double) rankedDoc.relevanceScore();
166171
}
172+
return sortedRankedDocsScores;
173+
}
167174

168-
for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
169-
if (sortedRankedDocsScores[pos] != null) {
170-
scoreBlockFactory.appendDouble(sortedRankedDocsScores[pos]);
171-
} else {
172-
scoreBlockFactory.appendNull();
175+
Page buildOutput(BlockFactory blockFactory, int scoreChannel) {
176+
int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1);
177+
Block[] blocks = new Block[blockCount];
178+
Page outputPage = null;
179+
try (Releasable ignored = inputPage::releaseBlocks) {
180+
for (int b = 0; b < blockCount; b++) {
181+
if (b == scoreChannel) {
182+
blocks[b] = buildScoreBlock(blockFactory);
183+
} else {
184+
blocks[b] = inputPage.getBlock(b);
185+
blocks[b].incRef();
186+
}
187+
}
188+
outputPage = new Page(blocks);
189+
return outputPage;
190+
} finally {
191+
if (outputPage == null) {
192+
Releasables.closeExpectNoException(blocks);
173193
}
174194
}
175-
176-
return scoreBlockFactory.build();
177195
}
178-
}
179-
180-
private InferenceAction.Request buildInferenceRequest(Page inputPage) {
181-
try (BytesRefBlock encodedRowsBlock = (BytesRefBlock) rowEncoder.eval(inputPage)) {
182-
assert (encodedRowsBlock.getPositionCount() == inputPage.getPositionCount());
183-
String[] inputs = new String[inputPage.getPositionCount()];
184-
BytesRef buffer = new BytesRef();
185196

186-
for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
187-
if (encodedRowsBlock.isNull(pos)) {
188-
inputs[pos] = "";
189-
} else {
190-
buffer = encodedRowsBlock.getBytesRef(encodedRowsBlock.getFirstValueIndex(pos), buffer);
191-
inputs[pos] = BytesRefs.toString(buffer);
197+
private Block buildScoreBlock(BlockFactory blockFactory) {
198+
try (DoubleBlock.Builder scoreBlockFactory = blockFactory.newDoubleBlockBuilder(rankedScores.length)) {
199+
for (Double rankedScore : rankedScores) {
200+
if (rankedScore != null) {
201+
scoreBlockFactory.appendDouble(rankedScore);
202+
} else {
203+
scoreBlockFactory.appendNull();
204+
}
192205
}
206+
return scoreBlockFactory.build();
193207
}
194-
195-
return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build();
196208
}
197209
}
198210
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/RerankOperatorTests.java

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.elasticsearch.compute.test.OperatorTestCase;
3131
import org.elasticsearch.compute.test.RandomBlock;
3232
import org.elasticsearch.core.Releasables;
33+
import org.elasticsearch.core.TimeValue;
3334
import org.elasticsearch.threadpool.FixedExecutorBuilder;
3435
import org.elasticsearch.threadpool.TestThreadPool;
3536
import org.elasticsearch.threadpool.ThreadPool;
@@ -97,16 +98,23 @@ private InferenceRunner mockedSimpleInferenceRunner() {
9798
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
9899
when(inferenceRunner.getThreadContext()).thenReturn(threadPool.getThreadContext());
99100
doAnswer(invocation -> {
100-
@SuppressWarnings("unchecked")
101-
ActionListener<InferenceAction.Response> listener = (ActionListener<InferenceAction.Response>) invocation.getArgument(
102-
1,
103-
ActionListener.class
104-
);
105-
InferenceAction.Response inferenceResponse = mock(InferenceAction.Response.class);
106-
when(inferenceResponse.getResults()).thenReturn(
107-
mockedRankedDocResults(invocation.getArgument(0, InferenceAction.Request.class))
108-
);
109-
listener.onResponse(inferenceResponse);
101+
Runnable sendResponse = () -> {
102+
@SuppressWarnings("unchecked")
103+
ActionListener<InferenceAction.Response> listener = (ActionListener<InferenceAction.Response>) invocation.getArgument(
104+
1,
105+
ActionListener.class
106+
);
107+
InferenceAction.Response inferenceResponse = mock(InferenceAction.Response.class);
108+
when(inferenceResponse.getResults()).thenReturn(
109+
mockedRankedDocResults(invocation.getArgument(0, InferenceAction.Request.class))
110+
);
111+
listener.onResponse(inferenceResponse);
112+
};
113+
if (randomBoolean()) {
114+
sendResponse.run();
115+
} else {
116+
threadPool.schedule(sendResponse, TimeValue.timeValueNanos(between(1, 1_000)), threadPool.executor(ESQL_TEST_EXECUTOR));
117+
}
110118
return null;
111119
}).when(inferenceRunner).doInference(any(), any());
112120

@@ -137,7 +145,8 @@ protected Matcher<String> expectedToStringOfSimple() {
137145

138146
@Override
139147
protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
140-
return new AbstractBlockSourceOperator(blockFactory, 8 * 1024) {
148+
final int minPageSize = Math.max(1, size / 100);
149+
return new AbstractBlockSourceOperator(blockFactory, between(minPageSize, size)) {
141150
@Override
142151
protected int remaining() {
143152
return size - currentPosition;

0 commit comments

Comments
 (0)