Skip to content

Commit 3280482

Browse files
committed
add rule to analyzer
1 parent 8c25295 commit 3280482

File tree

8 files changed

+140
-19
lines changed

8 files changed

+140
-19
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 120 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.common.logging.HeaderWarning;
1111
import org.elasticsearch.common.logging.LoggerMessageFormat;
12+
import org.elasticsearch.compute.data.AggregateMetricDoubleBlockBuilder;
1213
import org.elasticsearch.compute.data.Block;
1314
import org.elasticsearch.core.Strings;
1415
import org.elasticsearch.index.IndexMode;
@@ -52,6 +53,12 @@
5253
import org.elasticsearch.xpack.esql.expression.function.FunctionDefinition;
5354
import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction;
5455
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
56+
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
57+
import org.elasticsearch.xpack.esql.expression.function.aggregate.Avg;
58+
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
59+
import org.elasticsearch.xpack.esql.expression.function.aggregate.Max;
60+
import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
61+
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
5562
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
5663
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
5764
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
@@ -60,6 +67,8 @@
6067
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction;
6168
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ConvertFunction;
6269
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FoldablesConvertFunction;
70+
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromAggregateMetricDouble;
71+
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble;
6372
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos;
6473
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble;
6574
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger;
@@ -132,6 +141,7 @@
132141
import static java.util.Collections.emptyList;
133142
import static java.util.Collections.singletonList;
134143
import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.GEO_MATCH_TYPE;
144+
import static org.elasticsearch.xpack.esql.core.type.DataType.AGGREGATE_METRIC_DOUBLE;
135145
import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN;
136146
import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME;
137147
import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS;
@@ -179,7 +189,8 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
179189
"Resolution",
180190
new ResolveRefs(),
181191
new ImplicitCasting(),
182-
new ResolveUnionTypes() // Must be after ResolveRefs, so union types can be found
192+
new ResolveUnionTypes(), // Must be after ResolveRefs, so union types can be found
193+
new ImplicitCastAggregateMetricDoubles()
183194
),
184195
new Batch<>("Finish Analysis", Limiter.ONCE, new AddImplicitLimit(), new AddImplicitForkLimit(), new UnionTypesCleanup())
185196
);
@@ -1642,9 +1653,15 @@ private LogicalPlan doRule(LogicalPlan plan) {
16421653
return plan;
16431654
}
16441655

1645-
// And add generated fields to EsRelation, so these new attributes will appear in the OutputExec of the Fragment
1646-
// and thereby get used in FieldExtractExec
1647-
plan = plan.transformDown(EsRelation.class, esr -> {
1656+
return addGeneratedFieldsToEsRelations(plan, unionFieldAttributes);
1657+
}
1658+
1659+
/**
1660+
* Add generated fields to EsRelation, so these new attributes will appear in the OutputExec of the Fragment
1661+
* and thereby get used in FieldExtractExec
1662+
*/
1663+
private static LogicalPlan addGeneratedFieldsToEsRelations(LogicalPlan plan, List<FieldAttribute> unionFieldAttributes) {
1664+
return plan.transformDown(EsRelation.class, esr -> {
16481665
List<Attribute> missing = new ArrayList<>();
16491666
for (FieldAttribute fa : unionFieldAttributes) {
16501667
// Using outputSet().contains looks by NameId, resp. uses semanticEquals.
@@ -1664,7 +1681,6 @@ private LogicalPlan doRule(LogicalPlan plan) {
16641681
}
16651682
return esr;
16661683
});
1667-
return plan;
16681684
}
16691685

16701686
private Expression resolveConvertFunction(ConvertFunction convert, List<FieldAttribute> unionFieldAttributes) {
@@ -1734,7 +1750,7 @@ private Expression resolveConvertFunction(ConvertFunction convert, List<FieldAtt
17341750
return convertExpression;
17351751
}
17361752

1737-
private Expression createIfDoesNotAlreadyExist(
1753+
private static Expression createIfDoesNotAlreadyExist(
17381754
FieldAttribute fa,
17391755
MultiTypeEsField resolvedField,
17401756
List<FieldAttribute> unionFieldAttributes
@@ -1896,4 +1912,102 @@ private static void typeResolutions(
18961912
var concreteConvert = ResolveUnionTypes.typeSpecificConvert(convert, fieldAttribute.source(), type, imf);
18971913
typeResolutions.put(key, concreteConvert);
18981914
}
1915+
1916+
/**
1917+
* Take InvalidMappedFields in specific aggregations (min, max, sum, count, and avg) and if all original data types
1918+
* are aggregate metric double + any combination of numerics, implicitly cast them to the same type: aggregate metric
1919+
* double for count, and double for min, max, and sum. Avg gets replaced with its surrogate (Div(Sum, Count))
1920+
*/
1921+
private static class ImplicitCastAggregateMetricDoubles extends Rule<LogicalPlan, LogicalPlan> {
1922+
1923+
private List<FieldAttribute> unionFieldAttributes;
1924+
1925+
@Override
1926+
public LogicalPlan apply(LogicalPlan plan) {
1927+
unionFieldAttributes = new ArrayList<>();
1928+
return plan.transformUp(LogicalPlan.class, p -> p.childrenResolved() == false ? p : doRule(p));
1929+
}
1930+
1931+
private LogicalPlan doRule(LogicalPlan plan) {
1932+
int alreadyAddedUnionFieldAttributes = unionFieldAttributes.size();
1933+
plan = plan.transformExpressionsOnly(e -> switch (e) {
1934+
case Max max -> resolveMetricFunction(max, AggregateMetricDoubleBlockBuilder.Metric.MAX);
1935+
case Min min -> resolveMetricFunction(min, AggregateMetricDoubleBlockBuilder.Metric.MIN);
1936+
case Sum sum -> resolveMetricFunction(sum, AggregateMetricDoubleBlockBuilder.Metric.SUM);
1937+
case Count count -> resolveMetricFunction(count, AggregateMetricDoubleBlockBuilder.Metric.COUNT);
1938+
case Avg avg -> substituteSurrogates(avg);
1939+
default -> e;
1940+
});
1941+
1942+
if (unionFieldAttributes.size() == alreadyAddedUnionFieldAttributes) {
1943+
return plan;
1944+
}
1945+
return ResolveUnionTypes.addGeneratedFieldsToEsRelations(plan, unionFieldAttributes);
1946+
}
1947+
1948+
private Expression resolveMetricFunction(Expression expression, AggregateMetricDoubleBlockBuilder.Metric metric) {
1949+
AggregateFunction aggregateFunction = (AggregateFunction) expression;
1950+
if (aggregateFunction.field() instanceof FieldAttribute fa && fa.field() instanceof InvalidMappedField imf) {
1951+
HashMap<ResolveUnionTypes.TypeResolutionKey, Expression> typeResolutions = new HashMap<>();
1952+
if (typesShouldBeConverted(imf.types()) == false) {
1953+
return expression;
1954+
}
1955+
for (DataType type : imf.types()) {
1956+
// Effectively the contents of ResolveUnionTypes::typeSpecificConvert(...)
1957+
// except convertFunction is not necessarily a ConvertFunction (as in the case of Sum's FromAggregateMetricDouble)
1958+
// and we do not substitute surrogates because Count does have a surrogate in the case of aggregate metric double
1959+
ResolveUnionTypes.TypeResolutionKey key = new ResolveUnionTypes.TypeResolutionKey(fa.name(), type);
1960+
EsField field = new EsField(imf.getName(), type, imf.getProperties(), imf.isAggregatable());
1961+
FieldAttribute originalFieldAttr = (FieldAttribute) aggregateFunction.field();
1962+
FieldAttribute resolved = new FieldAttribute(
1963+
fa.source(),
1964+
originalFieldAttr.parentName(),
1965+
originalFieldAttr.name(),
1966+
field,
1967+
originalFieldAttr.nullable(),
1968+
originalFieldAttr.id(),
1969+
true
1970+
);
1971+
1972+
Expression convertExpression;
1973+
if (metric == AggregateMetricDoubleBlockBuilder.Metric.COUNT) {
1974+
convertExpression = new ToAggregateMetricDouble(fa.source(), resolved);
1975+
} else if (type == AGGREGATE_METRIC_DOUBLE) {
1976+
convertExpression = FromAggregateMetricDouble.withMetric(fa.source(), resolved, metric);
1977+
} else {
1978+
convertExpression = new ToDouble(fa.source(), resolved);
1979+
}
1980+
Expression e = expression.replaceChildren(List.of(convertExpression, expression.children().get(1)));
1981+
typeResolutions.put(key, e.children().getFirst());
1982+
}
1983+
var resolvedField = ResolveUnionTypes.resolvedMultiTypeEsField(fa, typeResolutions);
1984+
var newFieldAttribute = ResolveUnionTypes.createIfDoesNotAlreadyExist(fa, resolvedField, unionFieldAttributes);
1985+
return expression.replaceChildren(List.of(newFieldAttribute, expression.children().get(1)));
1986+
}
1987+
return expression;
1988+
}
1989+
1990+
private Expression substituteSurrogates(Expression expression) {
1991+
AggregateFunction aggregateFunction = (AggregateFunction) expression;
1992+
if (aggregateFunction.field() instanceof FieldAttribute fa && fa.field() instanceof InvalidMappedField imf) {
1993+
if (typesShouldBeConverted(imf.types()) == false) {
1994+
return expression;
1995+
}
1996+
return SubstituteSurrogateExpressions.rule(expression);
1997+
}
1998+
return expression;
1999+
}
2000+
2001+
private boolean typesShouldBeConverted(Set<DataType> types) {
2002+
if (types.contains(AGGREGATE_METRIC_DOUBLE) == false) {
2003+
return false;
2004+
}
2005+
for (DataType type : types) {
2006+
if (type.isNumeric() == false && type != AGGREGATE_METRIC_DOUBLE) {
2007+
return false;
2008+
}
2009+
}
2010+
return true;
2011+
}
2012+
}
18992013
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public Avg(
5050
Source source,
5151
@Param(
5252
name = "number",
53-
type = { "double", "integer", "long" },
53+
type = { "aggregate_metric_double", "double", "integer", "long" },
5454
description = "Expression that outputs values to average."
5555
) Expression field
5656
) {
@@ -65,10 +65,10 @@ public Avg(Source source, Expression field, Expression filter) {
6565
protected Expression.TypeResolution resolveType() {
6666
return isType(
6767
field(),
68-
dt -> dt.isNumeric() && dt != DataType.UNSIGNED_LONG,
68+
dt -> dt.isNumeric() && dt != DataType.UNSIGNED_LONG || dt == DataType.AGGREGATE_METRIC_DOUBLE,
6969
sourceText(),
7070
DEFAULT,
71-
"numeric except unsigned_long or counter types"
71+
"aggregate_metric_double or numeric except unsigned_long or counter types"
7272
);
7373
}
7474

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
import org.elasticsearch.xpack.esql.core.type.KeywordEsField;
6161
import org.elasticsearch.xpack.esql.core.type.MultiTypeEsField;
6262
import org.elasticsearch.xpack.esql.core.type.PotentiallyUnmappedKeywordEsField;
63-
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction;
63+
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
6464
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
6565
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
6666
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec.Sort;
@@ -170,7 +170,7 @@ private BlockLoader getBlockLoaderFor(int shardId, Attribute attr, MappedFieldTy
170170
Expression conversion = unionTypes.getConversionExpressionForIndex(indexName);
171171
return conversion == null
172172
? BlockLoader.CONSTANT_NULLS
173-
: new TypeConvertingBlockLoader(blockLoader, (AbstractConvertFunction) conversion);
173+
: new TypeConvertingBlockLoader(blockLoader, (EsqlScalarFunction) conversion);
174174
}
175175
return blockLoader;
176176
}
@@ -479,9 +479,9 @@ private static class TypeConvertingBlockLoader implements BlockLoader {
479479
private final BlockLoader delegate;
480480
private final TypeConverter typeConverter;
481481

482-
protected TypeConvertingBlockLoader(BlockLoader delegate, AbstractConvertFunction convertFunction) {
482+
protected TypeConvertingBlockLoader(BlockLoader delegate, EsqlScalarFunction convertFunction) {
483483
this.delegate = delegate;
484-
this.typeConverter = TypeConverter.fromConvertFunction(convertFunction);
484+
this.typeConverter = TypeConverter.fromScalarFunction(convertFunction);
485485
}
486486

487487
@Override

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/TypeConverter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import org.elasticsearch.xpack.esql.core.expression.Expression;
1818
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
1919
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
20-
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction;
20+
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
2121

2222
class TypeConverter {
2323
private final String evaluatorName;
@@ -28,7 +28,7 @@ private TypeConverter(String evaluatorName, ExpressionEvaluator convertEvaluator
2828
this.convertEvaluator = convertEvaluator;
2929
}
3030

31-
public static TypeConverter fromConvertFunction(AbstractConvertFunction convertFunction) {
31+
public static TypeConverter fromScalarFunction(EsqlScalarFunction convertFunction) {
3232
DriverContext driverContext1 = new DriverContext(
3333
BigArrays.NON_RECYCLING_INSTANCE,
3434
new org.elasticsearch.compute.data.BlockFactory(

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2005,7 +2005,7 @@ public void testUnsupportedTypesInStats() {
20052005
| stats avg(x), count_distinct(x), max(x), median(x), median_absolute_deviation(x), min(x), percentile(x, 10), sum(x)
20062006
""", """
20072007
Found 8 problems
2008-
line 2:12: argument of [avg(x)] must be [numeric except unsigned_long or counter types],\
2008+
line 2:12: argument of [avg(x)] must be [aggregate_metric_double or numeric except unsigned_long or counter types],\
20092009
found value [x] type [unsigned_long]
20102010
line 2:20: argument of [count_distinct(x)] must be [any exact type except unsigned_long, _source, or counter types],\
20112011
found value [x] type [unsigned_long]

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ public void testAggsExpressionsInStatsAggs() {
358358
error("from test | stats max(max(salary)) by first_name")
359359
);
360360
assertEquals(
361-
"1:25: argument of [avg(first_name)] must be [numeric except unsigned_long or counter types],"
361+
"1:25: argument of [avg(first_name)] must be [aggregate_metric_double or numeric except unsigned_long or counter types],"
362362
+ " found value [first_name] type [keyword]",
363363
error("from test | stats count(avg(first_name)) by first_name")
364364
);

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgErrorTests.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ protected Expression build(Source source, List<Expression> args) {
3232

3333
@Override
3434
protected Matcher<String> expectedTypeErrorMatcher(List<Set<DataType>> validPerPosition, List<DataType> signature) {
35-
return equalTo(typeErrorMessage(false, validPerPosition, signature, (v, p) -> "numeric except unsigned_long or counter types"));
35+
return equalTo(
36+
typeErrorMessage(
37+
false,
38+
validPerPosition,
39+
signature,
40+
(v, p) -> "aggregate_metric_double or numeric except unsigned_long or counter types"
41+
)
42+
);
3643
}
3744
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ private Block getBlockForMultiType(DocBlock indexDoc, MultiTypeEsField multiType
330330
}
331331
return switch (extractBlockForSingleDoc(indexDoc, ((FieldAttribute) conversion.field()).fieldName(), blockCopier)) {
332332
case BlockResultMissing unused -> getNullsBlock(indexDoc);
333-
case BlockResultSuccess success -> TypeConverter.fromConvertFunction(conversion).convert(success.block);
333+
case BlockResultSuccess success -> TypeConverter.fromScalarFunction(conversion).convert(success.block);
334334
};
335335
}
336336

0 commit comments

Comments
 (0)