Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
import org.elasticsearch.xpack.esql.expression.function.aggregate.MinOverTime;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SumOverTime;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SummationMode;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
Expand Down Expand Up @@ -93,7 +95,6 @@
import org.elasticsearch.xpack.esql.parser.ParsingException;
import org.elasticsearch.xpack.esql.plan.IndexPattern;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Dedup;
import org.elasticsearch.xpack.esql.plan.logical.Drop;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
Expand All @@ -107,8 +108,9 @@
import org.elasticsearch.xpack.esql.plan.logical.MvExpand;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.Rename;
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.fuse.Fuse;
import org.elasticsearch.xpack.esql.plan.logical.fuse.FuseScoreEval;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
Expand Down Expand Up @@ -526,12 +528,8 @@ protected LogicalPlan rule(LogicalPlan plan, AnalyzerContext context) {
return resolveInsist(i, childrenOutput, context.indexResolution());
}

if (plan instanceof Dedup dedup) {
return resolveDedup(dedup, childrenOutput);
}

if (plan instanceof RrfScoreEval rrf) {
return resolveRrfScoreEval(rrf, childrenOutput);
if (plan instanceof Fuse fuse) {
return resolveFuse(fuse, childrenOutput);
}

if (plan instanceof Rerank r) {
Expand Down Expand Up @@ -930,52 +928,44 @@ private static FieldAttribute insistKeyword(Attribute attribute) {
);
}

private LogicalPlan resolveDedup(Dedup dedup, List<Attribute> childrenOutput) {
List<NamedExpression> aggregates = dedup.finalAggs();
List<Attribute> groupings = dedup.groupings();
List<NamedExpression> newAggs = new ArrayList<>();
List<Attribute> newGroupings = new ArrayList<>();

for (NamedExpression agg : aggregates) {
var newAgg = (NamedExpression) agg.transformUp(UnresolvedAttribute.class, ua -> {
Expression ne = ua;
Attribute maybeResolved = maybeResolveAttribute(ua, childrenOutput);
if (maybeResolved != null) {
ne = maybeResolved;
}
return ne;
});
newAggs.add(newAgg);
private LogicalPlan resolveFuse(Fuse fuse, List<Attribute> childrenOutput) {
Source source = fuse.source();
Attribute score = fuse.score();
if (score instanceof UnresolvedAttribute) {
score = maybeResolveAttribute((UnresolvedAttribute) score, childrenOutput);
}

for (Attribute attr : groupings) {
if (attr instanceof UnresolvedAttribute ua) {
newGroupings.add(resolveAttribute(ua, childrenOutput));
} else {
newGroupings.add(attr);
}
Attribute discriminator = fuse.discriminator();
if (discriminator instanceof UnresolvedAttribute) {
discriminator = maybeResolveAttribute((UnresolvedAttribute) discriminator, childrenOutput);
}

return new Dedup(dedup.source(), dedup.child(), newAggs, newGroupings);
}

private LogicalPlan resolveRrfScoreEval(RrfScoreEval rrf, List<Attribute> childrenOutput) {
Attribute scoreAttr = rrf.scoreAttribute();
Attribute forkAttr = rrf.forkAttribute();
List<NamedExpression> groupings = fuse.groupings()
.stream()
.map(attr -> attr instanceof UnresolvedAttribute ? maybeResolveAttribute((UnresolvedAttribute) attr, childrenOutput) : attr)
.toList();

if (scoreAttr instanceof UnresolvedAttribute ua) {
scoreAttr = resolveAttribute(ua, childrenOutput);
// some attributes were unresolved - we return Fuse here so that the Verifier can raise an error message
if (score instanceof UnresolvedAttribute || discriminator instanceof UnresolvedAttribute) {
return new Fuse(fuse.source(), fuse.child(), score, discriminator, groupings, fuse.fuseType());
}

if (forkAttr instanceof UnresolvedAttribute ua) {
forkAttr = resolveAttribute(ua, childrenOutput);
}
LogicalPlan scoreEval = new FuseScoreEval(source, fuse.child(), score, discriminator);

// create aggregations
Expression aggFilter = new Literal(source, true, DataType.BOOLEAN);

if (forkAttr != rrf.forkAttribute() || scoreAttr != rrf.scoreAttribute()) {
return new RrfScoreEval(rrf.source(), rrf.child(), scoreAttr, forkAttr);
List<NamedExpression> aggregates = new ArrayList<>();
aggregates.add(new Alias(source, score.name(), new Sum(source, score, aggFilter, SummationMode.COMPENSATED_LITERAL)));

for (Attribute attr : childrenOutput) {
if (attr.name().equals(score.name())) {
continue;
}
aggregates.add(new Alias(source, attr.name(), new Values(source, attr, aggFilter)));
}

return rrf;
return resolveAggregate(new Aggregate(source, scoreEval, new ArrayList<>(groupings), aggregates), childrenOutput);
}

private Attribute maybeResolveAttribute(UnresolvedAttribute ua, List<Attribute> childrenOutput) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Dedup;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;

import java.io.IOException;
Expand Down Expand Up @@ -174,9 +173,7 @@ public boolean equals(Object obj) {
@Override
public BiConsumer<LogicalPlan, Failures> postAnalysisPlanVerification() {
return (p, failures) -> {
// `dedup` for now is not exposed as a command,
// so allowing aggregate functions for dedup explicitly is just an internal implementation detail
if ((p instanceof Aggregate) == false && (p instanceof Dedup) == false) {
if ((p instanceof Aggregate) == false) {
p.expressions().forEach(x -> x.forEachDown(AggregateFunction.class, af -> {
failures.add(fail(af, "aggregate function [{}] not allowed outside STATS command", af.sourceText()));
}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,9 @@
import org.elasticsearch.xpack.esql.expression.Order;
import org.elasticsearch.xpack.esql.expression.UnresolvedNamePattern;
import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SummationMode;
import org.elasticsearch.xpack.esql.plan.IndexPattern;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.ChangePoint;
import org.elasticsearch.xpack.esql.plan.logical.Dedup;
import org.elasticsearch.xpack.esql.plan.logical.Dissect;
import org.elasticsearch.xpack.esql.plan.logical.Drop;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
Expand All @@ -68,10 +65,10 @@
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.Rename;
import org.elasticsearch.xpack.esql.plan.logical.Row;
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
import org.elasticsearch.xpack.esql.plan.logical.Sample;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.fuse.Fuse;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
Expand Down Expand Up @@ -778,19 +775,14 @@ public PlanFactory visitFuseCommand(EsqlBaseParser.FuseCommandContext ctx) {
Source source = source(ctx);
return input -> {
Attribute scoreAttr = new UnresolvedAttribute(source, MetadataAttribute.SCORE);
Attribute forkAttr = new UnresolvedAttribute(source, Fork.FORK_FIELD);
Attribute discriminatorAttr = new UnresolvedAttribute(source, Fork.FORK_FIELD);
Attribute idAttr = new UnresolvedAttribute(source, IdFieldMapper.NAME);
Attribute indexAttr = new UnresolvedAttribute(source, MetadataAttribute.INDEX);
List<NamedExpression> aggregates = List.of(
new Alias(
source,
MetadataAttribute.SCORE,
new Sum(source, scoreAttr, new Literal(source, true, DataType.BOOLEAN), SummationMode.COMPENSATED_LITERAL)
)
);
List<Attribute> groupings = List.of(idAttr, indexAttr);

return new Dedup(source, new RrfScoreEval(source, input, scoreAttr, forkAttr), aggregates, groupings);
List<NamedExpression> groupings = List.of(idAttr, indexAttr);
Fuse.FuseType fuseType = Fuse.FuseType.RRF;

return new Fuse(source, input, scoreAttr, discriminatorAttr, groupings, fuseType);
};
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.plan.logical.fuse;

import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.capabilities.TelemetryAware;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;

import java.io.IOException;
import java.util.List;

public class Fuse extends UnaryPlan implements TelemetryAware {
private final Attribute score;
private final Attribute discriminator;
private final List<NamedExpression> groupings;
private final FuseType fuseType;

public enum FuseType {
RRF,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically only RRF is used for now - but we will add linear combination support soon after this PR

LINEAR
};

public Fuse(
Source source,
LogicalPlan child,
Attribute score,
Attribute discriminator,
List<NamedExpression> groupings,
FuseType fuseType
) {
super(source, child);
this.score = score;
this.discriminator = discriminator;
this.groupings = groupings;
this.fuseType = fuseType;

}

@Override
public String getWriteableName() {
throw new UnsupportedOperationException("not serialized");
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new UnsupportedOperationException("not serialized");
}

@Override
protected NodeInfo<? extends LogicalPlan> info() {
return NodeInfo.create(this, Fuse::new, child(), score, discriminator, groupings, fuseType);
}

@Override
public UnaryPlan replaceChild(LogicalPlan newChild) {
return new Fuse(source(), newChild, score, discriminator, groupings, fuseType);
}

public List<NamedExpression> groupings() {
return groupings;
}

public Attribute discriminator() {
return discriminator;
}

public Attribute score() {
return score;
}

public FuseType fuseType() {
return fuseType;
}

@Override
public boolean expressionsResolved() {
return score.resolved() && discriminator.resolved() && groupings.stream().allMatch(Expression::resolved);
}
}
Loading