|
65 | 65 | import org.elasticsearch.xpack.esql.expression.function.aggregate.MinOverTime; |
66 | 66 | import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; |
67 | 67 | import org.elasticsearch.xpack.esql.expression.function.aggregate.SumOverTime; |
| 68 | +import org.elasticsearch.xpack.esql.expression.function.aggregate.SummationMode; |
| 69 | +import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; |
68 | 70 | import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction; |
69 | 71 | import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; |
70 | 72 | import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; |
|
93 | 95 | import org.elasticsearch.xpack.esql.parser.ParsingException; |
94 | 96 | import org.elasticsearch.xpack.esql.plan.IndexPattern; |
95 | 97 | import org.elasticsearch.xpack.esql.plan.logical.Aggregate; |
96 | | -import org.elasticsearch.xpack.esql.plan.logical.Dedup; |
97 | 98 | import org.elasticsearch.xpack.esql.plan.logical.Drop; |
98 | 99 | import org.elasticsearch.xpack.esql.plan.logical.Enrich; |
99 | 100 | import org.elasticsearch.xpack.esql.plan.logical.EsRelation; |
|
107 | 108 | import org.elasticsearch.xpack.esql.plan.logical.MvExpand; |
108 | 109 | import org.elasticsearch.xpack.esql.plan.logical.Project; |
109 | 110 | import org.elasticsearch.xpack.esql.plan.logical.Rename; |
110 | | -import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval; |
111 | 111 | import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation; |
| 112 | +import org.elasticsearch.xpack.esql.plan.logical.fuse.Fuse; |
| 113 | +import org.elasticsearch.xpack.esql.plan.logical.fuse.FuseScoreEval; |
112 | 114 | import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; |
113 | 115 | import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; |
114 | 116 | import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank; |
@@ -526,12 +528,8 @@ protected LogicalPlan rule(LogicalPlan plan, AnalyzerContext context) { |
526 | 528 | return resolveInsist(i, childrenOutput, context.indexResolution()); |
527 | 529 | } |
528 | 530 |
|
529 | | - if (plan instanceof Dedup dedup) { |
530 | | - return resolveDedup(dedup, childrenOutput); |
531 | | - } |
532 | | - |
533 | | - if (plan instanceof RrfScoreEval rrf) { |
534 | | - return resolveRrfScoreEval(rrf, childrenOutput); |
| 531 | + if (plan instanceof Fuse fuse) { |
| 532 | + return resolveFuse(fuse, childrenOutput); |
535 | 533 | } |
536 | 534 |
|
537 | 535 | if (plan instanceof Rerank r) { |
@@ -930,52 +928,44 @@ private static FieldAttribute insistKeyword(Attribute attribute) { |
930 | 928 | ); |
931 | 929 | } |
932 | 930 |
|
933 | | - private LogicalPlan resolveDedup(Dedup dedup, List<Attribute> childrenOutput) { |
934 | | - List<NamedExpression> aggregates = dedup.finalAggs(); |
935 | | - List<Attribute> groupings = dedup.groupings(); |
936 | | - List<NamedExpression> newAggs = new ArrayList<>(); |
937 | | - List<Attribute> newGroupings = new ArrayList<>(); |
938 | | - |
939 | | - for (NamedExpression agg : aggregates) { |
940 | | - var newAgg = (NamedExpression) agg.transformUp(UnresolvedAttribute.class, ua -> { |
941 | | - Expression ne = ua; |
942 | | - Attribute maybeResolved = maybeResolveAttribute(ua, childrenOutput); |
943 | | - if (maybeResolved != null) { |
944 | | - ne = maybeResolved; |
945 | | - } |
946 | | - return ne; |
947 | | - }); |
948 | | - newAggs.add(newAgg); |
| 931 | + private LogicalPlan resolveFuse(Fuse fuse, List<Attribute> childrenOutput) { |
| 932 | + Source source = fuse.source(); |
| 933 | + Attribute score = fuse.score(); |
| 934 | + if (score instanceof UnresolvedAttribute) { |
| 935 | + score = maybeResolveAttribute((UnresolvedAttribute) score, childrenOutput); |
949 | 936 | } |
950 | 937 |
|
951 | | - for (Attribute attr : groupings) { |
952 | | - if (attr instanceof UnresolvedAttribute ua) { |
953 | | - newGroupings.add(resolveAttribute(ua, childrenOutput)); |
954 | | - } else { |
955 | | - newGroupings.add(attr); |
956 | | - } |
| 938 | + Attribute discriminator = fuse.discriminator(); |
| 939 | + if (discriminator instanceof UnresolvedAttribute) { |
| 940 | + discriminator = maybeResolveAttribute((UnresolvedAttribute) discriminator, childrenOutput); |
957 | 941 | } |
958 | 942 |
|
959 | | - return new Dedup(dedup.source(), dedup.child(), newAggs, newGroupings); |
960 | | - } |
961 | | - |
962 | | - private LogicalPlan resolveRrfScoreEval(RrfScoreEval rrf, List<Attribute> childrenOutput) { |
963 | | - Attribute scoreAttr = rrf.scoreAttribute(); |
964 | | - Attribute forkAttr = rrf.forkAttribute(); |
| 943 | + List<NamedExpression> groupings = fuse.groupings() |
| 944 | + .stream() |
| 945 | + .map(attr -> attr instanceof UnresolvedAttribute ? maybeResolveAttribute((UnresolvedAttribute) attr, childrenOutput) : attr) |
| 946 | + .toList(); |
965 | 947 |
|
966 | | - if (scoreAttr instanceof UnresolvedAttribute ua) { |
967 | | - scoreAttr = resolveAttribute(ua, childrenOutput); |
| 948 | + // some attributes were unresolved - we return Fuse here so that the Verifier can raise an error message |
| 949 | + if (score instanceof UnresolvedAttribute || discriminator instanceof UnresolvedAttribute) { |
| 950 | + return new Fuse(fuse.source(), fuse.child(), score, discriminator, groupings, fuse.fuseType()); |
968 | 951 | } |
969 | 952 |
|
970 | | - if (forkAttr instanceof UnresolvedAttribute ua) { |
971 | | - forkAttr = resolveAttribute(ua, childrenOutput); |
972 | | - } |
| 953 | + LogicalPlan scoreEval = new FuseScoreEval(source, fuse.child(), score, discriminator); |
| 954 | + |
| 955 | + // create aggregations |
| 956 | + Expression aggFilter = new Literal(source, true, DataType.BOOLEAN); |
973 | 957 |
|
974 | | - if (forkAttr != rrf.forkAttribute() || scoreAttr != rrf.scoreAttribute()) { |
975 | | - return new RrfScoreEval(rrf.source(), rrf.child(), scoreAttr, forkAttr); |
| 958 | + List<NamedExpression> aggregates = new ArrayList<>(); |
| 959 | + aggregates.add(new Alias(source, score.name(), new Sum(source, score, aggFilter, SummationMode.COMPENSATED_LITERAL))); |
| 960 | + |
| 961 | + for (Attribute attr : childrenOutput) { |
| 962 | + if (attr.name().equals(score.name())) { |
| 963 | + continue; |
| 964 | + } |
| 965 | + aggregates.add(new Alias(source, attr.name(), new Values(source, attr, aggFilter))); |
976 | 966 | } |
977 | 967 |
|
978 | | - return rrf; |
| 968 | + return resolveAggregate(new Aggregate(source, scoreEval, new ArrayList<>(groupings), aggregates), childrenOutput); |
979 | 969 | } |
980 | 970 |
|
981 | 971 | private Attribute maybeResolveAttribute(UnresolvedAttribute ua, List<Attribute> childrenOutput) { |
|
0 commit comments