|
65 | 65 | import org.elasticsearch.xpack.esql.expression.function.aggregate.MinOverTime;
|
66 | 66 | import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
|
67 | 67 | import org.elasticsearch.xpack.esql.expression.function.aggregate.SumOverTime;
|
| 68 | +import org.elasticsearch.xpack.esql.expression.function.aggregate.SummationMode; |
| 69 | +import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; |
68 | 70 | import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
|
69 | 71 | import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
|
70 | 72 | import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
|
|
93 | 95 | import org.elasticsearch.xpack.esql.parser.ParsingException;
|
94 | 96 | import org.elasticsearch.xpack.esql.plan.IndexPattern;
|
95 | 97 | import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
|
96 |
| -import org.elasticsearch.xpack.esql.plan.logical.Dedup; |
97 | 98 | import org.elasticsearch.xpack.esql.plan.logical.Drop;
|
98 | 99 | import org.elasticsearch.xpack.esql.plan.logical.Enrich;
|
99 | 100 | import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
|
|
107 | 108 | import org.elasticsearch.xpack.esql.plan.logical.MvExpand;
|
108 | 109 | import org.elasticsearch.xpack.esql.plan.logical.Project;
|
109 | 110 | import org.elasticsearch.xpack.esql.plan.logical.Rename;
|
110 |
| -import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval; |
111 | 111 | import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
|
| 112 | +import org.elasticsearch.xpack.esql.plan.logical.fuse.Fuse; |
| 113 | +import org.elasticsearch.xpack.esql.plan.logical.fuse.FuseScoreEval; |
112 | 114 | import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
|
113 | 115 | import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
|
114 | 116 | import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
|
@@ -525,12 +527,8 @@ protected LogicalPlan rule(LogicalPlan plan, AnalyzerContext context) {
|
525 | 527 | return resolveInsist(i, childrenOutput, context.indexResolution());
|
526 | 528 | }
|
527 | 529 |
|
528 |
| - if (plan instanceof Dedup dedup) { |
529 |
| - return resolveDedup(dedup, childrenOutput); |
530 |
| - } |
531 |
| - |
532 |
| - if (plan instanceof RrfScoreEval rrf) { |
533 |
| - return resolveRrfScoreEval(rrf, childrenOutput); |
| 530 | + if (plan instanceof Fuse fuse) { |
| 531 | + return resolveFuse(fuse, childrenOutput); |
534 | 532 | }
|
535 | 533 |
|
536 | 534 | if (plan instanceof Rerank r) {
|
@@ -929,52 +927,44 @@ private static FieldAttribute insistKeyword(Attribute attribute) {
|
929 | 927 | );
|
930 | 928 | }
|
931 | 929 |
|
932 |
| - private LogicalPlan resolveDedup(Dedup dedup, List<Attribute> childrenOutput) { |
933 |
| - List<NamedExpression> aggregates = dedup.finalAggs(); |
934 |
| - List<Attribute> groupings = dedup.groupings(); |
935 |
| - List<NamedExpression> newAggs = new ArrayList<>(); |
936 |
| - List<Attribute> newGroupings = new ArrayList<>(); |
937 |
| - |
938 |
| - for (NamedExpression agg : aggregates) { |
939 |
| - var newAgg = (NamedExpression) agg.transformUp(UnresolvedAttribute.class, ua -> { |
940 |
| - Expression ne = ua; |
941 |
| - Attribute maybeResolved = maybeResolveAttribute(ua, childrenOutput); |
942 |
| - if (maybeResolved != null) { |
943 |
| - ne = maybeResolved; |
944 |
| - } |
945 |
| - return ne; |
946 |
| - }); |
947 |
| - newAggs.add(newAgg); |
| 930 | + private LogicalPlan resolveFuse(Fuse fuse, List<Attribute> childrenOutput) { |
| 931 | + Source source = fuse.source(); |
| 932 | + Attribute score = fuse.score(); |
| 933 | + if (score instanceof UnresolvedAttribute) { |
| 934 | + score = maybeResolveAttribute((UnresolvedAttribute) score, childrenOutput); |
948 | 935 | }
|
949 | 936 |
|
950 |
| - for (Attribute attr : groupings) { |
951 |
| - if (attr instanceof UnresolvedAttribute ua) { |
952 |
| - newGroupings.add(resolveAttribute(ua, childrenOutput)); |
953 |
| - } else { |
954 |
| - newGroupings.add(attr); |
955 |
| - } |
| 937 | + Attribute discriminator = fuse.discriminator(); |
| 938 | + if (discriminator instanceof UnresolvedAttribute) { |
| 939 | + discriminator = maybeResolveAttribute((UnresolvedAttribute) discriminator, childrenOutput); |
956 | 940 | }
|
957 | 941 |
|
958 |
| - return new Dedup(dedup.source(), dedup.child(), newAggs, newGroupings); |
959 |
| - } |
960 |
| - |
961 |
| - private LogicalPlan resolveRrfScoreEval(RrfScoreEval rrf, List<Attribute> childrenOutput) { |
962 |
| - Attribute scoreAttr = rrf.scoreAttribute(); |
963 |
| - Attribute forkAttr = rrf.forkAttribute(); |
| 942 | + List<NamedExpression> groupings = fuse.groupings() |
| 943 | + .stream() |
| 944 | + .map(attr -> attr instanceof UnresolvedAttribute ? maybeResolveAttribute((UnresolvedAttribute) attr, childrenOutput) : attr) |
| 945 | + .toList(); |
964 | 946 |
|
965 |
| - if (scoreAttr instanceof UnresolvedAttribute ua) { |
966 |
| - scoreAttr = resolveAttribute(ua, childrenOutput); |
| 947 | + // some attributes were unresolved - we return Fuse here so that the Verifier can raise an error message |
| 948 | + if (score instanceof UnresolvedAttribute || discriminator instanceof UnresolvedAttribute) { |
| 949 | + return new Fuse(fuse.source(), fuse.child(), score, discriminator, groupings, fuse.fuseType()); |
967 | 950 | }
|
968 | 951 |
|
969 |
| - if (forkAttr instanceof UnresolvedAttribute ua) { |
970 |
| - forkAttr = resolveAttribute(ua, childrenOutput); |
971 |
| - } |
| 952 | + LogicalPlan scoreEval = new FuseScoreEval(source, fuse.child(), score, discriminator); |
| 953 | + |
| 954 | + // create aggregations |
| 955 | + Expression aggFilter = new Literal(source, true, DataType.BOOLEAN); |
972 | 956 |
|
973 |
| - if (forkAttr != rrf.forkAttribute() || scoreAttr != rrf.scoreAttribute()) { |
974 |
| - return new RrfScoreEval(rrf.source(), rrf.child(), scoreAttr, forkAttr); |
| 957 | + List<NamedExpression> aggregates = new ArrayList<>(); |
| 958 | + aggregates.add(new Alias(source, score.name(), new Sum(source, score, aggFilter, SummationMode.COMPENSATED_LITERAL))); |
| 959 | + |
| 960 | + for (Attribute attr : childrenOutput) { |
| 961 | + if (attr.name().equals(score.name())) { |
| 962 | + continue; |
| 963 | + } |
| 964 | + aggregates.add(new Alias(source, attr.name(), new Values(source, attr, aggFilter))); |
975 | 965 | }
|
976 | 966 |
|
977 |
| - return rrf; |
| 967 | + return resolveAggregate(new Aggregate(source, scoreEval, new ArrayList<>(groupings), aggregates), childrenOutput); |
978 | 968 | }
|
979 | 969 |
|
980 | 970 | private Attribute maybeResolveAttribute(UnresolvedAttribute ua, List<Attribute> childrenOutput) {
|
|
0 commit comments