Skip to content

Commit 47e9f1d

Browse files
committed
Reranker multiple fields implementation.
1 parent 7e1f367 commit 47e9f1d

File tree

10 files changed

+277
-195
lines changed

10 files changed

+277
-195
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99

1010
import org.apache.lucene.util.BytesRef;
1111
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.common.Strings;
1213
import org.elasticsearch.compute.data.Block;
1314
import org.elasticsearch.compute.data.BlockFactory;
14-
import org.elasticsearch.compute.data.BytesRefBlock;
15+
import org.elasticsearch.compute.data.BlockUtils;
1516
import org.elasticsearch.compute.data.DoubleBlock;
1617
import org.elasticsearch.compute.data.ElementType;
1718
import org.elasticsearch.compute.data.Page;
@@ -21,10 +22,15 @@
2122
import org.elasticsearch.inference.TaskType;
2223
import org.elasticsearch.logging.LogManager;
2324
import org.elasticsearch.logging.Logger;
25+
import org.elasticsearch.xcontent.XContentBuilder;
26+
import org.elasticsearch.xcontent.XContentFactory;
2427
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2528
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
2629

30+
import java.io.IOException;
31+
import java.util.HashMap;
2732
import java.util.List;
33+
import java.util.Map;
2834

2935
public class RerankOperator extends AsyncOperator<Page> {
3036

@@ -34,23 +40,39 @@ public record Factory(
3440
InferenceService inferenceService,
3541
String inferenceId,
3642
String queryText,
37-
EvalOperator.ExpressionEvaluator.Factory inputEvaluatorFactory,
43+
Map<String, EvalOperator.ExpressionEvaluator.Factory> rerankFieldsEvaluatorSuppliers,
3844
int scoreChannel,
3945
int maxOutstandingRequests
4046
) implements OperatorFactory {
4147
@Override
4248
public RerankOperator get(DriverContext driverContext) {
49+
50+
4351
return new RerankOperator(
4452
inferenceService,
4553
inferenceId,
4654
queryText,
47-
inputEvaluatorFactory.get(driverContext),
55+
buildRerankFieldEvaluator(rerankFieldsEvaluatorSuppliers, driverContext),
4856
scoreChannel,
4957
driverContext,
5058
maxOutstandingRequests
5159
);
5260
}
5361

62+
63+
private Map<String, EvalOperator.ExpressionEvaluator> buildRerankFieldEvaluator(
64+
Map<String, EvalOperator.ExpressionEvaluator.Factory> rerankFieldsEvaluatorSuppliers,
65+
DriverContext driverContext
66+
) {
67+
Map<String, EvalOperator.ExpressionEvaluator> rerankFieldsEvaluators = new HashMap<>();
68+
69+
for (var entry: rerankFieldsEvaluatorSuppliers.entrySet()) {
70+
rerankFieldsEvaluators.put(entry.getKey(), entry.getValue().get(driverContext));
71+
}
72+
73+
return rerankFieldsEvaluators;
74+
}
75+
5476
@Override
5577
public String describe() {
5678
return "RerankOperator[maxOutstandingRequests = " + maxOutstandingRequests + "]";
@@ -61,14 +83,14 @@ public String describe() {
6183
private final BlockFactory blockFactory;
6284
private final String inferenceId;
6385
private final String queryText;
64-
private final EvalOperator.ExpressionEvaluator inputEvaluator;
86+
private final Map<String, EvalOperator.ExpressionEvaluator> rerankFieldsEvaluator;
6587
private final int scoreChannel;
6688

6789
public RerankOperator(
6890
InferenceService inferenceService,
6991
String inferenceId,
7092
String queryText,
71-
EvalOperator.ExpressionEvaluator inputEvaluator,
93+
Map<String, EvalOperator.ExpressionEvaluator> rerankFieldsEvaluator,
7294
int scoreChannel,
7395
DriverContext driverContext,
7496
int maxOutstandingRequests
@@ -78,7 +100,7 @@ public RerankOperator(
78100
this.blockFactory = driverContext.blockFactory();
79101
this.inferenceId = inferenceId;
80102
this.queryText = queryText;
81-
this.inputEvaluator = inputEvaluator;
103+
this.rerankFieldsEvaluator = rerankFieldsEvaluator;
82104
this.scoreChannel = scoreChannel;
83105
}
84106

@@ -100,9 +122,13 @@ protected void performAsync(Page inputPage, ActionListener<Page> listener) {
100122
queryText,
101123
inputPage.getPositionCount()
102124
);
103-
inferenceService.infer(buildInferenceRequest(inputPage), ActionListener.wrap((inferenceResponse) -> {
104-
listener.onResponse(buildOutput(inputPage, inferenceResponse));
105-
}, listener::onFailure));
125+
try {
126+
inferenceService.infer(buildInferenceRequest(inputPage), ActionListener.wrap((inferenceResponse) -> {
127+
listener.onResponse(buildOutput(inputPage, inferenceResponse));
128+
}, listener::onFailure));
129+
} catch (IOException e) {
130+
listener.onFailure(e);
131+
}
106132
}
107133

108134
@Override
@@ -151,19 +177,40 @@ private Page buildOutput(Page inputPage, InferenceAction.Response inferenceRespo
151177
);
152178
}
153179

154-
private InferenceAction.Request buildInferenceRequest(Page inputPage) {
155-
BytesRef scratch = new BytesRef();
180+
private InferenceAction.Request buildInferenceRequest(Page inputPage) throws IOException {
156181
String[] inputs = new String[inputPage.getPositionCount()];
157-
BytesRefBlock inputBlock = (BytesRefBlock) inputEvaluator.eval(inputPage);
182+
Map<String, Block> inputBlocks = new HashMap<>();
183+
184+
185+
for (var entry :rerankFieldsEvaluator.entrySet()) {
186+
inputBlocks.put(entry.getKey(), entry.getValue().eval(inputPage));
187+
};
158188

159189
for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
160-
if (inputBlock.isNull(pos) || inputBlock.getValueCount(pos) > 1) {
161-
inputs[pos] = "";
162-
} else {
163-
inputs[pos] = inputBlock.getBytesRef(pos, scratch).utf8ToString();
190+
try (XContentBuilder yamlBuilder = XContentFactory.yamlBuilder().startObject()) {
191+
for (var blockEntry: inputBlocks.entrySet()) {
192+
String fieldName = blockEntry.getKey();
193+
Block currentBlock = blockEntry.getValue();
194+
if (currentBlock.isNull(pos)) {
195+
continue;
196+
}
197+
Object value = BlockUtils.toJavaObject(currentBlock, pos);
198+
yamlBuilder.field(fieldName, toYamlValue(value));
199+
}
200+
yamlBuilder.endObject();
201+
inputs[pos] = Strings.toString(yamlBuilder);
164202
}
165203
}
166204

205+
167206
return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build();
168207
}
208+
209+
private Object toYamlValue(Object value) {
210+
return switch (value) {
211+
case BytesRef b -> b.utf8ToString();
212+
case List<?> l -> l.stream().map(this::toYamlValue);
213+
default -> value;
214+
};
215+
}
169216
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.interp

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)