2424import org .elasticsearch .xpack .esql .expression .function .aggregate .Count ;
2525import org .elasticsearch .xpack .esql .expression .function .aggregate .Min ;
2626import org .elasticsearch .xpack .esql .expression .function .aggregate .Top ;
27+ import org .elasticsearch .xpack .esql .expression .function .scalar .conditional .Case ;
2728import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .ConfidenceInterval ;
2829import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .MvAppend ;
2930import org .elasticsearch .xpack .esql .expression .function .scalar .random .Random ;
30- import org .elasticsearch .xpack .esql .expression .predicate .operator .arithmetic .Mul ;
3131import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .Equals ;
3232import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .NotEquals ;
3333import org .elasticsearch .xpack .esql .plan .logical .Aggregate ;
@@ -121,6 +121,8 @@ public interface LogicalPlanRunner {
121121 // TODO: find a good default value, or alternative ways of setting it
122122 private static final int SAMPLE_ROW_COUNT = 100000 ;
123123
124+ private static final int BUCKET_COUNT = 25 ;
125+
124126 private static final Logger logger = LogManager .getLogger (Approximate .class );
125127
126128 private final LogicalPlan logicalPlan ;
@@ -313,14 +315,15 @@ private LogicalPlan approximatePlan(double sampleProbability) {
313315 if (plan instanceof Aggregate aggregate ) {
314316 encounteredStats .set (true );
315317 Expression sampleProbabilityExpr = new Literal (Source .EMPTY , sampleProbability , DataType .DOUBLE );
318+ Expression bucketProbabilityExpr = new Literal (Source .EMPTY , sampleProbability / BUCKET_COUNT , DataType .DOUBLE );
316319 Sample sample = new Sample (Source .EMPTY , sampleProbabilityExpr , aggregate .child ());
317320 Alias sampleId = new Alias (
318321 Source .EMPTY ,
319322 ".sample_id" ,
320323 new MvAppend (
321324 Source .EMPTY ,
322325 new Literal (Source .EMPTY , -1 , DataType .INTEGER ),
323- new Random (Source .EMPTY , new Literal (Source .EMPTY , 25 , DataType .INTEGER ))
326+ new Random (Source .EMPTY , new Literal (Source .EMPTY , BUCKET_COUNT , DataType .INTEGER ))
324327 )
325328 );
326329 Eval addSampleId = new Eval (Source .EMPTY , sample , List .of (sampleId ));
@@ -337,8 +340,10 @@ private LogicalPlan approximatePlan(double sampleProbability) {
337340 aggregates .add (sampleId .toAttribute ());
338341 Aggregate aggregateWithSampledId = (Aggregate ) aggregate .with (addSampleId , groupings , aggregates )
339342 .transformExpressionsOnlyUp (
340- expr -> expr instanceof NeedsSampleCorrection nsc ? nsc .sampleCorrection (sampleProbabilityExpr ) : expr
341- );
343+ expr -> expr instanceof NeedsSampleCorrection nsc ? nsc .sampleCorrection (
344+ new Case (Source .EMPTY ,
345+ new Equals (Source .EMPTY , sampleId .toAttribute (), Literal .integer (Source .EMPTY , -1 )),
346+ List .of (sampleProbabilityExpr , bucketProbabilityExpr ))) : expr );
342347 aggregates = new ArrayList <>();
343348 for (int i = 0 ; i < aggregate .aggregates ().size (); i ++) {
344349 NamedExpression aggr = aggregate .aggregates ().get (i );
@@ -356,13 +361,9 @@ private LogicalPlan approximatePlan(double sampleProbability) {
356361 ),
357362 new Top (
358363 Source .EMPTY ,
359- new Mul ( // TODO: make this mul a sample correction 1/buckets
360- Source .EMPTY ,
361- Literal .integer (Source .EMPTY , 25 ),
362- sampledAggr .toAttribute ()
363- ),
364+ sampledAggr .toAttribute (),
364365 new NotEquals (Source .EMPTY , sampleId .toAttribute (), Literal .integer (Source .EMPTY , -1 )),
365- Literal .integer (Source .EMPTY , 25 ),
366+ Literal .integer (Source .EMPTY , BUCKET_COUNT ),
366367 Literal .keyword (Source .EMPTY , "ASC" )
367368 )
368369 )
0 commit comments