1919import org .elasticsearch .xpack .esql .core .expression .Literal ;
2020import org .elasticsearch .xpack .esql .core .expression .NameId ;
2121import org .elasticsearch .xpack .esql .core .expression .NamedExpression ;
22- import org .elasticsearch .xpack .esql .core .expression .function .scalar .ScalarFunction ;
2322import org .elasticsearch .xpack .esql .core .tree .Source ;
24- import org .elasticsearch .xpack .esql .core .type .DataType ;
2523import org .elasticsearch .xpack .esql .core .util .Holder ;
2624import org .elasticsearch .xpack .esql .expression .function .aggregate .AggregateFunction ;
2725import org .elasticsearch .xpack .esql .expression .function .aggregate .Count ;
28- import org .elasticsearch .xpack .esql .expression .function .aggregate .Min ;
2926import org .elasticsearch .xpack .esql .expression .function .aggregate .Top ;
3027import org .elasticsearch .xpack .esql .expression .function .aggregate .Values ;
31- import org .elasticsearch .xpack .esql .expression .function .scalar .conditional .Case ;
3228import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .ConfidenceInterval ;
3329import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .MvAppend ;
30+ import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .MvContains ;
3431import org .elasticsearch .xpack .esql .expression .function .scalar .random .Random ;
3532import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .Equals ;
33+ import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .In ;
3634import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .NotEquals ;
3735import org .elasticsearch .xpack .esql .plan .logical .Aggregate ;
3836import org .elasticsearch .xpack .esql .plan .logical .ChangePoint ;
5654import org .elasticsearch .xpack .esql .session .Result ;
5755
5856import java .util .ArrayList ;
57+ import java .util .HashMap ;
5958import java .util .HashSet ;
6059import java .util .List ;
60+ import java .util .Map ;
6161import java .util .Set ;
6262import java .util .stream .Collectors ;
6363
@@ -120,7 +120,7 @@ public interface LogicalPlanRunner {
120120 // TODO: find a good default value, or alternative ways of setting it
121121 private static final int SAMPLE_ROW_COUNT = 100000 ;
122122
123- private static final int BUCKET_COUNT = 25 ;
123+ private static final int BUCKET_COUNT = 3 ; // 25;
124124
125125 private static final Logger logger = LogManager .getLogger (Approximate .class );
126126
@@ -301,12 +301,11 @@ private LogicalPlan approximatePlan(double sampleProbability) {
301301
302302 logger .debug ("generating approximate plan (p={})" , sampleProbability );
303303 Holder <Boolean > encounteredStats = new Holder <>(false );
304- Set <NameId > variablesWithConfidenceInterval = new HashSet <>();
305- Set <NameId > variablesWithPastConfidenceInterval = new HashSet <>();
304+ Map <NameId , List <Alias >> variablesWithConfidenceInterval = new HashMap <>();
306305
307- Alias bucketId = new Alias (
306+ Alias bucketIdField = new Alias (
308307 Source .EMPTY ,
309- ". bucket_id" ,
308+ "$$ bucket_id" ,
310309 new MvAppend (
311310 Source .EMPTY ,
312311 Literal .integer (Source .EMPTY , -1 ),
@@ -316,125 +315,134 @@ private LogicalPlan approximatePlan(double sampleProbability) {
316315
317316 LogicalPlan approximatePlan = logicalPlan .transformUp (plan -> {
318317 if (plan instanceof LeafPlan ) {
319- return new Sample (Source .EMPTY , Literal .fromDouble (Source .EMPTY , sampleProbability ), plan );
318+ plan = new Sample (Source .EMPTY , Literal .fromDouble (Source .EMPTY , sampleProbability ), plan );
320319 } else if (encounteredStats .get () == false && plan instanceof Aggregate aggregate ) {
321320 encounteredStats .set (true );
322321
323- Eval addBucketId = new Eval (Source .EMPTY , aggregate .child (), List .of (bucketId ));
324- List <NamedExpression > aggregates = new ArrayList <>(aggregate .aggregates ());
325- aggregates .add (bucketId .toAttribute ());
326- List <Expression > groupings = new ArrayList <>(aggregate .groupings ());
327- groupings .add (bucketId .toAttribute ());
328-
329- Aggregate aggregateWithBucketId = (Aggregate ) aggregate .with (addBucketId , groupings , aggregates )
330- .transformExpressionsOnlyUp (
331- expr -> expr instanceof NeedsSampleCorrection nsc ? nsc .sampleCorrection (
332- new Case (Source .EMPTY ,
333- new Equals (Source .EMPTY , bucketId .toAttribute (), Literal .integer (Source .EMPTY , -1 )),
334- List .of (
335- Literal .fromDouble (Source .EMPTY , sampleProbability ),
336- Literal .fromDouble (Source .EMPTY , sampleProbability / BUCKET_COUNT )
337- )
338- )
339- ) : expr );
340-
341- for (NamedExpression aggr : aggregate .aggregates ()) {
342- if (aggr instanceof Alias alias && alias .child () instanceof AggregateFunction ) {
343- variablesWithConfidenceInterval .add (alias .id ());
322+ Eval addBucketId = new Eval (Source .EMPTY , aggregate .child (), List .of (bucketIdField ));
323+ List <NamedExpression > aggregates = new ArrayList <>();
324+ for (NamedExpression aggOrKey : aggregate .aggregates ()) {
325+ if ((aggOrKey instanceof Alias alias && alias .child () instanceof AggregateFunction ) == false ) {
326+ // This is a grouping key, not an aggregate function.
327+ aggregates .add (aggOrKey );
328+ continue ;
344329 }
330+ Alias aggAlias = (Alias ) aggOrKey ;
331+ AggregateFunction agg = (AggregateFunction ) aggAlias .child ();
332+ List <Alias > bucketedAggs = new ArrayList <>();
333+ for (int bucketId = -1 ; bucketId < BUCKET_COUNT ; bucketId ++) {
334+ AggregateFunction bucketedAgg = agg .withFilter (
335+ new MvContains (Source .EMPTY , bucketIdField .toAttribute (), Literal .integer (Source .EMPTY , bucketId )));
336+ Expression correctedAgg = bucketedAgg instanceof NeedsSampleCorrection nsc
337+ ? nsc .sampleCorrection (
338+ Literal .fromDouble (Source .EMPTY , bucketId == -1 ? sampleProbability : sampleProbability / BUCKET_COUNT )
339+ )
340+ : bucketedAgg ;
341+ Alias correctAggAlias = bucketId == -1
342+ ? aggAlias .replaceChild (correctedAgg )
343+ : new Alias (
344+ Source .EMPTY ,
345+ aggOrKey .name () + "$bucket:" + bucketId ,
346+ correctedAgg
347+ );
348+ aggregates .add (correctAggAlias );
349+ if (bucketId >= 0 ) {
350+ bucketedAggs .add (correctAggAlias );
351+ }
352+ }
353+ variablesWithConfidenceInterval .put (aggOrKey .id (), bucketedAggs );
345354 }
355+ plan = aggregate .with (addBucketId , aggregate .groupings (), aggregates );
346356
347- return aggregateWithBucketId ;
348357 } else if (encounteredStats .get ()) {
349358 System .out .println ("@@@ UPDATE variablesWithConfidenceInterval" );
350359 System .out .println ("plan = " + plan );
351- System .out .println ("vars = " + variablesWithConfidenceInterval + " / " + variablesWithPastConfidenceInterval );
360+ System .out .println ("vars = " + variablesWithConfidenceInterval );
352361 switch (plan ) {
353362 case Eval eval :
363+ List <Alias > newFields = new ArrayList <>(eval .fields ());
354364 for (Alias field : eval .fields ()) {
355- if (field .anyMatch (expr -> expr instanceof NamedExpression named && variablesWithConfidenceInterval .contains (named .id ()))) {
356- // TODO: blacklist / whitelist?
357- if (field .child () instanceof MvAppend == false && field .dataType ().isNumeric ()) {
358- variablesWithConfidenceInterval .add (field .id ());
359- } else {
360- variablesWithPastConfidenceInterval .add (field .id ());
365+ if (field .dataType ().isNumeric () == false || field .child ().anyMatch (expr -> expr instanceof MvAppend )) {
366+ continue ;
367+ }
368+ if (field .child ().anyMatch (expr -> expr instanceof NamedExpression named && variablesWithConfidenceInterval .containsKey (named .id ()))) {
369+ List <Alias > newBuckets = new ArrayList <>();
370+ for (int bucketId = 0 ; bucketId < BUCKET_COUNT ; bucketId ++) {
371+ final int finalBucketId = bucketId ;
372+ Expression newChild = field .child ().transformDown (expr -> {
373+ if (expr instanceof NamedExpression named && variablesWithConfidenceInterval .containsKey (named .id ())) {
374+ List <Alias > buckets = variablesWithConfidenceInterval .get (named .id ());
375+ return buckets .get (finalBucketId ).toAttribute ();
376+ } else {
377+ return expr ;
378+ }
379+ });
380+ Alias newField = new Alias (
381+ Source .EMPTY ,
382+ field .name () + "$bucket:" + bucketId ,
383+ newChild
384+ );
385+ newBuckets .add (newField );
361386 }
362- } else if ( field . anyMatch ( expr -> expr instanceof NamedExpression named && variablesWithPastConfidenceInterval . contains ( named .id ()))) {
363- variablesWithPastConfidenceInterval . add ( field . id () );
387+ variablesWithConfidenceInterval . put ( field .id (), newBuckets );
388+ newFields . addAll ( newBuckets );
364389 }
365390 }
391+ plan = new Eval (Source .EMPTY , eval .child (), newFields );
366392 break ;
367- case Project project :
368- List <NamedExpression > projections = new ArrayList <>(project .projections ());
369- projections .add (bucketId .toAttribute ());
370- plan = project .withProjections (projections );
371- break ;
393+ // case Project project:
394+ // List<NamedExpression> projections = new ArrayList<>(project.projections());
395+ // plan = project.withProjections(projections);
396+ // break;
372397 case Rename rename :
373398 // TODO
374399 break ;
375400 default :
376401 }
377- System .out .println ("vars = " + variablesWithConfidenceInterval + " / " + variablesWithPastConfidenceInterval );
402+ System .out .println ("vars = " + variablesWithConfidenceInterval );
378403 }
379404 return plan ;
380405 });
381406
382407 System .out .println ("### OUTPUT: " + approximatePlan .output ());
383408
384- List <NamedExpression > aggregates = new ArrayList <>();
385- List <Expression > groupings = new ArrayList <>();
386- for (Attribute attribute : approximatePlan .output ()) {
387- if (attribute .id () == bucketId .id ()) {
388- continue ;
389- }
390- if (variablesWithConfidenceInterval .contains (attribute .id ()) || variablesWithPastConfidenceInterval .contains (attribute .id ())) {
391- Alias bestEstimate = new Alias (
409+ List <Alias > confidenceIntervals = new ArrayList <>();
410+ for (Attribute output : logicalPlan .output ()) {
411+ if (variablesWithConfidenceInterval .containsKey (output .id ())) {
412+ List <Alias > buckets = variablesWithConfidenceInterval .get (output .id ());
413+ Expression appendedBuckets = buckets .getFirst ().toAttribute ();
414+ for (int i = 1 ; i < buckets .size (); i ++) {
415+ appendedBuckets = new MvAppend (Source .EMPTY , appendedBuckets , buckets .get (i ).toAttribute ());
416+ }
417+ confidenceIntervals .add (new Alias (
392418 Source .EMPTY ,
393- attribute .name (),
394- new Values (
395- Source .EMPTY ,
396- attribute ,
397- new Equals (Source .EMPTY , bucketId .toAttribute (), Literal .integer (Source .EMPTY , -1 ))
398- )
399- );
400- aggregates .add (bestEstimate );
401- if (variablesWithConfidenceInterval .contains (attribute .id ())) {
402- aggregates .add (new Alias (
403- Source .EMPTY , "CONFIDENCE_INTERVAL(" + attribute .name () + ")" , new ConfidenceInterval (
419+ "CONFIDENCE_INTERVAL(" + output .name () + ")" ,
420+ new ConfidenceInterval (
404421 Source .EMPTY ,
405- bestEstimate .toAttribute (),
406- new Top (
407- Source .EMPTY ,
408- attribute ,
409- new NotEquals (Source .EMPTY , bucketId .toAttribute (), Literal .integer (Source .EMPTY , -1 )),
410- Literal .integer (Source .EMPTY , BUCKET_COUNT ),
411- Literal .keyword (Source .EMPTY , "ASC" )
412- ),
422+ output ,
423+ appendedBuckets ,
413424 Literal .integer (Source .EMPTY , BUCKET_COUNT ),
414425 Literal .fromDouble (Source .EMPTY , 0.0 )
415- // TODO: fix, 0.0 or NaN ?? TODO: remove!!
416426 )
417- ));
418- }
419- } else {
420- aggregates .add (attribute );
421- groupings .add (attribute );
427+ ));
422428 }
423429 }
424430
425- Aggregate finalAggregate = new Aggregate (
431+ approximatePlan = new Eval (
426432 Source .EMPTY ,
427433 approximatePlan ,
428- groupings ,
429- aggregates
434+ confidenceIntervals
430435 );
431436
432- if (approximatePlan instanceof Limit || approximatePlan instanceof TopN ) {
433- approximatePlan = ((UnaryPlan ) approximatePlan ).replaceChild (finalAggregate .replaceChild (((UnaryPlan ) approximatePlan ).child ()));
434- } else {
435- // Can this happen? Or is the last command always a Limit / TopN?
436- approximatePlan = finalAggregate ;
437- }
437+ Set <Attribute > dropAttributes = variablesWithConfidenceInterval .values ().stream ().flatMap (List ::stream ).map (Alias ::toAttribute ).collect (Collectors .toSet ());
438+ List <Attribute > keepAttributes = new ArrayList <>(approximatePlan .output ());
439+ keepAttributes .removeAll (dropAttributes );
440+
441+ approximatePlan = new Project (
442+ Source .EMPTY ,
443+ approximatePlan ,
444+ keepAttributes
445+ );
438446
439447 logger .info ("### AFTER APPROXIMATE:\n {}" , approximatePlan );
440448 System .out .println ("### OUTPUT: " + approximatePlan .output ());
0 commit comments