Skip to content

Commit 3bb8ff5

Browse files
committed
Move inference result type check to the InferenceOperator
1 parent 29ff576 commit 3bb8ff5

File tree

3 files changed

+33
-43
lines changed

3 files changed

+33
-43
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/CompletionOperator.java

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import java.util.List;
2525

26-
public class CompletionOperator extends InferenceOperator<Page> {
26+
public class CompletionOperator extends InferenceOperator<Page, ChatCompletionResults> {
2727

2828
public record Factory(InferenceRunner inferenceRunner, String inferenceId, ExpressionEvaluator.Factory promptEvaluatorFactory)
2929
implements
@@ -48,7 +48,7 @@ public CompletionOperator(
4848
String inferenceId,
4949
ExpressionEvaluator promptEvaluator
5050
) {
51-
super(driverContext, inferenceRunner.getThreadContext(), inferenceRunner, inferenceId);
51+
super(driverContext, inferenceRunner.getThreadContext(), inferenceRunner, inferenceId, ChatCompletionResults.class);
5252
this.promptEvaluator = promptEvaluator;
5353
this.blockFactory = driverContext.blockFactory();
5454
}
@@ -61,7 +61,7 @@ protected void performAsync(Page inputPage, ActionListener<Page> listener) {
6161
CountDownActionListener countDownListener = new CountDownActionListener(
6262
inputPage.getPositionCount(),
6363
listener.delegateFailureIgnoreResponseAndWrap(l -> {
64-
try(BytesRefBlock.Builder outputBlockBuilder = blockFactory.newBytesRefBlockBuilder(pageSize)) {
64+
try (BytesRefBlock.Builder outputBlockBuilder = blockFactory.newBytesRefBlockBuilder(pageSize)) {
6565
BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
6666
for (int pos = 0; pos < pageSize; pos++) {
6767
if (responses[pos] == null) {
@@ -91,21 +91,12 @@ protected void performAsync(Page inputPage, ActionListener<Page> listener) {
9191
}
9292

9393
InferenceAction.Request request = InferenceAction.Request.builder(inferenceId(), TaskType.COMPLETION)
94-
.setInput(List.of(promptBuilder.toString())).build();
94+
.setInput(List.of(promptBuilder.toString()))
95+
.build();
9596

9697
doInference(request, countDownListener.delegateFailureAndWrap((l, r) -> {
97-
if (r.getResults() instanceof ChatCompletionResults completionResults) {
98-
responses[currentPos] = completionResults.results().getFirst().content();
99-
l.onResponse(null);
100-
} else {
101-
l.onFailure(new IllegalStateException(
102-
"Inference result has wrong type. Got ["
103-
+ r.getResults().getClass()
104-
+ "] while expecting ["
105-
+ ChatCompletionResults.class
106-
+ "]"
107-
));
108-
}
98+
responses[currentPos] = r.results().getFirst().content();
99+
l.onResponse(null);
109100
}));
110101
}
111102
}
@@ -131,5 +122,4 @@ public Page getOutput() {
131122
public String toString() {
132123
return "CompletionOperator[inference_id=[" + inferenceId() + "]]";
133124
}
134-
135125
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,34 +11,52 @@
1111
import org.elasticsearch.common.util.concurrent.ThreadContext;
1212
import org.elasticsearch.compute.operator.AsyncOperator;
1313
import org.elasticsearch.compute.operator.DriverContext;
14+
import org.elasticsearch.inference.InferenceServiceResults;
1415
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1516

16-
abstract public class InferenceOperator<Fetched> extends AsyncOperator<Fetched> {
17+
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
18+
19+
abstract public class InferenceOperator<Fetched, InferenceResult extends InferenceServiceResults> extends AsyncOperator<Fetched> {
1720

1821
// Move to a setting.
1922
private static final int MAX_INFERENCE_WORKER = 10;
20-
2123
private final InferenceRunner inferenceRunner;
2224
private final String inferenceId;
25+
private final Class<InferenceResult> inferenceResultClass;
2326

2427
public InferenceOperator(
2528
DriverContext driverContext,
2629
ThreadContext threadContext,
2730
InferenceRunner inferenceRunner,
28-
String inferenceId
31+
String inferenceId,
32+
Class<InferenceResult> inferenceResultClass
2933
) {
3034
super(driverContext, threadContext, MAX_INFERENCE_WORKER);
3135
this.inferenceRunner = inferenceRunner;
3236
this.inferenceId = inferenceId;
37+
this.inferenceResultClass = inferenceResultClass;
3338

3439
assert inferenceRunner.getThreadContext() != null;
3540
}
3641

37-
protected void doInference(InferenceAction.Request inferenceRequest, ActionListener<InferenceAction.Response> listener) {
38-
inferenceRunner.doInference(inferenceRequest, listener);
42+
protected final void doInference(InferenceAction.Request inferenceRequest, ActionListener<InferenceResult> listener) {
43+
inferenceRunner.doInference(inferenceRequest, listener.map(this::checkedInferenceResults));
3944
}
4045

4146
protected String inferenceId() {
4247
return inferenceId;
4348
}
49+
50+
private InferenceResult checkedInferenceResults(InferenceAction.Response inferenceResponse) {
51+
if (inferenceResultClass.isInstance(inferenceResponse.getResults())) {
52+
return inferenceResultClass.cast(inferenceResponse.getResults());
53+
}
54+
throw new IllegalStateException(
55+
format(
56+
"Inference result has wrong type. Got [{}] while expecting [{}]",
57+
inferenceResponse.getResults().getClass().getName(),
58+
inferenceResultClass.getName()
59+
)
60+
);
61+
}
4462
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import java.util.List;
2727

28-
public class RerankOperator extends InferenceOperator<Page> {
28+
public class RerankOperator extends InferenceOperator<Page, RankedDocsResults> {
2929
public record Factory(
3030
InferenceRunner inferenceRunner,
3131
String inferenceId,
@@ -65,7 +65,7 @@ public RerankOperator(
6565
ExpressionEvaluator rowEncoder,
6666
int scoreChannel
6767
) {
68-
super(driverContext, inferenceRunner.getThreadContext(), inferenceRunner, inferenceId);
68+
super(driverContext, inferenceRunner.getThreadContext(), inferenceRunner, inferenceId, RankedDocsResults.class);
6969

7070
this.blockFactory = driverContext.blockFactory();
7171
this.queryText = queryText;
@@ -81,10 +81,7 @@ protected void performAsync(Page inputPage, ActionListener<Page> listener) {
8181
try {
8282
doInference(
8383
buildInferenceRequest(inputPage),
84-
ActionListener.wrap(
85-
inferenceResponse -> outputListener.onResponse(buildOutput(inputPage, inferenceResponse)),
86-
outputListener::onFailure
87-
)
84+
outputListener.delegateFailureAndWrap((l, r) -> l.onResponse(buildOutput(inputPage, r)))
8885
);
8986
} catch (Exception e) {
9087
outputListener.onFailure(e);
@@ -111,21 +108,6 @@ public String toString() {
111108
return "RerankOperator[inference_id=[" + inferenceId() + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]";
112109
}
113110

114-
private Page buildOutput(Page inputPage, InferenceAction.Response inferenceResponse) {
115-
if (inferenceResponse.getResults() instanceof RankedDocsResults rankedDocsResults) {
116-
return buildOutput(inputPage, rankedDocsResults);
117-
118-
}
119-
120-
throw new IllegalStateException(
121-
"Inference result has wrong type. Got ["
122-
+ inferenceResponse.getResults().getClass()
123-
+ "] while expecting ["
124-
+ RankedDocsResults.class
125-
+ "]"
126-
);
127-
}
128-
129111
private Page buildOutput(Page inputPage, RankedDocsResults rankedDocsResults) {
130112
int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1);
131113
Block[] blocks = new Block[blockCount];

0 commit comments

Comments
 (0)