-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Prevent concurrent access to local breaker in rerank #128162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,14 +19,15 @@ | |
| import org.elasticsearch.compute.operator.DriverContext; | ||
| import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; | ||
| import org.elasticsearch.compute.operator.Operator; | ||
| import org.elasticsearch.core.Releasable; | ||
| import org.elasticsearch.core.Releasables; | ||
| import org.elasticsearch.inference.TaskType; | ||
| import org.elasticsearch.xpack.core.inference.action.InferenceAction; | ||
| import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; | ||
|
|
||
| import java.util.List; | ||
|
|
||
| public class RerankOperator extends AsyncOperator<Page> { | ||
| public class RerankOperator extends AsyncOperator<RerankOperator.OngoingRerank> { | ||
|
|
||
| // Move to a setting. | ||
| private static final int MAX_INFERENCE_WORKER = 10; | ||
|
|
@@ -85,20 +86,16 @@ public RerankOperator( | |
| } | ||
|
|
||
| @Override | ||
| protected void performAsync(Page inputPage, ActionListener<Page> listener) { | ||
| protected void performAsync(Page inputPage, ActionListener<OngoingRerank> listener) { | ||
| // Ensure input page blocks are released when the listener is called. | ||
| final ActionListener<Page> outputListener = ActionListener.runAfter(listener, () -> { releasePageOnAnyThread(inputPage); }); | ||
|
|
||
| listener = listener.delegateResponse((l, e) -> { | ||
| releasePageOnAnyThread(inputPage); | ||
| l.onFailure(e); | ||
| }); | ||
| try { | ||
| inferenceRunner.doInference( | ||
| buildInferenceRequest(inputPage), | ||
| ActionListener.wrap( | ||
| inferenceResponse -> outputListener.onResponse(buildOutput(inputPage, inferenceResponse)), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it isn't safe to create blocks here (And therefore update the CB), why is it safe to read a block from transport like we do in
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| outputListener::onFailure | ||
| ) | ||
| ); | ||
| inferenceRunner.doInference(buildInferenceRequest(inputPage), listener.map(resp -> new OngoingRerank(inputPage, resp))); | ||
| } catch (Exception e) { | ||
| outputListener.onFailure(e); | ||
| listener.onFailure(e); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -108,91 +105,106 @@ protected void doClose() { | |
| } | ||
|
|
||
| @Override | ||
| protected void releaseFetchedOnAnyThread(Page page) { | ||
| releasePageOnAnyThread(page); | ||
| protected void releaseFetchedOnAnyThread(OngoingRerank result) { | ||
| releasePageOnAnyThread(result.inputPage); | ||
| } | ||
|
|
||
| @Override | ||
| public Page getOutput() { | ||
| return fetchFromBuffer(); | ||
| var fetched = fetchFromBuffer(); | ||
| if (fetched == null) { | ||
| return null; | ||
| } else { | ||
| return fetched.buildOutput(blockFactory, scoreChannel); | ||
| } | ||
| } | ||
|
|
||
| @Override | ||
| public String toString() { | ||
| return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]"; | ||
| } | ||
|
|
||
| private Page buildOutput(Page inputPage, InferenceAction.Response inferenceResponse) { | ||
| if (inferenceResponse.getResults() instanceof RankedDocsResults rankedDocsResults) { | ||
| return buildOutput(inputPage, rankedDocsResults); | ||
|
|
||
| } | ||
|
|
||
| throw new IllegalStateException( | ||
| "Inference result has wrong type. Got [" | ||
| + inferenceResponse.getResults().getClass() | ||
| + "] while expecting [" | ||
| + RankedDocsResults.class | ||
| + "]" | ||
| ); | ||
| } | ||
|
|
||
| private Page buildOutput(Page inputPage, RankedDocsResults rankedDocsResults) { | ||
| int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1); | ||
| Block[] blocks = new Block[blockCount]; | ||
| private InferenceAction.Request buildInferenceRequest(Page inputPage) { | ||
| try (BytesRefBlock encodedRowsBlock = (BytesRefBlock) rowEncoder.eval(inputPage)) { | ||
| assert (encodedRowsBlock.getPositionCount() == inputPage.getPositionCount()); | ||
| String[] inputs = new String[inputPage.getPositionCount()]; | ||
| BytesRef buffer = new BytesRef(); | ||
|
|
||
| try { | ||
| for (int b = 0; b < blockCount; b++) { | ||
| if (b == scoreChannel) { | ||
| blocks[b] = buildScoreBlock(inputPage, rankedDocsResults); | ||
| for (int pos = 0; pos < inputPage.getPositionCount(); pos++) { | ||
| if (encodedRowsBlock.isNull(pos)) { | ||
| inputs[pos] = ""; | ||
| } else { | ||
| blocks[b] = inputPage.getBlock(b); | ||
| blocks[b].incRef(); | ||
| buffer = encodedRowsBlock.getBytesRef(encodedRowsBlock.getFirstValueIndex(pos), buffer); | ||
| inputs[pos] = BytesRefs.toString(buffer); | ||
| } | ||
| } | ||
| return new Page(blocks); | ||
| } catch (Exception e) { | ||
| Releasables.closeExpectNoException(blocks); | ||
| throw (e); | ||
|
|
||
| return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build(); | ||
| } | ||
| } | ||
|
|
||
| private Block buildScoreBlock(Page inputPage, RankedDocsResults rankedDocsResults) { | ||
| Double[] sortedRankedDocsScores = new Double[inputPage.getPositionCount()]; | ||
| public static final class OngoingRerank { | ||
| final Page inputPage; | ||
| final Double[] rankedScores; | ||
|
|
||
| OngoingRerank(Page inputPage, InferenceAction.Response resp) { | ||
| if (resp.getResults() instanceof RankedDocsResults == false) { | ||
| releasePageOnAnyThread(inputPage); | ||
| throw new IllegalStateException( | ||
| "Inference result has wrong type. Got [" | ||
| + resp.getResults().getClass() | ||
| + "] while expecting [" | ||
| + RankedDocsResults.class | ||
| + "]" | ||
| ); | ||
|
|
||
| try (DoubleBlock.Builder scoreBlockFactory = blockFactory.newDoubleBlockBuilder(inputPage.getPositionCount())) { | ||
| } | ||
| final var results = (RankedDocsResults) resp.getResults(); | ||
| this.inputPage = inputPage; | ||
| this.rankedScores = extractRankedScores(inputPage.getPositionCount(), results); | ||
| } | ||
|
|
||
| private static Double[] extractRankedScores(int positionCount, RankedDocsResults rankedDocsResults) { | ||
| Double[] sortedRankedDocsScores = new Double[positionCount]; | ||
| for (RankedDocsResults.RankedDoc rankedDoc : rankedDocsResults.getRankedDocs()) { | ||
| sortedRankedDocsScores[rankedDoc.index()] = (double) rankedDoc.relevanceScore(); | ||
| } | ||
| return sortedRankedDocsScores; | ||
| } | ||
|
|
||
| for (int pos = 0; pos < inputPage.getPositionCount(); pos++) { | ||
| if (sortedRankedDocsScores[pos] != null) { | ||
| scoreBlockFactory.appendDouble(sortedRankedDocsScores[pos]); | ||
| } else { | ||
| scoreBlockFactory.appendNull(); | ||
| Page buildOutput(BlockFactory blockFactory, int scoreChannel) { | ||
| int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1); | ||
| Block[] blocks = new Block[blockCount]; | ||
| Page outputPage = null; | ||
| try (Releasable ignored = inputPage::releaseBlocks) { | ||
| for (int b = 0; b < blockCount; b++) { | ||
| if (b == scoreChannel) { | ||
| blocks[b] = buildScoreBlock(blockFactory); | ||
| } else { | ||
| blocks[b] = inputPage.getBlock(b); | ||
| blocks[b].incRef(); | ||
| } | ||
| } | ||
| outputPage = new Page(blocks); | ||
| return outputPage; | ||
| } finally { | ||
| if (outputPage == null) { | ||
| Releasables.closeExpectNoException(blocks); | ||
| } | ||
| } | ||
|
|
||
| return scoreBlockFactory.build(); | ||
| } | ||
| } | ||
|
|
||
| private InferenceAction.Request buildInferenceRequest(Page inputPage) { | ||
| try (BytesRefBlock encodedRowsBlock = (BytesRefBlock) rowEncoder.eval(inputPage)) { | ||
| assert (encodedRowsBlock.getPositionCount() == inputPage.getPositionCount()); | ||
| String[] inputs = new String[inputPage.getPositionCount()]; | ||
| BytesRef buffer = new BytesRef(); | ||
|
|
||
| for (int pos = 0; pos < inputPage.getPositionCount(); pos++) { | ||
| if (encodedRowsBlock.isNull(pos)) { | ||
| inputs[pos] = ""; | ||
| } else { | ||
| buffer = encodedRowsBlock.getBytesRef(encodedRowsBlock.getFirstValueIndex(pos), buffer); | ||
| inputs[pos] = BytesRefs.toString(buffer); | ||
| private Block buildScoreBlock(BlockFactory blockFactory) { | ||
| try (DoubleBlock.Builder scoreBlockFactory = blockFactory.newDoubleBlockBuilder(rankedScores.length)) { | ||
| for (Double rankedScore : rankedScores) { | ||
| if (rankedScore != null) { | ||
| scoreBlockFactory.appendDouble(rankedScore); | ||
| } else { | ||
| scoreBlockFactory.appendNull(); | ||
| } | ||
| } | ||
| return scoreBlockFactory.build(); | ||
| } | ||
|
|
||
| return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build(); | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we wrap this in something that
asserts that we're on the same thread every time?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
++ I pushed ce54583.