Skip to content

Commit 99d9210

Browse files
committed
Refactor FUSE planning
1 parent 16fb8b8 commit 99d9210

File tree

12 files changed

+207
-247
lines changed

12 files changed

+207
-247
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 34 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565
import org.elasticsearch.xpack.esql.expression.function.aggregate.MinOverTime;
6666
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
6767
import org.elasticsearch.xpack.esql.expression.function.aggregate.SumOverTime;
68+
import org.elasticsearch.xpack.esql.expression.function.aggregate.SummationMode;
69+
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
6870
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
6971
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
7072
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
@@ -93,7 +95,6 @@
9395
import org.elasticsearch.xpack.esql.parser.ParsingException;
9496
import org.elasticsearch.xpack.esql.plan.IndexPattern;
9597
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
96-
import org.elasticsearch.xpack.esql.plan.logical.Dedup;
9798
import org.elasticsearch.xpack.esql.plan.logical.Drop;
9899
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
99100
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
@@ -107,8 +108,9 @@
107108
import org.elasticsearch.xpack.esql.plan.logical.MvExpand;
108109
import org.elasticsearch.xpack.esql.plan.logical.Project;
109110
import org.elasticsearch.xpack.esql.plan.logical.Rename;
110-
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
111111
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
112+
import org.elasticsearch.xpack.esql.plan.logical.fuse.Fuse;
113+
import org.elasticsearch.xpack.esql.plan.logical.fuse.FuseScoreEval;
112114
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
113115
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
114116
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
@@ -526,12 +528,8 @@ protected LogicalPlan rule(LogicalPlan plan, AnalyzerContext context) {
526528
return resolveInsist(i, childrenOutput, context.indexResolution());
527529
}
528530

529-
if (plan instanceof Dedup dedup) {
530-
return resolveDedup(dedup, childrenOutput);
531-
}
532-
533-
if (plan instanceof RrfScoreEval rrf) {
534-
return resolveRrfScoreEval(rrf, childrenOutput);
531+
if (plan instanceof Fuse fuse) {
532+
return resolveFuse(fuse, childrenOutput);
535533
}
536534

537535
if (plan instanceof Rerank r) {
@@ -930,52 +928,44 @@ private static FieldAttribute insistKeyword(Attribute attribute) {
930928
);
931929
}
932930

933-
private LogicalPlan resolveDedup(Dedup dedup, List<Attribute> childrenOutput) {
934-
List<NamedExpression> aggregates = dedup.finalAggs();
935-
List<Attribute> groupings = dedup.groupings();
936-
List<NamedExpression> newAggs = new ArrayList<>();
937-
List<Attribute> newGroupings = new ArrayList<>();
938-
939-
for (NamedExpression agg : aggregates) {
940-
var newAgg = (NamedExpression) agg.transformUp(UnresolvedAttribute.class, ua -> {
941-
Expression ne = ua;
942-
Attribute maybeResolved = maybeResolveAttribute(ua, childrenOutput);
943-
if (maybeResolved != null) {
944-
ne = maybeResolved;
945-
}
946-
return ne;
947-
});
948-
newAggs.add(newAgg);
931+
private LogicalPlan resolveFuse(Fuse fuse, List<Attribute> childrenOutput) {
932+
Source source = fuse.source();
933+
Attribute score = fuse.score();
934+
if (score instanceof UnresolvedAttribute) {
935+
score = maybeResolveAttribute((UnresolvedAttribute) score, childrenOutput);
949936
}
950937

951-
for (Attribute attr : groupings) {
952-
if (attr instanceof UnresolvedAttribute ua) {
953-
newGroupings.add(resolveAttribute(ua, childrenOutput));
954-
} else {
955-
newGroupings.add(attr);
956-
}
938+
Attribute discriminator = fuse.discriminator();
939+
if (discriminator instanceof UnresolvedAttribute) {
940+
discriminator = maybeResolveAttribute((UnresolvedAttribute) discriminator, childrenOutput);
957941
}
958942

959-
return new Dedup(dedup.source(), dedup.child(), newAggs, newGroupings);
960-
}
961-
962-
private LogicalPlan resolveRrfScoreEval(RrfScoreEval rrf, List<Attribute> childrenOutput) {
963-
Attribute scoreAttr = rrf.scoreAttribute();
964-
Attribute forkAttr = rrf.forkAttribute();
943+
List<NamedExpression> groupings = fuse.groupings()
944+
.stream()
945+
.map(attr -> attr instanceof UnresolvedAttribute ? maybeResolveAttribute((UnresolvedAttribute) attr, childrenOutput) : attr)
946+
.toList();
965947

966-
if (scoreAttr instanceof UnresolvedAttribute ua) {
967-
scoreAttr = resolveAttribute(ua, childrenOutput);
948+
// some attributes were unresolved - we return Fuse here so that the Verifier can raise an error message
949+
if (score instanceof UnresolvedAttribute || discriminator instanceof UnresolvedAttribute) {
950+
return new Fuse(fuse.source(), fuse.child(), score, discriminator, groupings, fuse.fuseType());
968951
}
969952

970-
if (forkAttr instanceof UnresolvedAttribute ua) {
971-
forkAttr = resolveAttribute(ua, childrenOutput);
972-
}
953+
LogicalPlan scoreEval = new FuseScoreEval(source, fuse.child(), score, discriminator);
954+
955+
// create aggregations
956+
Expression aggFilter = new Literal(source, true, DataType.BOOLEAN);
973957

974-
if (forkAttr != rrf.forkAttribute() || scoreAttr != rrf.scoreAttribute()) {
975-
return new RrfScoreEval(rrf.source(), rrf.child(), scoreAttr, forkAttr);
958+
List<NamedExpression> aggregates = new ArrayList<>();
959+
aggregates.add(new Alias(source, score.name(), new Sum(source, score, aggFilter, SummationMode.COMPENSATED_LITERAL)));
960+
961+
for (Attribute attr : childrenOutput) {
962+
if (attr.name().equals(score.name())) {
963+
continue;
964+
}
965+
aggregates.add(new Alias(source, attr.name(), new Values(source, attr, aggFilter)));
976966
}
977967

978-
return rrf;
968+
return resolveAggregate(new Aggregate(source, scoreEval, new ArrayList<>(groupings), aggregates), childrenOutput);
979969
}
980970

981971
private Attribute maybeResolveAttribute(UnresolvedAttribute ua, List<Attribute> childrenOutput) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
2222
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
2323
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
24-
import org.elasticsearch.xpack.esql.plan.logical.Dedup;
2524
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
2625

2726
import java.io.IOException;
@@ -174,9 +173,7 @@ public boolean equals(Object obj) {
174173
@Override
175174
public BiConsumer<LogicalPlan, Failures> postAnalysisPlanVerification() {
176175
return (p, failures) -> {
177-
// `dedup` for now is not exposed as a command,
178-
// so allowing aggregate functions for dedup explicitly is just an internal implementation detail
179-
if ((p instanceof Aggregate) == false && (p instanceof Dedup) == false) {
176+
if ((p instanceof Aggregate) == false) {
180177
p.expressions().forEach(x -> x.forEachDown(AggregateFunction.class, af -> {
181178
failures.add(fail(af, "aggregate function [{}] not allowed outside STATS command", af.sourceText()));
182179
}));

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,9 @@
4444
import org.elasticsearch.xpack.esql.expression.Order;
4545
import org.elasticsearch.xpack.esql.expression.UnresolvedNamePattern;
4646
import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction;
47-
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
48-
import org.elasticsearch.xpack.esql.expression.function.aggregate.SummationMode;
4947
import org.elasticsearch.xpack.esql.plan.IndexPattern;
5048
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
5149
import org.elasticsearch.xpack.esql.plan.logical.ChangePoint;
52-
import org.elasticsearch.xpack.esql.plan.logical.Dedup;
5350
import org.elasticsearch.xpack.esql.plan.logical.Dissect;
5451
import org.elasticsearch.xpack.esql.plan.logical.Drop;
5552
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
@@ -68,10 +65,10 @@
6865
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
6966
import org.elasticsearch.xpack.esql.plan.logical.Rename;
7067
import org.elasticsearch.xpack.esql.plan.logical.Row;
71-
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
7268
import org.elasticsearch.xpack.esql.plan.logical.Sample;
7369
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
7470
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
71+
import org.elasticsearch.xpack.esql.plan.logical.fuse.Fuse;
7572
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
7673
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
7774
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
@@ -778,19 +775,14 @@ public PlanFactory visitFuseCommand(EsqlBaseParser.FuseCommandContext ctx) {
778775
Source source = source(ctx);
779776
return input -> {
780777
Attribute scoreAttr = new UnresolvedAttribute(source, MetadataAttribute.SCORE);
781-
Attribute forkAttr = new UnresolvedAttribute(source, Fork.FORK_FIELD);
778+
Attribute discriminatorAttr = new UnresolvedAttribute(source, Fork.FORK_FIELD);
782779
Attribute idAttr = new UnresolvedAttribute(source, IdFieldMapper.NAME);
783780
Attribute indexAttr = new UnresolvedAttribute(source, MetadataAttribute.INDEX);
784-
List<NamedExpression> aggregates = List.of(
785-
new Alias(
786-
source,
787-
MetadataAttribute.SCORE,
788-
new Sum(source, scoreAttr, new Literal(source, true, DataType.BOOLEAN), SummationMode.COMPENSATED_LITERAL)
789-
)
790-
);
791-
List<Attribute> groupings = List.of(idAttr, indexAttr);
792781

793-
return new Dedup(source, new RrfScoreEval(source, input, scoreAttr, forkAttr), aggregates, groupings);
782+
List<NamedExpression> groupings = List.of(idAttr, indexAttr);
783+
Fuse.FuseType fuseType = Fuse.FuseType.RRF;
784+
785+
return new Fuse(source, input, scoreAttr, discriminatorAttr, groupings, fuseType);
794786
};
795787
}
796788

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Dedup.java

Lines changed: 0 additions & 111 deletions
This file was deleted.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.logical.fuse;
9+
10+
import org.elasticsearch.common.io.stream.StreamOutput;
11+
import org.elasticsearch.xpack.esql.capabilities.TelemetryAware;
12+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
13+
import org.elasticsearch.xpack.esql.core.expression.Expression;
14+
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
15+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
16+
import org.elasticsearch.xpack.esql.core.tree.Source;
17+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
18+
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
19+
20+
import java.io.IOException;
21+
import java.util.List;
22+
23+
public class Fuse extends UnaryPlan implements TelemetryAware {
24+
private final Attribute score;
25+
private final Attribute discriminator;
26+
private final List<NamedExpression> groupings;
27+
private final FuseType fuseType;
28+
29+
public enum FuseType {
30+
RRF,
31+
LINEAR
32+
};
33+
34+
public Fuse(
35+
Source source,
36+
LogicalPlan child,
37+
Attribute score,
38+
Attribute discriminator,
39+
List<NamedExpression> groupings,
40+
FuseType fuseType
41+
) {
42+
super(source, child);
43+
this.score = score;
44+
this.discriminator = discriminator;
45+
this.groupings = groupings;
46+
this.fuseType = fuseType;
47+
48+
}
49+
50+
@Override
51+
public String getWriteableName() {
52+
throw new UnsupportedOperationException("not serialized");
53+
}
54+
55+
@Override
56+
public void writeTo(StreamOutput out) throws IOException {
57+
throw new UnsupportedOperationException("not serialized");
58+
}
59+
60+
@Override
61+
protected NodeInfo<? extends LogicalPlan> info() {
62+
return NodeInfo.create(this, Fuse::new, child(), score, discriminator, groupings, fuseType);
63+
}
64+
65+
@Override
66+
public UnaryPlan replaceChild(LogicalPlan newChild) {
67+
return new Fuse(source(), newChild, score, discriminator, groupings, fuseType);
68+
}
69+
70+
public List<NamedExpression> groupings() {
71+
return groupings;
72+
}
73+
74+
public Attribute discriminator() {
75+
return discriminator;
76+
}
77+
78+
public Attribute score() {
79+
return score;
80+
}
81+
82+
public FuseType fuseType() {
83+
return fuseType;
84+
}
85+
86+
@Override
87+
public boolean expressionsResolved() {
88+
return score.resolved() && discriminator.resolved() && groupings.stream().allMatch(Expression::resolved);
89+
}
90+
}

0 commit comments

Comments
 (0)