2121import org .elasticsearch .xpack .esql .core .tree .Source ;
2222import org .elasticsearch .xpack .esql .core .type .DataType ;
2323import org .elasticsearch .xpack .esql .core .util .Holder ;
24+ import org .elasticsearch .xpack .esql .expression .function .aggregate .AggregateFunction ;
2425import org .elasticsearch .xpack .esql .expression .function .aggregate .Count ;
26+ import org .elasticsearch .xpack .esql .expression .function .aggregate .Min ;
2527import org .elasticsearch .xpack .esql .expression .function .aggregate .Top ;
28+ import org .elasticsearch .xpack .esql .expression .function .scalar .convert .ToDouble ;
29+ import org .elasticsearch .xpack .esql .expression .function .scalar .convert .ToLong ;
2630import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .ConfidenceInterval ;
2731import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .MvAppend ;
2832import org .elasticsearch .xpack .esql .expression .function .scalar .random .Random ;
33+ import org .elasticsearch .xpack .esql .expression .predicate .operator .arithmetic .Mul ;
34+ import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .Equals ;
35+ import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .NotEquals ;
2936import org .elasticsearch .xpack .esql .plan .logical .Aggregate ;
3037import org .elasticsearch .xpack .esql .plan .logical .Dissect ;
3138import org .elasticsearch .xpack .esql .plan .logical .Drop ;
@@ -115,7 +122,7 @@ public interface LogicalPlanRunner {
115122 private static final Set <Class <? extends LogicalPlan >> INCOMPATIBLE_COMMANDS = Set .of (InlineStats .class , LookupJoin .class );
116123
117124 // TODO: find a good default value, or alternative ways of setting it
118- private static final int SAMPLE_ROW_COUNT = 10000 ;
125+ private static final int SAMPLE_ROW_COUNT = 100000 ;
119126
120127 private static final Logger logger = LogManager .getLogger (Approximate .class );
121128
@@ -297,6 +304,9 @@ private LogicalPlan approximatePlan(double sampleProbability) {
297304 logger .debug ("using original plan (too few rows)" );
298305 return logicalPlan ;
299306 }
307+
308+ logger .info ("### BEFORE APPROXIMATE:\n {}" , logicalPlan );
309+
300310 logger .debug ("generating approximate plan (p={})" , sampleProbability );
301311 Holder <Boolean > encounteredStats = new Holder <>(false );
302312 LogicalPlan approximatePlan = logicalPlan .transformUp (plan -> {
@@ -307,29 +317,84 @@ private LogicalPlan approximatePlan(double sampleProbability) {
307317 encounteredStats .set (true );
308318 Expression sampleProbabilityExpr = new Literal (Source .EMPTY , sampleProbability , DataType .DOUBLE );
309319 Sample sample = new Sample (Source .EMPTY , sampleProbabilityExpr , aggregate .child ());
310- Alias sampleId = new Alias (Source .EMPTY , ".sample_id" ,
311- new MvAppend (Source .EMPTY , new Literal (Source .EMPTY , 0 , DataType .INTEGER ), new Random (Source .EMPTY , new Literal (Source .EMPTY , 25 , DataType .INTEGER ))));
312- Eval addSampleId = new Eval (
320+ Alias sampleId = new Alias (
313321 Source .EMPTY ,
314- sample ,
315- List .of (sampleId )
316- );
317- List <Expression > groupings = new ArrayList <>(aggregate .groupings ());
318- groupings .add (new ReferenceAttribute (Source .EMPTY , null , ".sample_id" , DataType .INTEGER , sampleId .nullable (), sampleId .id (), sampleId .synthetic ()));
319- LogicalPlan aggregateWithSampledId = aggregate .with (addSampleId , groupings , aggregate .aggregates ()).transformExpressionsOnlyUp (
320- expr -> expr instanceof NeedsSampleCorrection nsc ? nsc .sampleCorrection (sampleProbabilityExpr ) : expr
322+ ".sample_id" ,
323+ new MvAppend (
324+ Source .EMPTY ,
325+ new Literal (Source .EMPTY , -1 , DataType .INTEGER ),
326+ new Random (Source .EMPTY , new Literal (Source .EMPTY , 25 , DataType .INTEGER ))
327+ )
321328 );
329+ Eval addSampleId = new Eval (Source .EMPTY , sample , List .of (sampleId ));
322330 List <NamedExpression > aggregates = new ArrayList <>();
323331 for (NamedExpression aggr : aggregate .aggregates ()) {
324- // aggregates.add(new Alias(Source.EMPTY, "confidence:" + aggr.name(),
325- // new ConfidenceInterval(Source.EMPTY, new Top(aggr.))));
332+ if (aggr instanceof Alias alias && alias .child () instanceof AggregateFunction ) {
333+ aggregates .add (new Alias (Source .EMPTY , ".sampled-" + alias .name (), alias .child ()));
334+ } else {
335+ aggregates .add (aggr );
336+ }
326337 }
327- plan = new Aggregate (Source .EMPTY , aggregateWithSampledId , aggregate .groupings (), aggregates );
338+ List <Expression > groupings = new ArrayList <>(aggregate .groupings ());
339+ groupings .add (sampleId .toAttribute ());
340+ aggregates .add (sampleId .toAttribute ());
341+ Aggregate aggregateWithSampledId = (Aggregate ) aggregate .with (addSampleId , groupings , aggregates )
342+ .transformExpressionsOnlyUp (
343+ expr -> expr instanceof NeedsSampleCorrection nsc ? nsc .sampleCorrection (sampleProbabilityExpr ) : expr
344+ );
345+ aggregates = new ArrayList <>();
346+ for (int i = 0 ; i < aggregate .aggregates ().size (); i ++) {
347+ NamedExpression aggr = aggregate .aggregates ().get (i );
348+ NamedExpression sampledAggr = aggregateWithSampledId .aggregates ().get (i );
349+ if (aggr instanceof Alias alias && alias .child () instanceof AggregateFunction ) {
350+ aggregates .add (
351+ alias .replaceChild (
352+ new ToLong ( // TODO: cast to original type
353+ Source .EMPTY ,
354+ new ConfidenceInterval ( // TODO: move confidence level to the end
355+ Source .EMPTY ,
356+ new ToDouble (
357+ Source .EMPTY ,
358+ new Min (
359+ Source .EMPTY ,
360+ sampledAggr .toAttribute (),
361+ new Equals (Source .EMPTY , sampleId .toAttribute (), Literal .integer (Source .EMPTY , -1 ))
362+ )
363+ ),
364+ new ToDouble (
365+ Source .EMPTY ,
366+ new Top (
367+ Source .EMPTY ,
368+ new Mul (
369+ Source .EMPTY ,
370+ Literal .integer (Source .EMPTY , 25 ),
371+ sampledAggr .toAttribute ()
372+ ),
373+ new NotEquals (Source .EMPTY , sampleId .toAttribute (), Literal .integer (Source .EMPTY , -1 )),
374+ Literal .integer (Source .EMPTY , 25 ),
375+ Literal .keyword (Source .EMPTY , "ASC" )
376+ )
377+ )
378+ )
379+ )
380+ )
381+ );
382+ } else {
383+ aggregates .add (aggr );
384+ }
385+ }
386+ plan = new Aggregate (
387+ Source .EMPTY ,
388+ aggregateWithSampledId ,
389+ aggregate .groupings ().stream ().map (e -> e instanceof Alias a ? a .toAttribute () : e ).toList (),
390+ aggregates
391+ );
328392 }
329393 }
330394 return plan ;
331395 });
332396
397+ logger .info ("### AFTER APPROXIMATE:\n {}" , approximatePlan );
333398
334399 approximatePlan .setPreOptimized ();
335400 return approximatePlan ;
0 commit comments