Skip to content

Commit 1901bde

Browse files
committed
Add comments.
1 parent e64b81c commit 1901bde

File tree

10 files changed

+344
-35
lines changed

10 files changed

+344
-35
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 95 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.compute.operator.AsyncOperator;
1414
import org.elasticsearch.compute.operator.DriverContext;
1515
import org.elasticsearch.core.Releasable;
16+
import org.elasticsearch.core.Releasables;
1617
import org.elasticsearch.inference.InferenceServiceResults;
1718
import org.elasticsearch.threadpool.ThreadPool;
1819
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
@@ -24,70 +25,137 @@
2425

2526
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
2627

27-
public abstract class InferenceOperator extends AsyncOperator<InferenceOperator.OngoingInference> {
28+
/**
29+
* An abstract asynchronous operator that performs throttled bulk inference execution using an {@link InferenceRunner}.
30+
* <p>
31+
* The {@code InferenceOperator} integrates with the compute framework supports throttled bulk execution of inference requests. It
32+
* transforms input {@link Page} into inference requests, asynchronously executes them, and converts the responses into a new {@link Page}.
33+
* </p>
34+
*/
35+
public abstract class InferenceOperator extends AsyncOperator<InferenceOperator.OngoingInferenceResult> {
2836
private final String inferenceId;
2937
private final BlockFactory blockFactory;
3038
private final BulkInferenceExecutor bulkInferenceExecutor;
3139

40+
/**
41+
* Constructs a new {@code InferenceOperator}.
42+
*
43+
* @param driverContext The driver context.
44+
* @param inferenceRunner The runner used to execute inference requests.
45+
* @param bulkExecutionConfig Configuration for inference execution.
46+
* @param threadPool The thread pool used for executing async inference.
47+
* @param inferenceId The ID of the inference model to use.
48+
*/
3249
public InferenceOperator(
3350
DriverContext driverContext,
3451
InferenceRunner inferenceRunner,
3552
BulkInferenceExecutionConfig bulkExecutionConfig,
3653
ThreadPool threadPool,
3754
String inferenceId
3855
) {
39-
super(driverContext, threadPool.getThreadContext(), bulkExecutionConfig.workers());
56+
super(driverContext, inferenceRunner.threadPool().getThreadContext(), bulkExecutionConfig.workers());
4057
this.blockFactory = driverContext.blockFactory();
4158
this.bulkInferenceExecutor = new BulkInferenceExecutor(inferenceRunner, threadPool, bulkExecutionConfig);
4259
this.inferenceId = inferenceId;
4360
}
4461

62+
/**
63+
* Returns the {@link BlockFactory} used to create output data blocks.
64+
*/
4565
protected BlockFactory blockFactory() {
4666
return blockFactory;
4767
}
4868

69+
/**
70+
* Returns the inference model ID used for this operator.
71+
*/
4972
protected String inferenceId() {
5073
return inferenceId;
5174
}
5275

76+
/**
77+
* Initiates asynchronous inferences for the given input page.
78+
*/
5379
@Override
54-
protected void releaseFetchedOnAnyThread(OngoingInference result) {
55-
releasePageOnAnyThread(result.inputPage);
56-
}
57-
58-
@Override
59-
protected void performAsync(Page input, ActionListener<OngoingInference> listener) {
80+
protected void performAsync(Page input, ActionListener<OngoingInferenceResult> listener) {
6081
try {
6182
BulkInferenceRequestIterator requests = requests(input);
6283
listener = ActionListener.releaseBefore(requests, listener);
63-
bulkInferenceExecutor.execute(requests, listener.map(responses -> new OngoingInference(input, responses)));
84+
bulkInferenceExecutor.execute(requests, listener.map(responses -> new OngoingInferenceResult(input, responses)));
6485
} catch (Exception e) {
6586
listener.onFailure(e);
6687
}
6788
}
6889

90+
/**
91+
* Releases resources associated with an ongoing inference.
92+
*/
93+
@Override
94+
protected void releaseFetchedOnAnyThread(OngoingInferenceResult ongoingInferenceResult) {
95+
Releasables.close(ongoingInferenceResult);
96+
}
97+
98+
/**
99+
* Returns the next available output page constructed from completed inference results.
100+
*/
69101
@Override
70102
public Page getOutput() {
71-
OngoingInference ongoingInference = fetchFromBuffer();
72-
if (ongoingInference == null) {
103+
OngoingInferenceResult ongoingInferenceResult = fetchFromBuffer();
104+
if (ongoingInferenceResult == null) {
73105
return null;
74106
}
75107

76-
try (OutputBuilder outputBuilder = outputBuilder(ongoingInference.inputPage)) {
77-
ongoingInference.responses.forEach(outputBuilder::addInferenceResponse);
108+
try (OutputBuilder outputBuilder = outputBuilder(ongoingInferenceResult.inputPage)) {
109+
assert ongoingInferenceResult.inputPage.getPositionCount() == ongoingInferenceResult.responses.size();
110+
for (InferenceAction.Response response : ongoingInferenceResult.responses) {
111+
try {
112+
outputBuilder.addInferenceResponse(response);
113+
} catch (IllegalArgumentException e) {
114+
throw new IllegalStateException("Invalid inference response", e);
115+
}
116+
}
78117
return outputBuilder.buildOutput();
118+
79119
} finally {
80-
releaseFetchedOnAnyThread(ongoingInference);
120+
releaseFetchedOnAnyThread(ongoingInferenceResult);
81121
}
82122
}
83123

124+
/**
125+
* Converts the given input page into a sequence of inference requests.
126+
*
127+
* @param input The input page to process.
128+
*/
84129
protected abstract BulkInferenceRequestIterator requests(Page input);
85130

131+
/**
132+
* Creates a new {@link OutputBuilder} instance used to build the output page.
133+
*
134+
* @param input The corresponding input page used to generate the inference requests.
135+
*/
86136
protected abstract OutputBuilder outputBuilder(Page input);
87137

138+
/**
139+
* An interface for accumulating inference responses and constructing a result {@link Page}.
140+
*/
88141
public interface OutputBuilder extends Releasable {
142+
143+
/**
144+
* Adds an inference response to the output.
145+
* <p>
146+
* The responses must be added in the same order as the corresponding inference requests were generated.
147+
* Failing to preserve order may lead to incorrect or misaligned output rows.
148+
* </p>
149+
*
150+
* @param inferenceResponse The inference response to include.
151+
*/
89152
void addInferenceResponse(InferenceAction.Response inferenceResponse);
90153

154+
/**
155+
* Builds the final output page from accumulated inference responses.
156+
*
157+
* @return The constructed output page.
158+
*/
91159
Page buildOutput();
92160

93161
static <IR extends InferenceServiceResults> IR inferenceResults(InferenceAction.Response inferenceResponse, Class<IR> clazz) {
@@ -102,7 +170,18 @@ static <IR extends InferenceServiceResults> IR inferenceResults(InferenceAction.
102170
}
103171
}
104172

105-
public record OngoingInference(Page inputPage, List<InferenceAction.Response> responses) {
106-
173+
/**
174+
* Represents the result of an ongoing inference operation, including the original input page
175+
* and the list of inference responses.
176+
*
177+
* @param inputPage The input page used to generate inference requests.
178+
* @param responses The inference responses returned by the inference service.
179+
*/
180+
public record OngoingInferenceResult(Page inputPage, List<InferenceAction.Response> responses) implements Releasable {
181+
182+
@Override
183+
public void close() {
184+
releasePageOnAnyThread(inputPage);
185+
}
107186
}
108187
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,53 +17,102 @@
1717

1818
import static org.elasticsearch.index.seqno.SequenceNumbers.NO_OPS_PERFORMED;
1919

20+
/**
21+
* Tracks the state of a bulk inference execution, including sequencing, failure management, and buffering of inference responses for
22+
* ordered output construction.
23+
*/
2024
public class BulkInferenceExecutionState {
2125
private final LocalCheckpointTracker checkpoint = new LocalCheckpointTracker(NO_OPS_PERFORMED, NO_OPS_PERFORMED);
2226
private final FailureCollector failureCollector = new FailureCollector();
23-
private final Map<Long, InferenceAction.Response> bufferedResponses = new ConcurrentHashMap<>();
27+
private final Map<Long, InferenceAction.Response> bufferedResponses;
2428
private final AtomicBoolean finished = new AtomicBoolean(false);
2529

30+
public BulkInferenceExecutionState(int bufferSize) {
31+
this.bufferedResponses = new ConcurrentHashMap<>(bufferSize);
32+
}
33+
34+
/**
35+
* Generates a new unique sequence number for an inference request.
36+
*/
2637
public long generateSeqNo() {
2738
return checkpoint.generateSeqNo();
2839
}
2940

41+
/**
42+
* Returns the highest sequence number marked as persisted, such that all lower sequence numbers have also been marked as persisted.
43+
*/
3044
public long getPersistedCheckpoint() {
3145
return checkpoint.getPersistedCheckpoint();
3246
}
3347

48+
/**
49+
* Returns the highest sequence number marked as processed, such that all lower sequence numbers have also been marked as processed.
50+
*/
3451
public long getProcessedCheckpoint() {
3552
return checkpoint.getProcessedCheckpoint();
3653
}
3754

55+
/**
56+
* Highest generated sequence number.
57+
*/
3858
public long getMaxSeqNo() {
3959
return checkpoint.getMaxSeqNo();
4060
}
4161

62+
/**
63+
* Marks an inference response as persisted.
64+
*
65+
* @param seqNo The corresponding sequence number
66+
*/
67+
public void markSeqNoAsPersisted(long seqNo) {
68+
checkpoint.markSeqNoAsPersisted(seqNo);
69+
}
70+
71+
/**
72+
* Add an inference response to the buffer and marks the corresponding sequence number as processed.
73+
*
74+
* @param seqNo The sequence number of the inference request.
75+
* @param response The inference response.
76+
*/
4277
public synchronized void onInferenceResponse(long seqNo, InferenceAction.Response response) {
4378
if (failureCollector.hasFailure() == false) {
4479
bufferedResponses.put(seqNo, response);
4580
}
4681
checkpoint.markSeqNoAsProcessed(seqNo);
4782
}
4883

84+
/**
85+
* * Handles an exception thrown during inference execution.
86+
* Records the failure and marks the corresponding sequence number as processed.
87+
*
88+
* @param seqNo The sequence number of the inference request.
89+
* @param e The exception
90+
*/
4991
public synchronized void onInferenceException(long seqNo, Exception e) {
5092
failureCollector.unwrapAndCollect(e);
5193
checkpoint.markSeqNoAsProcessed(seqNo);
5294
bufferedResponses.clear();
5395
}
5496

97+
/**
98+
* Retrieves and removes the buffered response by sequence number.
99+
*
100+
* @param seqNo The sequence number of the response to fetch.
101+
*/
55102
public synchronized InferenceAction.Response fetchBufferedResponse(long seqNo) {
56103
return bufferedResponses.remove(seqNo);
57104
}
58105

59-
public void markSeqNoAsPersisted(long seqNo) {
60-
checkpoint.markSeqNoAsPersisted(seqNo);
61-
}
62-
106+
/**
107+
* Returns whether any failure has been recorded during execution.
108+
*/
63109
public boolean hasFailure() {
64110
return failureCollector.hasFailure();
65111
}
66112

113+
/**
114+
* Returns the recorded failure, if any.
115+
*/
67116
public Exception getFailure() {
68117
return failureCollector.getFailure();
69118
}
@@ -72,10 +121,16 @@ public void addFailure(Exception e) {
72121
failureCollector.unwrapAndCollect(e);
73122
}
74123

124+
/**
125+
* Indicates whether the entire bulk execution is marked as finished and all responses have been successfully persisted.
126+
*/
75127
public boolean finished() {
76128
return finished.get() && getMaxSeqNo() == getPersistedCheckpoint();
77129
}
78130

131+
/**
132+
* Marks the bulk as finished, indicating that all inference requests have been sent.
133+
*/
79134
public void finish() {
80135
this.finished.set(true);
81136
}

0 commit comments

Comments
 (0)