Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.LocalCircuitBreaker;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.AsyncOperator;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.DriverRunner;
Expand Down Expand Up @@ -77,6 +79,19 @@ public final void testSimpleLargeInput() {
assertSimple(driverContext(), between(1_000, 10_000));
}

/**
* Test with a local breaker
*/
public final void testWithLocalBreaker() {
BlockFactory blockFactory = blockFactory();
final int overReservedBytes = between(0, 1024 * 1024);
final int maxOverReservedBytes = between(overReservedBytes, 1024 * 1024);
var localBreaker = new LocalCircuitBreaker(blockFactory.breaker(), overReservedBytes, maxOverReservedBytes);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we wrap this in something that asserts that we're on the same thread every time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ I pushed ce54583.

BlockFactory localBlockFactory = blockFactory.newChildFactory(localBreaker);
DriverContext driverContext = new DriverContext(localBlockFactory.bigArrays(), localBlockFactory);
assertSimple(driverContext, between(10, 10_000));
}

/**
* Enough memory for {@link #simple} not to throw a {@link CircuitBreakingException}.
* It's fine if this is <strong>much</strong> more memory than {@linkplain #simple} needs.
Expand Down Expand Up @@ -247,9 +262,19 @@ public void testSimpleFinishClose() {
try (var operator = simple().get(driverContext)) {
assert operator.needsInput();
for (Page page : input) {
operator.addInput(page);
if (operator.needsInput()) {
operator.addInput(page);
} else {
page.releaseBlocks();
}
}
operator.finish();
if (operator instanceof AsyncOperator<?> || randomBoolean()) {
driverContext.finish();
PlainActionFuture<Void> waitForAsync = new PlainActionFuture<>();
driverContext.waitForAsyncActions(waitForAsync);
waitForAsync.actionGet(TimeValue.timeValueSeconds(30));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.compute.test;

import org.elasticsearch.compute.data.LocalCircuitBreaker;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.Operator;
Expand All @@ -28,7 +29,11 @@ public static Driver create(
List<Operator> intermediateOperators,
SinkOperator sink
) {
return create(driverContext, source, intermediateOperators, sink, () -> {});
return create(driverContext, source, intermediateOperators, sink, () -> {
if (driverContext.breaker() instanceof LocalCircuitBreaker localBreaker) {
localBreaker.close();
}
});
}

public static Driver create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;

import java.util.List;

public class RerankOperator extends AsyncOperator<Page> {
public class RerankOperator extends AsyncOperator<RerankOperator.OngoingRerank> {

// Move to a setting.
private static final int MAX_INFERENCE_WORKER = 10;
Expand Down Expand Up @@ -85,20 +86,16 @@ public RerankOperator(
}

@Override
protected void performAsync(Page inputPage, ActionListener<Page> listener) {
protected void performAsync(Page inputPage, ActionListener<OngoingRerank> listener) {
// Ensure input page blocks are released when the listener is called.
final ActionListener<Page> outputListener = ActionListener.runAfter(listener, () -> { releasePageOnAnyThread(inputPage); });

listener = listener.delegateResponse((l, e) -> {
releasePageOnAnyThread(inputPage);
l.onFailure(e);
});
try {
inferenceRunner.doInference(
buildInferenceRequest(inputPage),
ActionListener.wrap(
inferenceResponse -> outputListener.onResponse(buildOutput(inputPage, inferenceResponse)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it isn't safe to create blocks here (And therefore update the CB), why is it safe to read a block from transport like we do in LookupFromIndexOperator? I was wondering which rule we should follow in the AsyncOperators to avoid having this issue again

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • We use the global breaker not local breaker for reading blocks from an input stream.
  • I have added assertions in ce54583, which should consistently detect such misuse.

outputListener::onFailure
)
);
inferenceRunner.doInference(buildInferenceRequest(inputPage), listener.map(resp -> new OngoingRerank(inputPage, resp)));
} catch (Exception e) {
outputListener.onFailure(e);
listener.onFailure(e);
}
}

Expand All @@ -108,91 +105,106 @@ protected void doClose() {
}

@Override
protected void releaseFetchedOnAnyThread(Page page) {
releasePageOnAnyThread(page);
protected void releaseFetchedOnAnyThread(OngoingRerank result) {
releasePageOnAnyThread(result.inputPage);
}

@Override
public Page getOutput() {
return fetchFromBuffer();
var fetched = fetchFromBuffer();
if (fetched == null) {
return null;
} else {
return fetched.buildOutput(blockFactory, scoreChannel);
}
}

@Override
public String toString() {
return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]";
}

private Page buildOutput(Page inputPage, InferenceAction.Response inferenceResponse) {
if (inferenceResponse.getResults() instanceof RankedDocsResults rankedDocsResults) {
return buildOutput(inputPage, rankedDocsResults);

}

throw new IllegalStateException(
"Inference result has wrong type. Got ["
+ inferenceResponse.getResults().getClass()
+ "] while expecting ["
+ RankedDocsResults.class
+ "]"
);
}

private Page buildOutput(Page inputPage, RankedDocsResults rankedDocsResults) {
int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1);
Block[] blocks = new Block[blockCount];
private InferenceAction.Request buildInferenceRequest(Page inputPage) {
try (BytesRefBlock encodedRowsBlock = (BytesRefBlock) rowEncoder.eval(inputPage)) {
assert (encodedRowsBlock.getPositionCount() == inputPage.getPositionCount());
String[] inputs = new String[inputPage.getPositionCount()];
BytesRef buffer = new BytesRef();

try {
for (int b = 0; b < blockCount; b++) {
if (b == scoreChannel) {
blocks[b] = buildScoreBlock(inputPage, rankedDocsResults);
for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
if (encodedRowsBlock.isNull(pos)) {
inputs[pos] = "";
} else {
blocks[b] = inputPage.getBlock(b);
blocks[b].incRef();
buffer = encodedRowsBlock.getBytesRef(encodedRowsBlock.getFirstValueIndex(pos), buffer);
inputs[pos] = BytesRefs.toString(buffer);
}
}
return new Page(blocks);
} catch (Exception e) {
Releasables.closeExpectNoException(blocks);
throw (e);

return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build();
}
}

private Block buildScoreBlock(Page inputPage, RankedDocsResults rankedDocsResults) {
Double[] sortedRankedDocsScores = new Double[inputPage.getPositionCount()];
public static final class OngoingRerank {
final Page inputPage;
final Double[] rankedScores;

OngoingRerank(Page inputPage, InferenceAction.Response resp) {
if (resp.getResults() instanceof RankedDocsResults == false) {
releasePageOnAnyThread(inputPage);
throw new IllegalStateException(
"Inference result has wrong type. Got ["
+ resp.getResults().getClass()
+ "] while expecting ["
+ RankedDocsResults.class
+ "]"
);

try (DoubleBlock.Builder scoreBlockFactory = blockFactory.newDoubleBlockBuilder(inputPage.getPositionCount())) {
}
final var results = (RankedDocsResults) resp.getResults();
this.inputPage = inputPage;
this.rankedScores = extractRankedScores(inputPage.getPositionCount(), results);
}

private static Double[] extractRankedScores(int positionCount, RankedDocsResults rankedDocsResults) {
Double[] sortedRankedDocsScores = new Double[positionCount];
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocsResults.getRankedDocs()) {
sortedRankedDocsScores[rankedDoc.index()] = (double) rankedDoc.relevanceScore();
}
return sortedRankedDocsScores;
}

for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
if (sortedRankedDocsScores[pos] != null) {
scoreBlockFactory.appendDouble(sortedRankedDocsScores[pos]);
} else {
scoreBlockFactory.appendNull();
Page buildOutput(BlockFactory blockFactory, int scoreChannel) {
int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1);
Block[] blocks = new Block[blockCount];
Page outputPage = null;
try (Releasable ignored = inputPage::releaseBlocks) {
for (int b = 0; b < blockCount; b++) {
if (b == scoreChannel) {
blocks[b] = buildScoreBlock(blockFactory);
} else {
blocks[b] = inputPage.getBlock(b);
blocks[b].incRef();
}
}
outputPage = new Page(blocks);
return outputPage;
} finally {
if (outputPage == null) {
Releasables.closeExpectNoException(blocks);
}
}

return scoreBlockFactory.build();
}
}

private InferenceAction.Request buildInferenceRequest(Page inputPage) {
try (BytesRefBlock encodedRowsBlock = (BytesRefBlock) rowEncoder.eval(inputPage)) {
assert (encodedRowsBlock.getPositionCount() == inputPage.getPositionCount());
String[] inputs = new String[inputPage.getPositionCount()];
BytesRef buffer = new BytesRef();

for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
if (encodedRowsBlock.isNull(pos)) {
inputs[pos] = "";
} else {
buffer = encodedRowsBlock.getBytesRef(encodedRowsBlock.getFirstValueIndex(pos), buffer);
inputs[pos] = BytesRefs.toString(buffer);
private Block buildScoreBlock(BlockFactory blockFactory) {
try (DoubleBlock.Builder scoreBlockFactory = blockFactory.newDoubleBlockBuilder(rankedScores.length)) {
for (Double rankedScore : rankedScores) {
if (rankedScore != null) {
scoreBlockFactory.appendDouble(rankedScore);
} else {
scoreBlockFactory.appendNull();
}
}
return scoreBlockFactory.build();
}

return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.elasticsearch.compute.test.OperatorTestCase;
import org.elasticsearch.compute.test.RandomBlock;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -97,16 +98,23 @@ private InferenceRunner mockedSimpleInferenceRunner() {
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
when(inferenceRunner.getThreadContext()).thenReturn(threadPool.getThreadContext());
doAnswer(invocation -> {
@SuppressWarnings("unchecked")
ActionListener<InferenceAction.Response> listener = (ActionListener<InferenceAction.Response>) invocation.getArgument(
1,
ActionListener.class
);
InferenceAction.Response inferenceResponse = mock(InferenceAction.Response.class);
when(inferenceResponse.getResults()).thenReturn(
mockedRankedDocResults(invocation.getArgument(0, InferenceAction.Request.class))
);
listener.onResponse(inferenceResponse);
Runnable sendResponse = () -> {
@SuppressWarnings("unchecked")
ActionListener<InferenceAction.Response> listener = (ActionListener<InferenceAction.Response>) invocation.getArgument(
1,
ActionListener.class
);
InferenceAction.Response inferenceResponse = mock(InferenceAction.Response.class);
when(inferenceResponse.getResults()).thenReturn(
mockedRankedDocResults(invocation.getArgument(0, InferenceAction.Request.class))
);
listener.onResponse(inferenceResponse);
};
if (randomBoolean()) {
sendResponse.run();
} else {
threadPool.schedule(sendResponse, TimeValue.timeValueNanos(between(1, 1_000)), threadPool.executor(ESQL_TEST_EXECUTOR));
}
return null;
}).when(inferenceRunner).doInference(any(), any());

Expand Down Expand Up @@ -137,7 +145,8 @@ protected Matcher<String> expectedToStringOfSimple() {

@Override
protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
return new AbstractBlockSourceOperator(blockFactory, 8 * 1024) {
final int minPageSize = Math.max(1, size / 100);
return new AbstractBlockSourceOperator(blockFactory, between(minPageSize, size)) {
@Override
protected int remaining() {
return size - currentPosition;
Expand Down