Skip to content

Commit ed675f6

Browse files
committed
move sample to front
1 parent a58d98d commit ed675f6

File tree

1 file changed

+66
-67
lines changed
  • x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate

1 file changed

+66
-67
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate/Approximate.java

Lines changed: 66 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,7 @@ private LogicalPlan countPlan(double sampleProbability) {
225225
} else if (encounteredStats.get() == false) {
226226
if (plan instanceof Aggregate aggregate) {
227227
encounteredStats.set(true);
228-
Expression sampleProbabilityExpr = new Literal(Source.EMPTY, sampleProbability, DataType.DOUBLE);
229-
Sample sample = new Sample(Source.EMPTY, sampleProbabilityExpr, aggregate.child());
228+
Sample sample = new Sample(Source.EMPTY, Literal.fromDouble(Source.EMPTY, sampleProbability), aggregate.child());
230229
plan = new Aggregate(
231230
Source.EMPTY,
232231
sample,
@@ -296,78 +295,78 @@ private LogicalPlan approximatePlan(double sampleProbability) {
296295
Holder<Boolean> encounteredStats = new Holder<>(false);
297296
LogicalPlan approximatePlan = logicalPlan.transformUp(plan -> {
298297
if (plan instanceof LeafPlan) {
299-
encounteredStats.set(false);
300-
} else if (encounteredStats.get() == false) {
301-
if (plan instanceof Aggregate aggregate) {
302-
encounteredStats.set(true);
303-
Expression sampleProbabilityExpr = new Literal(Source.EMPTY, sampleProbability, DataType.DOUBLE);
304-
Expression bucketProbabilityExpr = new Literal(Source.EMPTY, sampleProbability / BUCKET_COUNT, DataType.DOUBLE);
305-
Sample sample = new Sample(Source.EMPTY, sampleProbabilityExpr, aggregate.child());
306-
Alias sampleId = new Alias(
298+
return new Sample(Source.EMPTY, Literal.fromDouble(Source.EMPTY, sampleProbability), plan);
299+
} else if (encounteredStats.get() == false && plan instanceof Aggregate aggregate) {
300+
encounteredStats.set(true);
301+
Alias sampleId = new Alias(
302+
Source.EMPTY,
303+
".sample_id",
304+
new MvAppend(
307305
Source.EMPTY,
308-
".sample_id",
309-
new MvAppend(
310-
Source.EMPTY,
311-
new Literal(Source.EMPTY, -1, DataType.INTEGER),
312-
new Random(Source.EMPTY, new Literal(Source.EMPTY, BUCKET_COUNT, DataType.INTEGER))
313-
)
314-
);
315-
Eval addSampleId = new Eval(Source.EMPTY, sample, List.of(sampleId));
316-
List<NamedExpression> aggregates = new ArrayList<>();
317-
for (NamedExpression aggr : aggregate.aggregates()) {
318-
if (aggr instanceof Alias alias && alias.child() instanceof AggregateFunction) {
319-
aggregates.add(new Alias(Source.EMPTY, ".sampled-" + alias.name(), alias.child()));
320-
} else {
321-
aggregates.add(aggr);
322-
}
306+
Literal.integer(Source.EMPTY, -1),
307+
new Random(Source.EMPTY, Literal.integer(Source.EMPTY, BUCKET_COUNT))
308+
)
309+
);
310+
Eval addSampleId = new Eval(Source.EMPTY, aggregate.child(), List.of(sampleId));
311+
List<NamedExpression> aggregates = new ArrayList<>();
312+
for (NamedExpression aggr : aggregate.aggregates()) {
313+
if (aggr instanceof Alias alias && alias.child() instanceof AggregateFunction) {
314+
aggregates.add(new Alias(Source.EMPTY, ".sampled-" + alias.name(), alias.child()));
315+
} else {
316+
aggregates.add(aggr);
323317
}
324-
List<Expression> groupings = new ArrayList<>(aggregate.groupings());
325-
groupings.add(sampleId.toAttribute());
326-
aggregates.add(sampleId.toAttribute());
327-
Aggregate aggregateWithSampledId = (Aggregate) aggregate.with(addSampleId, groupings, aggregates)
328-
.transformExpressionsOnlyUp(
329-
expr -> expr instanceof NeedsSampleCorrection nsc ? nsc.sampleCorrection(
330-
new Case(Source.EMPTY,
331-
new Equals(Source.EMPTY, sampleId.toAttribute(), Literal.integer(Source.EMPTY, -1)),
332-
List.of(sampleProbabilityExpr, bucketProbabilityExpr))) : expr);
333-
aggregates = new ArrayList<>();
334-
for (int i = 0; i < aggregate.aggregates().size(); i++) {
335-
NamedExpression aggr = aggregate.aggregates().get(i);
336-
NamedExpression sampledAggr = aggregateWithSampledId.aggregates().get(i);
337-
if (aggr instanceof Alias alias && alias.child() instanceof AggregateFunction aggFn) {
338-
// TODO: probably filter low non-empty bucket counts. They're inaccurate and for skew, you need >=3.
339-
aggregates.add(
340-
alias.replaceChild(
341-
new ConfidenceInterval( // TODO: move confidence level to the end
318+
}
319+
List<Expression> groupings = new ArrayList<>(aggregate.groupings());
320+
groupings.add(sampleId.toAttribute());
321+
aggregates.add(sampleId.toAttribute());
322+
Aggregate aggregateWithSampledId = (Aggregate) aggregate.with(addSampleId, groupings, aggregates)
323+
.transformExpressionsOnlyUp(
324+
expr -> expr instanceof NeedsSampleCorrection nsc ? nsc.sampleCorrection(
325+
new Case(Source.EMPTY,
326+
new Equals(Source.EMPTY, sampleId.toAttribute(), Literal.integer(Source.EMPTY, -1)),
327+
List.of(
328+
Literal.fromDouble(Source.EMPTY, sampleProbability),
329+
Literal.fromDouble(Source.EMPTY, sampleProbability / BUCKET_COUNT)
330+
)
331+
)
332+
) : expr);
333+
aggregates = new ArrayList<>();
334+
for (int i = 0; i < aggregate.aggregates().size(); i++) {
335+
NamedExpression aggr = aggregate.aggregates().get(i);
336+
NamedExpression sampledAggr = aggregateWithSampledId.aggregates().get(i);
337+
if (aggr instanceof Alias alias && alias.child() instanceof AggregateFunction aggFn) {
338+
// TODO: probably filter low non-empty bucket counts. They're inaccurate and for skew, you need >=3.
339+
aggregates.add(
340+
alias.replaceChild(
341+
new ConfidenceInterval( // TODO: move confidence level to the end
342+
Source.EMPTY,
343+
new Min(
344+
Source.EMPTY,
345+
sampledAggr.toAttribute(),
346+
new Equals(Source.EMPTY, sampleId.toAttribute(), Literal.integer(Source.EMPTY, -1))
347+
),
348+
new Top(
342349
Source.EMPTY,
343-
new Min(
344-
Source.EMPTY,
345-
sampledAggr.toAttribute(),
346-
new Equals(Source.EMPTY, sampleId.toAttribute(), Literal.integer(Source.EMPTY, -1))
347-
),
348-
new Top(
349-
Source.EMPTY,
350-
sampledAggr.toAttribute(),
351-
new NotEquals(Source.EMPTY, sampleId.toAttribute(), Literal.integer(Source.EMPTY, -1)),
352-
Literal.integer(Source.EMPTY, BUCKET_COUNT),
353-
Literal.keyword(Source.EMPTY, "ASC")
354-
),
350+
sampledAggr.toAttribute(),
351+
new NotEquals(Source.EMPTY, sampleId.toAttribute(), Literal.integer(Source.EMPTY, -1)),
355352
Literal.integer(Source.EMPTY, BUCKET_COUNT),
356-
Literal.fromDouble(Source.EMPTY, aggFn instanceof NeedsSampleCorrection ? 0.0 : Double.NaN)
357-
)
353+
Literal.keyword(Source.EMPTY, "ASC")
354+
),
355+
Literal.integer(Source.EMPTY, BUCKET_COUNT),
356+
Literal.fromDouble(Source.EMPTY, aggFn instanceof NeedsSampleCorrection ? 0.0 : Double.NaN)
358357
)
359-
);
360-
} else {
361-
aggregates.add(aggr);
362-
}
358+
)
359+
);
360+
} else {
361+
aggregates.add(aggr);
363362
}
364-
plan = new Aggregate(
365-
Source.EMPTY,
366-
aggregateWithSampledId,
367-
aggregate.groupings().stream().map(e -> e instanceof Alias a ? a.toAttribute() : e).toList(),
368-
aggregates
369-
);
370363
}
364+
plan = new Aggregate(
365+
Source.EMPTY,
366+
aggregateWithSampledId,
367+
aggregate.groupings().stream().map(e -> e instanceof Alias a ? a.toAttribute() : e).toList(),
368+
aggregates
369+
);
371370
}
372371
return plan;
373372
});

0 commit comments

Comments
 (0)