Skip to content

Commit db55fec

Browse files
committed
InferenceOperator refactoring.
1 parent 6f622e8 commit db55fec

File tree

2 files changed

+49
-16
lines changed

2 files changed

+49
-16
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.common.util.concurrent.ThreadContext;
12+
import org.elasticsearch.compute.operator.AsyncOperator;
13+
import org.elasticsearch.compute.operator.DriverContext;
14+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
15+
16+
abstract public class InferenceOperator<Fetched> extends AsyncOperator<Fetched> {
17+
18+
// Move to a setting.
19+
private static final int MAX_INFERENCE_WORKER = 10;
20+
21+
private final InferenceRunner inferenceRunner;
22+
private final String inferenceId;
23+
24+
public InferenceOperator(
25+
DriverContext driverContext,
26+
ThreadContext threadContext,
27+
InferenceRunner inferenceRunner,
28+
String inferenceId
29+
) {
30+
super(driverContext, threadContext, MAX_INFERENCE_WORKER);
31+
this.inferenceRunner = inferenceRunner;
32+
this.inferenceId = inferenceId;
33+
34+
assert inferenceRunner.getThreadContext() != null;
35+
}
36+
37+
protected void doInference(InferenceAction.Request inferenceRequest, ActionListener<InferenceAction.Response> listener) {
38+
inferenceRunner.doInference(inferenceRequest, listener);
39+
}
40+
41+
protected String inferenceId() {
42+
return inferenceId;
43+
}
44+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import org.elasticsearch.compute.data.BytesRefBlock;
1616
import org.elasticsearch.compute.data.DoubleBlock;
1717
import org.elasticsearch.compute.data.Page;
18-
import org.elasticsearch.compute.operator.AsyncOperator;
1918
import org.elasticsearch.compute.operator.DriverContext;
2019
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
2120
import org.elasticsearch.compute.operator.Operator;
@@ -26,11 +25,7 @@
2625

2726
import java.util.List;
2827

29-
public class RerankOperator extends AsyncOperator<Page> {
30-
31-
// Move to a setting.
32-
private static final int MAX_INFERENCE_WORKER = 10;
33-
28+
public class RerankOperator extends InferenceOperator<Page> {
3429
public record Factory(
3530
InferenceRunner inferenceRunner,
3631
String inferenceId,
@@ -57,9 +52,7 @@ public Operator get(DriverContext driverContext) {
5752
}
5853
}
5954

60-
private final InferenceRunner inferenceRunner;
6155
private final BlockFactory blockFactory;
62-
private final String inferenceId;
6356
private final String queryText;
6457
private final ExpressionEvaluator rowEncoder;
6558
private final int scoreChannel;
@@ -72,13 +65,9 @@ public RerankOperator(
7265
ExpressionEvaluator rowEncoder,
7366
int scoreChannel
7467
) {
75-
super(driverContext, inferenceRunner.getThreadContext(), MAX_INFERENCE_WORKER);
76-
77-
assert inferenceRunner.getThreadContext() != null;
68+
super(driverContext, inferenceRunner.getThreadContext(), inferenceRunner, inferenceId);
7869

7970
this.blockFactory = driverContext.blockFactory();
80-
this.inferenceRunner = inferenceRunner;
81-
this.inferenceId = inferenceId;
8271
this.queryText = queryText;
8372
this.rowEncoder = rowEncoder;
8473
this.scoreChannel = scoreChannel;
@@ -90,7 +79,7 @@ protected void performAsync(Page inputPage, ActionListener<Page> listener) {
9079
final ActionListener<Page> outputListener = ActionListener.runAfter(listener, () -> { releasePageOnAnyThread(inputPage); });
9180

9281
try {
93-
inferenceRunner.doInference(
82+
doInference(
9483
buildInferenceRequest(inputPage),
9584
ActionListener.wrap(
9685
inferenceResponse -> outputListener.onResponse(buildOutput(inputPage, inferenceResponse)),
@@ -119,7 +108,7 @@ public Page getOutput() {
119108

120109
@Override
121110
public String toString() {
122-
return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]";
111+
return "RerankOperator[inference_id=[" + inferenceId() + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]";
123112
}
124113

125114
private Page buildOutput(Page inputPage, InferenceAction.Response inferenceResponse) {
@@ -192,7 +181,7 @@ private InferenceAction.Request buildInferenceRequest(Page inputPage) {
192181
}
193182
}
194183

195-
return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build();
184+
return InferenceAction.Request.builder(inferenceId(), TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build();
196185
}
197186
}
198187
}

0 commit comments

Comments
 (0)