Skip to content

Commit b0129a3

Browse files
committed
RereankOperator refactoring
1 parent ef688f7 commit b0129a3

File tree

2 files changed

+119
-120
lines changed

2 files changed

+119
-120
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java

Lines changed: 70 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import org.elasticsearch.compute.data.Page;
1818
import org.elasticsearch.compute.operator.AsyncOperator;
1919
import org.elasticsearch.compute.operator.DriverContext;
20-
import org.elasticsearch.compute.operator.EvalOperator;
20+
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
2121
import org.elasticsearch.compute.operator.Operator;
2222
import org.elasticsearch.core.Releasables;
2323
import org.elasticsearch.inference.TaskType;
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
2828

2929
import java.io.IOException;
30+
import java.io.UncheckedIOException;
3031
import java.util.List;
3132
import java.util.Map;
3233

@@ -39,7 +40,7 @@ public record Factory(
3940
InferenceService inferenceService,
4041
String inferenceId,
4142
String queryText,
42-
Map<String, EvalOperator.ExpressionEvaluator.Factory> rerankFieldsEvaluatorFactories,
43+
Map<String, ExpressionEvaluator.Factory> fieldsEvaluatorFactories,
4344
int scoreChannel
4445
) implements OperatorFactory {
4546

@@ -50,7 +51,7 @@ public String describe() {
5051
+ " query="
5152
+ queryText
5253
+ " rerank_fields="
53-
+ rerankFieldsEvaluatorFactories.keySet()
54+
+ fieldsEvaluatorFactories.keySet()
5455
+ " score_channel="
5556
+ scoreChannel
5657
+ "]";
@@ -63,67 +64,73 @@ public Operator get(DriverContext driverContext) {
6364
inferenceService,
6465
inferenceId,
6566
queryText,
66-
rerankFieldsEvaluatorFactories.keySet().toArray(new String[0]),
67-
rerankFieldsEvaluatorFactories.values()
68-
.stream()
69-
.map(factory -> factory.get(driverContext))
70-
.toArray(EvalOperator.ExpressionEvaluator[]::new),
67+
fieldNames(),
68+
fieldsEvaluators(driverContext),
7169
scoreChannel
7270
);
7371
}
72+
73+
private String[] fieldNames() {
74+
return fieldsEvaluatorFactories.keySet().toArray(String[]::new);
75+
}
76+
77+
private ExpressionEvaluator[] fieldsEvaluators(DriverContext context) {
78+
return fieldsEvaluatorFactories.values().stream().map(factory -> factory.get(context)).toArray(ExpressionEvaluator[]::new);
79+
}
7480
}
7581

7682
private final InferenceService inferenceService;
7783
private final BlockFactory blockFactory;
7884
private final String inferenceId;
7985
private final String queryText;
80-
private final String[] rerankFieldNames;
81-
private final EvalOperator.ExpressionEvaluator[] rerankFieldsEvaluators;
86+
private final String[] fieldNames;
87+
private final ExpressionEvaluator[] fieldsEvaluators;
8288
private final int scoreChannel;
8389

8490
public RerankOperator(
8591
DriverContext driverContext,
8692
InferenceService inferenceService,
8793
String inferenceId,
8894
String queryText,
89-
String[] rerankFieldNames,
90-
EvalOperator.ExpressionEvaluator[] rerankFieldsEvaluators,
95+
String[] fieldNames,
96+
ExpressionEvaluator[] fieldsEvaluators,
9197
int scoreChannel
9298
) {
9399
super(driverContext, inferenceService.getThreadContext(), MAX_INFERENCE_WORKER);
94100

95101
assert inferenceService.getThreadContext() != null;
102+
assert fieldNames.length == fieldsEvaluators.length;
96103

97104
this.blockFactory = driverContext.blockFactory();
98105
this.inferenceService = inferenceService;
99106
this.inferenceId = inferenceId;
100107
this.queryText = queryText;
101-
this.rerankFieldNames = rerankFieldNames;
102-
this.rerankFieldsEvaluators = rerankFieldsEvaluators;
108+
this.fieldNames = fieldNames;
109+
this.fieldsEvaluators = fieldsEvaluators;
103110
this.scoreChannel = scoreChannel;
104111
}
105112

106113
@Override
107114
protected void performAsync(Page inputPage, ActionListener<Page> listener) {
108115
// Ensure input page blocks are released when the listener is called.
109-
final ActionListener<Page> outputListener = ActionListener.runAfter(listener, inputPage::releaseBlocks);
116+
final ActionListener<Page> outputListener = ActionListener.runAfter(listener, () -> { inputPage.releaseBlocks(); });
110117

111-
final ActionListener<InferenceAction.Response> inferenceResonseListener = ActionListener.wrap(
112-
inferenceResponse -> buildOutput(inputPage, inferenceResponse, outputListener),
113-
outputListener::onFailure
114-
);
115-
116-
final ActionListener<InferenceAction.Request> buildInferenceRequestListener = ActionListener.wrap(
117-
(inferenceRequest) -> inferenceService.doInference(inferenceRequest, inferenceResonseListener),
118-
outputListener::onFailure
119-
);
120-
121-
buildInferenceRequest(inputPage, buildInferenceRequestListener);
118+
try {
119+
inferenceService.doInference(
120+
buildInferenceRequest(inputPage),
121+
ActionListener.wrap(
122+
inferenceResponse -> outputListener.onResponse(buildOutput(inputPage, inferenceResponse)),
123+
outputListener::onFailure
124+
)
125+
);
126+
} catch (Exception e) {
127+
outputListener.onFailure(e);
128+
}
122129
}
123130

124131
@Override
125132
protected void doClose() {
126-
Releasables.closeExpectNoException(this.rerankFieldsEvaluators);
133+
Releasables.closeExpectNoException(this.fieldsEvaluators);
127134
}
128135

129136
@Override
@@ -143,30 +150,28 @@ public String toString() {
143150
+ " query="
144151
+ queryText
145152
+ " rerank_fields="
146-
+ List.of(rerankFieldNames)
153+
+ List.of(fieldNames)
147154
+ " score_channel="
148155
+ scoreChannel
149156
+ "]";
150157
}
151158

152-
private void buildOutput(Page inputPage, InferenceAction.Response inferenceResponse, ActionListener<Page> listener) {
159+
private Page buildOutput(Page inputPage, InferenceAction.Response inferenceResponse) {
153160
if (inferenceResponse.getResults() instanceof RankedDocsResults rankedDocsResults) {
154-
buildOutput(inputPage, rankedDocsResults, listener);
155-
return;
161+
return buildOutput(inputPage, rankedDocsResults);
162+
156163
}
157164

158-
listener.onFailure(
159-
new IllegalStateException(
160-
"Inference result has wrong type. Got ["
161-
+ inferenceResponse.getResults().getClass()
162-
+ "] while expecting ["
163-
+ RankedDocsResults.class
164-
+ "]"
165-
)
165+
throw new IllegalStateException(
166+
"Inference result has wrong type. Got ["
167+
+ inferenceResponse.getResults().getClass()
168+
+ "] while expecting ["
169+
+ RankedDocsResults.class
170+
+ "]"
166171
);
167172
}
168173

169-
private void buildOutput(Page inputPage, RankedDocsResults rankedDocsResults, ActionListener<Page> listener) {
174+
private Page buildOutput(Page inputPage, RankedDocsResults rankedDocsResults) {
170175
int blockCount = inputPage.getBlockCount();
171176
Block[] blocks = new Block[blockCount];
172177

@@ -179,9 +184,10 @@ private void buildOutput(Page inputPage, RankedDocsResults rankedDocsResults, Ac
179184
blocks[b].incRef();
180185
}
181186
}
182-
listener.onResponse(new Page(blocks));
187+
return new Page(blocks);
183188
} catch (Exception e) {
184189
Releasables.closeExpectNoException(blocks);
190+
throw (e);
185191
}
186192
}
187193

@@ -205,51 +211,34 @@ private Block buildScoreBlock(Page inputPage, RankedDocsResults rankedDocsResult
205211
}
206212
}
207213

208-
private void buildInferenceRequest(Page inputPage, ActionListener<InferenceAction.Request> listener) {
209-
210-
Block[] inputBlocks = inputBlocks(inputPage);
214+
private InferenceAction.Request buildInferenceRequest(Page inputPage) {
215+
Block[] inputBlocks = new Block[fieldsEvaluators.length];
211216

212217
try {
213-
String[] inputs = new String[inputPage.getPositionCount()];
214-
if (inputBlocks.length > 0) for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
215-
inputs[pos] = toYaml(inputBlocks, pos);
218+
for (int b = 0; b < inputBlocks.length; b++) {
219+
inputBlocks[b] = fieldsEvaluators[b].eval(inputPage);
216220
}
217-
listener.onResponse(
218-
InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build()
219-
);
220-
} catch (Exception e) {
221-
listener.onFailure(e);
222-
} finally {
223-
Releasables.closeExpectNoException(inputBlocks);
224-
}
225-
}
226221

227-
private Block[] inputBlocks(Page inputPage) {
228-
Block[] blocks = new Block[rerankFieldsEvaluators.length];
229-
230-
try {
231-
for (int i = 0; i < rerankFieldsEvaluators.length; i++) {
232-
blocks[i] = rerankFieldsEvaluators[i].eval(inputPage);
233-
}
234-
235-
return blocks;
236-
} catch (Exception e) {
237-
Releasables.closeExpectNoException(blocks);
238-
throw e;
239-
}
240-
}
241-
242-
private String toYaml(Block[] inputBlocks, int position) throws IOException {
243-
try (XContentBuilder yamlBuilder = XContentFactory.yamlBuilder().startObject()) {
244-
for (int i = 0; i < inputBlocks.length; i++) {
245-
String fieldName = rerankFieldNames[i];
246-
Block currentBlock = inputBlocks[i];
247-
if (currentBlock.isNull(position)) {
248-
continue;
222+
String[] inputs = new String[inputPage.getPositionCount()];
223+
for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
224+
try (XContentBuilder yamlBuilder = XContentFactory.yamlBuilder().startObject()) {
225+
for (int i = 0; i < inputBlocks.length; i++) {
226+
String fieldName = fieldNames[i];
227+
Block currentBlock = inputBlocks[i];
228+
if (currentBlock.isNull(pos)) {
229+
continue;
230+
}
231+
yamlBuilder.field(fieldName, toYaml(BlockUtils.toJavaObject(currentBlock, pos)));
232+
}
233+
inputs[pos] = Strings.toString(yamlBuilder.endObject());
234+
} catch (IOException e) {
235+
throw new UncheckedIOException(e);
249236
}
250-
yamlBuilder.field(fieldName, toYaml(BlockUtils.toJavaObject(currentBlock, position)));
251237
}
252-
return Strings.toString(yamlBuilder.endObject());
238+
239+
return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build();
240+
} finally {
241+
Releasables.closeExpectNoException(inputBlocks);
253242
}
254243
}
255244

@@ -262,7 +251,7 @@ private Object toYaml(Object value) {
262251
};
263252
} catch (Error | Exception e) {
264253
// Swallow errors caused by invalid byteref.
265-
return null;
254+
return "";
266255
}
267256
}
268257
}

0 commit comments

Comments
 (0)