Skip to content

Commit cd4466e

Browse files
committed
Make Vector ops LogicalPlans
Previously Vector operations were defined as Expressions which didn't match the actual semantics not the rest of the parsing infra. This has now been addressed by making the nodes LogicalPlans.
1 parent f9deb4f commit cd4466e

File tree

13 files changed

+3485
-2838
lines changed

13 files changed

+3485
-2838
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 10 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -517,72 +517,11 @@ protected LogicalPlan rule(LogicalPlan plan, AnalyzerContext context) {
517517
case Insist i -> resolveInsist(i, childrenOutput, context);
518518
case Fuse fuse -> resolveFuse(fuse, childrenOutput);
519519
case Rerank r -> resolveRerank(r, childrenOutput);
520+
case PromqlCommand promql -> resolvePromql(promql, childrenOutput);
520521
default -> plan.transformExpressionsOnly(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
521522
};
522523
}
523524

524-
if (plan instanceof Completion c) {
525-
return resolveCompletion(c, childrenOutput);
526-
}
527-
528-
if (plan instanceof Drop d) {
529-
return resolveDrop(d, childrenOutput);
530-
}
531-
532-
if (plan instanceof Rename r) {
533-
return resolveRename(r, childrenOutput);
534-
}
535-
536-
if (plan instanceof Keep p) {
537-
return resolveKeep(p, childrenOutput);
538-
}
539-
540-
if (plan instanceof Fork f) {
541-
return resolveFork(f, context);
542-
}
543-
544-
if (plan instanceof Eval p) {
545-
return resolveEval(p, childrenOutput);
546-
}
547-
548-
if (plan instanceof Enrich p) {
549-
return resolveEnrich(p, childrenOutput);
550-
}
551-
552-
if (plan instanceof MvExpand p) {
553-
return resolveMvExpand(p, childrenOutput);
554-
}
555-
556-
if (plan instanceof Lookup l) {
557-
return resolveLookup(l, childrenOutput);
558-
}
559-
560-
if (plan instanceof LookupJoin j) {
561-
return resolveLookupJoin(j, context);
562-
}
563-
564-
if (plan instanceof Insist i) {
565-
return resolveInsist(i, childrenOutput, context);
566-
}
567-
568-
if (plan instanceof Fuse fuse) {
569-
return resolveFuse(fuse, childrenOutput);
570-
}
571-
572-
if (plan instanceof Rerank r) {
573-
return resolveRerank(r, childrenOutput);
574-
}
575-
576-
if (plan instanceof PromqlCommand p) {
577-
LogicalPlan nested = p.promqlPlan();
578-
return p.withPromqlPlan(
579-
nested.transformExpressionsDown(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput))
580-
);
581-
}
582-
583-
return plan.transformExpressionsOnly(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
584-
}
585-
586525
private Aggregate resolveAggregate(Aggregate aggregate, List<Attribute> childrenOutput) {
587526
// if the grouping is resolved but the aggs are not, use the former to resolve the latter
588527
// e.g. STATS a ... GROUP BY a = x + 1
@@ -1161,6 +1100,15 @@ private LogicalPlan resolveFuse(Fuse fuse, List<Attribute> childrenOutput) {
11611100
return resolveAggregate(new Aggregate(source, scoreEval, new ArrayList<>(keys), aggregates), childrenOutput);
11621101
}
11631102

1103+
private LogicalPlan resolvePromql(PromqlCommand promql, List<Attribute> childrenOutput) {
1104+
LogicalPlan promqlPlan = promql.promqlPlan();
1105+
Function<UnresolvedAttribute, Expression> lambda = ua -> maybeResolveAttribute(ua, childrenOutput);
1106+
// resolve the nested plan
1107+
return promql.withPromqlPlan(promqlPlan.transformExpressionsDown(UnresolvedAttribute.class, lambda))
1108+
// but also any unresolved expressions
1109+
.transformExpressionsOnly(UnresolvedAttribute.class, lambda);
1110+
}
1111+
11641112
private Attribute maybeResolveAttribute(UnresolvedAttribute ua, List<Attribute> childrenOutput) {
11651113
return maybeResolveAttribute(ua, childrenOutput, log);
11661114
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/promql/predicate/operator/VectorBinaryOperator.java

Lines changed: 95 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,31 @@
88
package org.elasticsearch.xpack.esql.expression.promql.predicate.operator;
99

1010
import org.elasticsearch.common.io.stream.StreamOutput;
11-
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
11+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
1212
import org.elasticsearch.xpack.esql.core.expression.Expression;
13-
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
14-
import org.elasticsearch.xpack.esql.core.expression.Nullability;
13+
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
1514
import org.elasticsearch.xpack.esql.core.expression.function.Function;
1615
import org.elasticsearch.xpack.esql.core.tree.Source;
1716
import org.elasticsearch.xpack.esql.core.type.DataType;
18-
import org.elasticsearch.xpack.esql.expression.promql.types.PromqlDataTypes;
17+
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
18+
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
19+
import org.elasticsearch.xpack.esql.plan.logical.BinaryPlan;
20+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
21+
import org.elasticsearch.xpack.esql.plan.logical.promql.selector.LabelMatcher;
1922

2023
import java.io.IOException;
24+
import java.util.ArrayList;
25+
import java.util.HashSet;
2126
import java.util.List;
2227
import java.util.Objects;
28+
import java.util.Set;
2329

24-
import static java.util.Arrays.asList;
30+
public abstract class VectorBinaryOperator extends BinaryPlan {
2531

26-
public abstract class VectorBinaryOperator extends Expression {
27-
28-
private final Expression left, right;
2932
private final VectorMatch match;
3033
private final boolean dropMetricName;
31-
32-
private DataType dataType;
33-
34-
private BinaryOp binaryOp;
34+
private final BinaryOp binaryOp;
35+
private List<Attribute> output;
3536

3637
/**
3738
* Underlying binary operation (e.g. +, -, *, /, etc.) being performed
@@ -49,28 +50,18 @@ public interface ScalarFunctionFactory {
4950

5051
protected VectorBinaryOperator(
5152
Source source,
52-
Expression left,
53-
Expression right,
53+
LogicalPlan left,
54+
LogicalPlan right,
5455
VectorMatch match,
5556
boolean dropMetricName,
5657
BinaryOp binaryOp
5758
) {
58-
super(source, asList(left, right));
59-
this.left = left;
60-
this.right = right;
59+
super(source, left, right);
6160
this.match = match;
6261
this.dropMetricName = dropMetricName;
6362
this.binaryOp = binaryOp;
6463
}
6564

66-
public Expression left() {
67-
return left;
68-
}
69-
70-
public Expression right() {
71-
return right;
72-
}
73-
7465
public VectorMatch match() {
7566
return match;
7667
}
@@ -84,55 +75,113 @@ public BinaryOp binaryOp() {
8475
}
8576

8677
@Override
87-
public DataType dataType() {
88-
if (dataType == null) {
89-
dataType = PromqlDataTypes.operationType(left.dataType(), right.dataType());
78+
public List<Attribute> output() {
79+
if (output == null) {
80+
output = computeOutputAttributes();
81+
}
82+
return output;
83+
}
84+
85+
private List<Attribute> computeOutputAttributes() {
86+
// TODO: this isn't tested and should be revised
87+
List<Attribute> leftAttrs = left().output();
88+
List<Attribute> rightAttrs = right().output();
89+
90+
Set<String> leftLabels = extractLabelNames(leftAttrs);
91+
Set<String> rightLabels = extractLabelNames(rightAttrs);
92+
93+
Set<String> outputLabels;
94+
95+
if (match != null) {
96+
if (match.filter() == VectorMatch.Filter.ON) {
97+
outputLabels = new HashSet<>(match.filterLabels());
98+
} else if (match.filter() == VectorMatch.Filter.IGNORING) {
99+
outputLabels = new HashSet<>(leftLabels);
100+
outputLabels.addAll(rightLabels);
101+
outputLabels.removeAll(match.filterLabels());
102+
} else {
103+
outputLabels = new HashSet<>(leftLabels);
104+
outputLabels.retainAll(rightLabels);
105+
}
106+
} else {
107+
outputLabels = new HashSet<>(leftLabels);
108+
outputLabels.retainAll(rightLabels);
90109
}
91-
return dataType;
92-
}
93110

94-
@Override
95-
public VectorBinaryOperator replaceChildren(List<Expression> newChildren) {
96-
return replaceChildren(left, right);
111+
if (dropMetricName) {
112+
outputLabels.remove(LabelMatcher.NAME);
113+
}
114+
115+
List<Attribute> result = new ArrayList<>();
116+
for (String label : outputLabels) {
117+
Attribute attr = findAttribute(label, leftAttrs, rightAttrs);
118+
if (attr != null) {
119+
result.add(attr);
120+
}
121+
}
122+
123+
result.add(new ReferenceAttribute(source(), "value", DataType.DOUBLE));
124+
return result;
97125
}
98126

99-
protected abstract VectorBinaryOperator replaceChildren(Expression left, Expression right);
127+
private Set<String> extractLabelNames(List<Attribute> attrs) {
128+
Set<String> labels = new HashSet<>();
129+
for (Attribute attr : attrs) {
130+
String name = attr.name();
131+
if (name.equals("value") == false) {
132+
labels.add(name);
133+
}
134+
}
135+
return labels;
136+
}
100137

101-
@Override
102-
public boolean foldable() {
103-
return left.foldable() && right.foldable();
138+
private Attribute findAttribute(String name, List<Attribute> left, List<Attribute> right) {
139+
for (Attribute attr : left) {
140+
if (attr.name().equals(name)) {
141+
return attr;
142+
}
143+
}
144+
for (Attribute attr : right) {
145+
if (attr.name().equals(name)) {
146+
return attr;
147+
}
148+
}
149+
return null;
104150
}
105151

106152
@Override
107-
public Object fold(FoldContext ctx) {
108-
return binaryOp.asFunction().create(source(), left(), right()).fold(ctx);
109-
}
153+
public abstract VectorBinaryOperator replaceChildren(LogicalPlan newLeft, LogicalPlan newRight);
110154

111155
@Override
112-
public Nullability nullable() {
113-
return Nullability.TRUE;
156+
public boolean expressionsResolved() {
157+
return true;
114158
}
115159

116160
@Override
117161
public boolean equals(Object o) {
162+
if (this == o) return true;
163+
if (o == null || getClass() != o.getClass()) return false;
118164
if (super.equals(o)) {
119165
VectorBinaryOperator that = (VectorBinaryOperator) o;
120-
return dropMetricName == that.dropMetricName && Objects.equals(match, that.match) && Objects.equals(binaryOp, that.binaryOp);
166+
return dropMetricName == that.dropMetricName
167+
&& Objects.equals(match, that.match)
168+
&& Objects.equals(binaryOp, that.binaryOp);
121169
}
122170
return false;
123171
}
124172

125173
@Override
126174
public int hashCode() {
127-
return Objects.hash(left, right, match, dropMetricName, binaryOp);
175+
return Objects.hash(super.hashCode(), match, dropMetricName, binaryOp);
128176
}
129177

178+
@Override
130179
public String getWriteableName() {
131-
throw new EsqlIllegalArgumentException("should not be serialized");
180+
throw new UnsupportedOperationException("PromQL plans should not be serialized");
132181
}
133182

134183
@Override
135184
public void writeTo(StreamOutput out) throws IOException {
136-
throw new EsqlIllegalArgumentException("should not be serialized");
185+
throw new UnsupportedOperationException("PromQL plans should not be serialized");
137186
}
138187
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/promql/predicate/operator/arithmetic/VectorBinaryArithmetic.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.esql.expression.promql.predicate.operator.arithmetic;
99

10-
import org.elasticsearch.xpack.esql.core.expression.Expression;
1110
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
1211
import org.elasticsearch.xpack.esql.core.tree.Source;
1312
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pow;
@@ -18,6 +17,7 @@
1817
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub;
1918
import org.elasticsearch.xpack.esql.expression.promql.predicate.operator.VectorBinaryOperator;
2019
import org.elasticsearch.xpack.esql.expression.promql.predicate.operator.VectorMatch;
20+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
2121

2222
public class VectorBinaryArithmetic extends VectorBinaryOperator {
2323

@@ -44,7 +44,7 @@ public ScalarFunctionFactory asFunction() {
4444

4545
private final ArithmeticOp op;
4646

47-
public VectorBinaryArithmetic(Source source, Expression left, Expression right, VectorMatch match, ArithmeticOp op) {
47+
public VectorBinaryArithmetic(Source source, LogicalPlan left, LogicalPlan right, VectorMatch match, ArithmeticOp op) {
4848
super(source, left, right, match, true, op);
4949
this.op = op;
5050
}
@@ -54,12 +54,12 @@ public ArithmeticOp op() {
5454
}
5555

5656
@Override
57-
protected VectorBinaryOperator replaceChildren(Expression left, Expression right) {
58-
return new VectorBinaryArithmetic(source(), left, right, match(), op());
57+
public VectorBinaryOperator replaceChildren(LogicalPlan newLeft, LogicalPlan newRight) {
58+
return new VectorBinaryArithmetic(source(), newLeft, newRight, match(), op());
5959
}
6060

6161
@Override
62-
protected NodeInfo<? extends Expression> info() {
62+
protected NodeInfo<VectorBinaryArithmetic> info() {
6363
return NodeInfo.create(this, VectorBinaryArithmetic::new, left(), right(), match(), op());
6464
}
6565
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/promql/predicate/operator/comparison/VectorBinaryComparison.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.esql.expression.promql.predicate.operator.comparison;
99

10-
import org.elasticsearch.xpack.esql.core.expression.Expression;
1110
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
1211
import org.elasticsearch.xpack.esql.core.tree.Source;
1312
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
@@ -18,6 +17,7 @@
1817
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
1918
import org.elasticsearch.xpack.esql.expression.promql.predicate.operator.VectorBinaryOperator;
2019
import org.elasticsearch.xpack.esql.expression.promql.predicate.operator.VectorMatch;
20+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
2121

2222
import java.util.Objects;
2323

@@ -47,7 +47,7 @@ public ScalarFunctionFactory asFunction() {
4747
private final ComparisonOp op;
4848
private final boolean boolMode;
4949

50-
public VectorBinaryComparison(Source source, Expression left, Expression right, VectorMatch match, boolean boolMode, ComparisonOp op) {
50+
public VectorBinaryComparison(Source source, LogicalPlan left, LogicalPlan right, VectorMatch match, boolean boolMode, ComparisonOp op) {
5151
super(source, left, right, match, boolMode == false, op);
5252
this.op = op;
5353
this.boolMode = boolMode;
@@ -62,12 +62,12 @@ public boolean boolMode() {
6262
}
6363

6464
@Override
65-
protected VectorBinaryOperator replaceChildren(Expression left, Expression right) {
66-
return new VectorBinaryComparison(source(), left, right, match(), boolMode, op());
65+
public VectorBinaryOperator replaceChildren(LogicalPlan newLeft, LogicalPlan newRight) {
66+
return new VectorBinaryComparison(source(), newLeft, newRight, match(), boolMode, op());
6767
}
6868

6969
@Override
70-
protected NodeInfo<? extends Expression> info() {
70+
protected NodeInfo<VectorBinaryComparison> info() {
7171
return NodeInfo.create(this, VectorBinaryComparison::new, left(), right(), match(), boolMode(), op());
7272
}
7373

0 commit comments

Comments
 (0)