Skip to content

Commit 9069561

Browse files
committed
Adding physical plan and operator for Rerank
1 parent 7787efe commit 9069561

File tree

15 files changed

+403
-3
lines changed

15 files changed

+403
-3
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,7 @@ private LogicalPlan resolveRerank(Rerank rerank, List<Attribute> childOutput) {
731731
boolean changed = false;
732732
for (Alias field : rerank.rerankFields()) {
733733
Alias result = (Alias) field.transformUp(UnresolvedAttribute.class, ua -> resolveAttribute(ua, childOutput));
734+
newFields.add(result);
734735
changed |= result != field;
735736
}
736737

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.client.internal.Client;
1212
import org.elasticsearch.client.internal.OriginSettingClient;
13+
import org.elasticsearch.common.util.concurrent.ThreadContext;
1314
import org.elasticsearch.inference.TaskType;
1415
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
1516
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
@@ -30,6 +31,10 @@ public InferenceService(Client client) {
3031
this.client = new OriginSettingClient(client, ML_ORIGIN);
3132
}
3233

34+
public ThreadContext getThreadContext() {
35+
return client.threadPool().getThreadContext();
36+
}
37+
3338
public void resolveInferences(List<InferencePlan> plans, ActionListener<InferenceResolution> listener) {
3439

3540
if (plans.isEmpty()) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
4949
import org.elasticsearch.xpack.esql.plan.physical.SubqueryExec;
5050
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
51+
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
5152

5253
import java.util.ArrayList;
5354
import java.util.List;
@@ -105,6 +106,7 @@ public static List<NamedWriteableRegistry.Entry> physical() {
105106
LocalSourceExec.ENTRY,
106107
MvExpandExec.ENTRY,
107108
ProjectExec.ENTRY,
109+
RerankExec.ENTRY,
108110
ShowExec.ENTRY,
109111
SubqueryExec.ENTRY,
110112
TopNExec.ENTRY
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.physical.inference;
9+
10+
import org.elasticsearch.common.io.stream.StreamOutput;
11+
import org.elasticsearch.xpack.esql.core.expression.Expression;
12+
import org.elasticsearch.xpack.esql.core.tree.Source;
13+
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
14+
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
15+
16+
import java.io.IOException;
17+
import java.util.Objects;
18+
19+
public abstract class InferenceExec extends UnaryExec {
20+
private final Expression inferenceId;
21+
22+
protected InferenceExec(Source source, PhysicalPlan child, Expression inferenceId) {
23+
super(source, child);
24+
this.inferenceId = inferenceId;
25+
}
26+
27+
public Expression inferenceId() {
28+
return inferenceId;
29+
}
30+
31+
@Override
32+
public void writeTo(StreamOutput out) throws IOException {
33+
Source.EMPTY.writeTo(out);
34+
out.writeNamedWriteable(child());
35+
out.writeNamedWriteable(inferenceId());
36+
}
37+
38+
@Override
39+
public boolean equals(Object o) {
40+
if (this == o) return true;
41+
if (o == null || getClass() != o.getClass()) return false;
42+
if (super.equals(o) == false) return false;
43+
InferenceExec that = (InferenceExec) o;
44+
return inferenceId.equals(that.inferenceId);
45+
}
46+
47+
@Override
48+
public int hashCode() {
49+
return Objects.hash(super.hashCode(), inferenceId());
50+
}
51+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.physical.inference;
9+
10+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.xpack.esql.core.expression.Alias;
14+
import org.elasticsearch.xpack.esql.core.expression.Expression;
15+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
16+
import org.elasticsearch.xpack.esql.core.tree.Source;
17+
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
18+
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
19+
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
20+
21+
import java.io.IOException;
22+
import java.util.List;
23+
import java.util.Objects;
24+
25+
public class RerankExec extends InferenceExec {
26+
27+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
28+
PhysicalPlan.class,
29+
"RerankExec",
30+
RerankExec::new
31+
);
32+
33+
private final Expression queryText;
34+
private final List<Alias> rerankFields;
35+
36+
public RerankExec(Source source, PhysicalPlan child, Expression inferenceId, Expression queryText, List<Alias> rerankFields) {
37+
super(source, child, inferenceId);
38+
this.queryText = queryText;
39+
this.rerankFields = rerankFields;
40+
}
41+
42+
public RerankExec(StreamInput in) throws IOException {
43+
this(
44+
Source.readFrom((PlanStreamInput) in),
45+
in.readNamedWriteable(PhysicalPlan.class),
46+
in.readNamedWriteable(Expression.class),
47+
in.readNamedWriteable(Expression.class),
48+
in.readCollectionAsList(Alias::new)
49+
);
50+
}
51+
52+
public Expression queryText() {
53+
return queryText;
54+
}
55+
56+
public List<Alias> rerankFields() {
57+
return rerankFields;
58+
}
59+
60+
@Override
61+
public String getWriteableName() {
62+
return ENTRY.name;
63+
}
64+
65+
@Override
66+
public void writeTo(StreamOutput out) throws IOException {
67+
super.writeTo(out);
68+
out.writeNamedWriteable(queryText());
69+
out.writeCollection(rerankFields());
70+
}
71+
72+
@Override
73+
protected NodeInfo<? extends PhysicalPlan> info() {
74+
return NodeInfo.create(this, RerankExec::new, child(), inferenceId(), queryText, rerankFields);
75+
}
76+
77+
@Override
78+
public UnaryExec replaceChild(PhysicalPlan newChild) {
79+
return new RerankExec(source(), newChild, inferenceId(), queryText, rerankFields);
80+
}
81+
82+
@Override
83+
public boolean equals(Object o) {
84+
if (this == o) return true;
85+
if (o == null || getClass() != o.getClass()) return false;
86+
if (super.equals(o) == false) return false;
87+
RerankExec rerank = (RerankExec) o;
88+
return Objects.equals(queryText, rerank.queryText) && Objects.equals(rerankFields, rerank.rerankFields);
89+
}
90+
91+
@Override
92+
public int hashCode() {
93+
return Objects.hash(super.hashCode(), queryText, rerankFields);
94+
}
95+
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.physical.inference;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.compute.data.BlockFactory;
12+
import org.elasticsearch.compute.data.Page;
13+
import org.elasticsearch.compute.operator.AsyncOperator;
14+
import org.elasticsearch.compute.operator.DriverContext;
15+
import org.elasticsearch.compute.operator.EvalOperator;
16+
import org.elasticsearch.compute.operator.Operator;
17+
import org.elasticsearch.logging.LogManager;
18+
import org.elasticsearch.logging.Logger;
19+
import org.elasticsearch.xpack.esql.inference.InferenceService;
20+
21+
import java.util.HashMap;
22+
import java.util.Map;
23+
24+
public class RerankOperator extends AsyncOperator<Page> {
25+
26+
// Move to a setting.
27+
private static final int MAX_INFERENCE_WORKER = 10;
28+
29+
private static final Logger logger = LogManager.getLogger(RerankOperator.class);
30+
31+
public record Factory(
32+
InferenceService inferenceService,
33+
String inferenceId,
34+
String queryText,
35+
Map<String, EvalOperator.ExpressionEvaluator.Factory> rerankFieldsEvaluatorSuppliers,
36+
int scoreChannel
37+
) implements OperatorFactory {
38+
39+
@Override
40+
public String describe() {
41+
return "RerankOperator[inference_id="
42+
+ inferenceId
43+
+ " query="
44+
+ queryText
45+
+ " rerank_fields="
46+
+ rerankFieldsEvaluatorSuppliers.keySet()
47+
+ " scoreChannel="
48+
+ scoreChannel
49+
+ "]";
50+
}
51+
52+
@Override
53+
public Operator get(DriverContext driverContext) {
54+
return new RerankOperator(
55+
driverContext,
56+
inferenceService,
57+
inferenceId,
58+
queryText,
59+
buildRerankFieldEvaluator(rerankFieldsEvaluatorSuppliers, driverContext),
60+
scoreChannel
61+
);
62+
}
63+
64+
private Map<String, EvalOperator.ExpressionEvaluator> buildRerankFieldEvaluator(
65+
Map<String, EvalOperator.ExpressionEvaluator.Factory> rerankFieldsEvaluatorSuppliers,
66+
DriverContext driverContext
67+
) {
68+
Map<String, EvalOperator.ExpressionEvaluator> rerankFieldsEvaluators = new HashMap<>();
69+
70+
for (var entry : rerankFieldsEvaluatorSuppliers.entrySet()) {
71+
rerankFieldsEvaluators.put(entry.getKey(), entry.getValue().get(driverContext));
72+
}
73+
74+
return rerankFieldsEvaluators;
75+
}
76+
}
77+
78+
private final InferenceService inferenceService;
79+
private final BlockFactory blockFactory;
80+
private final String inferenceId;
81+
private final String queryText;
82+
private final Map<String, EvalOperator.ExpressionEvaluator> rerankFieldsEvaluator;
83+
private final int scoreChannel;
84+
85+
public RerankOperator(
86+
DriverContext driverContext,
87+
InferenceService inferenceService,
88+
String inferenceId,
89+
String queryText,
90+
Map<String, EvalOperator.ExpressionEvaluator> rerankFieldsEvaluator,
91+
int scoreChannel
92+
) {
93+
super(driverContext, inferenceService.getThreadContext(), MAX_INFERENCE_WORKER);
94+
this.blockFactory = driverContext.blockFactory();
95+
this.inferenceService = inferenceService;
96+
this.inferenceId = inferenceId;
97+
this.queryText = queryText;
98+
this.rerankFieldsEvaluator = rerankFieldsEvaluator;
99+
this.scoreChannel = scoreChannel;
100+
}
101+
102+
@Override
103+
protected void performAsync(Page inputPage, ActionListener<Page> listener) {
104+
listener.onResponse(inputPage);
105+
}
106+
107+
@Override
108+
protected void doClose() {
109+
110+
}
111+
112+
@Override
113+
protected void releaseFetchedOnAnyThread(Page page) {
114+
releasePageOnAnyThread(page);
115+
}
116+
117+
@Override
118+
public Page getOutput() {
119+
return fetchFromBuffer();
120+
}
121+
}

0 commit comments

Comments
 (0)