Skip to content

Commit 3566ff5

Browse files
committed
Approximate ESQL stats execution using 1000 documents
1 parent 75ca874 commit 3566ff5

File tree

11 files changed

+244
-5
lines changed

11 files changed

+244
-5
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/esql/action/EsqlQueryRequest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,6 @@ protected EsqlQueryRequest(StreamInput in) throws IOException {
2424
public abstract String query();
2525

2626
public abstract QueryBuilder filter();
27+
28+
public abstract boolean approximate();
2729
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/esql/action/EsqlQueryRequestBuilder.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ public final ActionType<Response> action() {
3939

4040
public abstract EsqlQueryRequestBuilder<Request, Response> filter(QueryBuilder filter);
4141

42+
public abstract EsqlQueryRequestBuilder<Request, Response> approximate(boolean approximate);
43+
4244
public abstract EsqlQueryRequestBuilder<Request, Response> allowPartialResults(boolean allowPartialResults);
4345

4446
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryRequest.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ public class EsqlQueryRequest extends org.elasticsearch.xpack.core.esql.action.E
4848
private boolean includeCCSMetadata;
4949
private Locale locale;
5050
private QueryBuilder filter;
51+
private boolean approximate;
5152
private QueryPragmas pragmas = new QueryPragmas(Settings.EMPTY);
5253
private QueryParams params = new QueryParams();
5354
private TimeValue waitForCompletionTimeout = DEFAULT_WAIT_FOR_COMPLETION;
@@ -167,6 +168,16 @@ public QueryBuilder filter() {
167168
return filter;
168169
}
169170

171+
public EsqlQueryRequest approximate(boolean approximate) {
172+
this.approximate = approximate;
173+
return this;
174+
}
175+
176+
@Override
177+
public boolean approximate() {
178+
return approximate;
179+
}
180+
170181
public EsqlQueryRequest pragmas(QueryPragmas pragmas) {
171182
this.pragmas = pragmas;
172183
return this;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryRequestBuilder.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ public EsqlQueryRequestBuilder filter(QueryBuilder filter) {
4646
return this;
4747
}
4848

49+
@Override
50+
public EsqlQueryRequestBuilder approximate(boolean approximate) {
51+
request.approximate(approximate);
52+
return this;
53+
}
54+
4955
public EsqlQueryRequestBuilder pragmas(QueryPragmas pragmas) {
5056
request.pragmas(pragmas);
5157
return this;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/RequestXContent.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ String fields() {
7474
static final ParseField QUERY_FIELD = new ParseField("query");
7575
private static final ParseField COLUMNAR_FIELD = new ParseField("columnar");
7676
private static final ParseField FILTER_FIELD = new ParseField("filter");
77+
private static final ParseField APPROMIXATE_FIELD = new ParseField("approximate");
7778
static final ParseField PRAGMA_FIELD = new ParseField("pragma");
7879
private static final ParseField PARAMS_FIELD = new ParseField("params");
7980
private static final ParseField LOCALE_FIELD = new ParseField("locale");
@@ -103,6 +104,7 @@ private static void objectParserCommon(ObjectParser<EsqlQueryRequest, ?> parser)
103104
parser.declareString(EsqlQueryRequest::query, QUERY_FIELD);
104105
parser.declareBoolean(EsqlQueryRequest::columnar, COLUMNAR_FIELD);
105106
parser.declareObject(EsqlQueryRequest::filter, (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p), FILTER_FIELD);
107+
parser.declareBoolean(EsqlQueryRequest::approximate, APPROMIXATE_FIELD);
106108
parser.declareBoolean(EsqlQueryRequest::acceptedPragmaRisks, ACCEPT_PRAGMA_RISKS);
107109
parser.declareBoolean(EsqlQueryRequest::includeCCSMetadata, INCLUDE_CCS_METADATA_FIELD);
108110
parser.declareObject(
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.approximate;
9+
10+
import org.elasticsearch.compute.data.LongBlock;
11+
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
12+
import org.elasticsearch.xpack.esql.core.expression.Alias;
13+
import org.elasticsearch.xpack.esql.core.expression.Expression;
14+
import org.elasticsearch.xpack.esql.core.expression.Literal;
15+
import org.elasticsearch.xpack.esql.core.tree.Source;
16+
import org.elasticsearch.xpack.esql.core.type.DataType;
17+
import org.elasticsearch.xpack.esql.core.util.Holder;
18+
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
19+
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
20+
import org.elasticsearch.xpack.esql.plan.logical.Dissect;
21+
import org.elasticsearch.xpack.esql.plan.logical.Drop;
22+
import org.elasticsearch.xpack.esql.plan.logical.Eval;
23+
import org.elasticsearch.xpack.esql.plan.logical.Filter;
24+
import org.elasticsearch.xpack.esql.plan.logical.Grok;
25+
import org.elasticsearch.xpack.esql.plan.logical.Keep;
26+
import org.elasticsearch.xpack.esql.plan.logical.LeafPlan;
27+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
28+
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
29+
import org.elasticsearch.xpack.esql.plan.logical.Rename;
30+
import org.elasticsearch.xpack.esql.plan.logical.Sample;
31+
import org.elasticsearch.xpack.esql.session.Result;
32+
33+
import java.util.List;
34+
import java.util.Locale;
35+
import java.util.Set;
36+
37+
public class Approximate {
38+
39+
private static final Set<Class<? extends LogicalPlan>> SWAPPABLE_WITH_SAMPLE = Set.of(
40+
Dissect.class,
41+
Drop.class,
42+
Eval.class,
43+
Filter.class,
44+
Grok.class,
45+
Keep.class,
46+
OrderBy.class,
47+
Rename.class,
48+
Sample.class
49+
);
50+
51+
private static final int SAMPLE_ROW_COUNT = 1000;
52+
53+
private final LogicalPlan logicalPlan;
54+
55+
public Approximate(LogicalPlan logicalPlan) {
56+
this.logicalPlan = logicalPlan;
57+
verifyPlan();
58+
}
59+
60+
/**
61+
* Verifies that a plan is suitable for approximation.
62+
*
63+
* To be so, the plan must contain at least one STATS function, and all
64+
* functions between the source and the leftmost STATS function must be
65+
* swappable with STATS.
66+
*
67+
* In that case, the STATS can be replaced by SAMPLE, STATS and stats
68+
* correction terms, and the SAMPLE can be moved to the source and
69+
* executed inside Lucene.
70+
*/
71+
private void verifyPlan() {
72+
if (logicalPlan.preOptimized() == false) {
73+
throw new IllegalStateException("Expected pre-optimized plan");
74+
}
75+
76+
if (logicalPlan.anyMatch(plan -> plan instanceof Aggregate) == false) {
77+
throw new InvalidArgumentException("query without [STATS] function cannot be approximated");
78+
}
79+
80+
Holder<Boolean> encounteredStats = new Holder<>(false);
81+
logicalPlan.transformUp(plan -> {
82+
if (plan instanceof LeafPlan) {
83+
encounteredStats.set(false);
84+
} else if (encounteredStats.get() == false) {
85+
if (plan instanceof Aggregate) {
86+
encounteredStats.set(true);
87+
} else if (SWAPPABLE_WITH_SAMPLE.contains(plan.getClass()) == false) {
88+
throw new InvalidArgumentException(
89+
"query with [" + plan.nodeName().toUpperCase(Locale.ROOT) + "] before [STATS] function cannot be approximated"
90+
);
91+
}
92+
}
93+
return plan;
94+
});
95+
}
96+
97+
/**
98+
* Returns a plan that counts the number of rows of the original plan that
99+
* would reach the leftmost STATS function. So it's the original plan cut
100+
* off at the leftmost STATS function, followed by "| STATS COUNT(*)".
101+
* This value can be used to pick a good sample probability.
102+
*/
103+
public LogicalPlan countPlan() {
104+
Holder<Boolean> encounteredStats = new Holder<>(false);
105+
LogicalPlan countPlan = logicalPlan.transformUp(plan -> {
106+
if (plan instanceof LeafPlan) {
107+
encounteredStats.set(false);
108+
} else if (encounteredStats.get() == false) {
109+
if (plan instanceof Aggregate aggregate) {
110+
encounteredStats.set(true);
111+
plan = new Aggregate(
112+
Source.EMPTY,
113+
aggregate.child(),
114+
List.of(),
115+
List.of(new Alias(Source.EMPTY, "approximate-count", new Count(Source.EMPTY, Literal.keyword(Source.EMPTY, "*"))))
116+
);
117+
}
118+
} else {
119+
plan = plan.children().getFirst();
120+
}
121+
return plan;
122+
});
123+
124+
countPlan.setPreOptimized();
125+
return countPlan;
126+
}
127+
128+
129+
/**
130+
* Returns a plan that approximates the original plan. It consists of the
131+
* original plan, with the leftmost STATS function replaced by:
132+
* "SAMPLE probability | STATS sample_corrected_aggs".
133+
*
134+
* The sample probability is based on the total row count that would reach
135+
* the STATS function, which is obtained by executing the countPlan.
136+
*/
137+
public LogicalPlan approximatePlan(Result countResult) {
138+
long rowCount = ((LongBlock) (countResult.pages().getFirst().getBlock(0))).getLong(0);
139+
if (rowCount <= 1000) {
140+
return logicalPlan;
141+
}
142+
double sampleProbability = (double) SAMPLE_ROW_COUNT / rowCount;
143+
144+
Holder<Boolean> encounteredStats = new Holder<>(false);
145+
LogicalPlan approximatePlan = logicalPlan.transformUp(plan -> {
146+
if (plan instanceof LeafPlan) {
147+
encounteredStats.set(false);
148+
} else if (encounteredStats.get() == false) {
149+
if (plan instanceof Aggregate aggregate) {
150+
encounteredStats.set(true);
151+
Expression sampleProbabilityExpr = new Literal(Source.EMPTY, sampleProbability, DataType.DOUBLE);
152+
Sample sample = new Sample(Source.EMPTY, sampleProbabilityExpr, aggregate.child());
153+
plan = aggregate.replaceChild(sample);
154+
plan = plan.transformExpressionsOnlyUp(
155+
expr -> expr instanceof NeedsSampleCorrection nsc ? nsc.sampleCorrection(sampleProbabilityExpr) : expr
156+
);
157+
}
158+
}
159+
return plan;
160+
});
161+
162+
approximatePlan.setPreOptimized();
163+
return approximatePlan;
164+
}
165+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.approximate;
9+
10+
import org.elasticsearch.xpack.esql.core.expression.Expression;
11+
12+
public interface NeedsSampleCorrection {
13+
Expression sampleCorrection(Expression sampleProbability);
14+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Count.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
1313
import org.elasticsearch.compute.aggregation.CountAggregatorFunction;
1414
import org.elasticsearch.compute.data.AggregateMetricDoubleBlockBuilder;
15+
import org.elasticsearch.xpack.esql.approximate.NeedsSampleCorrection;
1516
import org.elasticsearch.xpack.esql.core.expression.Expression;
1617
import org.elasticsearch.xpack.esql.core.expression.Literal;
1718
import org.elasticsearch.xpack.esql.core.expression.Nullability;
@@ -25,8 +26,10 @@
2526
import org.elasticsearch.xpack.esql.expression.function.FunctionType;
2627
import org.elasticsearch.xpack.esql.expression.function.Param;
2728
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromAggregateMetricDouble;
29+
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
2830
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvCount;
2931
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
32+
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
3033
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
3134
import org.elasticsearch.xpack.esql.planner.ToAggregator;
3235

@@ -37,7 +40,7 @@
3740
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
3841
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
3942

40-
public class Count extends AggregateFunction implements ToAggregator, SurrogateExpression {
43+
public class Count extends AggregateFunction implements ToAggregator, SurrogateExpression, NeedsSampleCorrection {
4144
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Count", Count::new);
4245

4346
@FunctionInfo(
@@ -169,4 +172,9 @@ public Expression surrogate() {
169172

170173
return null;
171174
}
175+
176+
@Override
177+
public Expression sampleCorrection(Expression sampleProbability) {
178+
return new ToLong(source(), new Div(source(), new Count(source(), field(), filter()), sampleProbability));
179+
}
172180
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.compute.aggregation.SumIntAggregatorFunctionSupplier;
1414
import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier;
1515
import org.elasticsearch.compute.data.AggregateMetricDoubleBlockBuilder;
16+
import org.elasticsearch.xpack.esql.approximate.NeedsSampleCorrection;
1617
import org.elasticsearch.xpack.esql.core.expression.Expression;
1718
import org.elasticsearch.xpack.esql.core.expression.Literal;
1819
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
@@ -26,7 +27,9 @@
2627
import org.elasticsearch.xpack.esql.expression.function.FunctionType;
2728
import org.elasticsearch.xpack.esql.expression.function.Param;
2829
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromAggregateMetricDouble;
30+
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
2931
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum;
32+
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
3033
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
3134

3235
import java.io.IOException;
@@ -43,7 +46,7 @@
4346
/**
4447
* Sum all values of a field in matching documents.
4548
*/
46-
public class Sum extends NumericAggregate implements SurrogateExpression {
49+
public class Sum extends NumericAggregate implements SurrogateExpression, NeedsSampleCorrection {
4750
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Sum", Sum::new);
4851

4952
@FunctionInfo(
@@ -145,4 +148,14 @@ public Expression surrogate() {
145148
// SUM(const) is equivalent to MV_SUM(const)*COUNT(*).
146149
return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD))) : null;
147150
}
151+
152+
@Override
153+
public Expression sampleCorrection(Expression sampleProbability) {
154+
Expression correctedSum = new Div(source(), new Sum(source(), field(), filter()), sampleProbability);
155+
return switch (dataType()) {
156+
case DOUBLE -> correctedSum;
157+
case LONG -> new ToLong(source(), correctedSum);
158+
default -> throw new IllegalStateException("unexpected data type [" + dataType() + "]");
159+
};
160+
}
148161
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ public PlanType transformExpressionsOnly(Function<Expression, ? extends Expressi
102102
return transformPropertiesOnly(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(rule)));
103103
}
104104

105+
public PlanType transformExpressionsOnlyUp(Function<Expression, ? extends Expression> rule) {
106+
return transformPropertiesOnly(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(rule)));
107+
}
108+
105109
public <E extends Expression> PlanType transformExpressionsOnly(Class<E> typeToken, Function<E, ? extends Expression> rule) {
106110
return transformPropertiesOnly(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule)));
107111
}

0 commit comments

Comments
 (0)