19
19
import org .elasticsearch .xpack .esql .core .expression .Literal ;
20
20
import org .elasticsearch .xpack .esql .core .expression .NameId ;
21
21
import org .elasticsearch .xpack .esql .core .expression .NamedExpression ;
22
- import org .elasticsearch .xpack .esql .core .expression .function .scalar .ScalarFunction ;
23
22
import org .elasticsearch .xpack .esql .core .tree .Source ;
24
- import org .elasticsearch .xpack .esql .core .type .DataType ;
25
23
import org .elasticsearch .xpack .esql .core .util .Holder ;
26
24
import org .elasticsearch .xpack .esql .expression .function .aggregate .AggregateFunction ;
27
25
import org .elasticsearch .xpack .esql .expression .function .aggregate .Count ;
28
- import org .elasticsearch .xpack .esql .expression .function .aggregate .Min ;
29
26
import org .elasticsearch .xpack .esql .expression .function .aggregate .Top ;
30
27
import org .elasticsearch .xpack .esql .expression .function .aggregate .Values ;
31
- import org .elasticsearch .xpack .esql .expression .function .scalar .conditional .Case ;
32
28
import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .ConfidenceInterval ;
33
29
import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .MvAppend ;
30
+ import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .MvContains ;
34
31
import org .elasticsearch .xpack .esql .expression .function .scalar .random .Random ;
35
32
import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .Equals ;
33
+ import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .In ;
36
34
import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .NotEquals ;
37
35
import org .elasticsearch .xpack .esql .plan .logical .Aggregate ;
38
36
import org .elasticsearch .xpack .esql .plan .logical .ChangePoint ;
56
54
import org .elasticsearch .xpack .esql .session .Result ;
57
55
58
56
import java .util .ArrayList ;
57
+ import java .util .HashMap ;
59
58
import java .util .HashSet ;
60
59
import java .util .List ;
60
+ import java .util .Map ;
61
61
import java .util .Set ;
62
62
import java .util .stream .Collectors ;
63
63
@@ -120,7 +120,7 @@ public interface LogicalPlanRunner {
120
120
// TODO: find a good default value, or alternative ways of setting it
121
121
private static final int SAMPLE_ROW_COUNT = 100000 ;
122
122
123
- private static final int BUCKET_COUNT = 25 ;
123
+ private static final int BUCKET_COUNT = 3 ; // 25;
124
124
125
125
private static final Logger logger = LogManager .getLogger (Approximate .class );
126
126
@@ -301,12 +301,11 @@ private LogicalPlan approximatePlan(double sampleProbability) {
301
301
302
302
logger .debug ("generating approximate plan (p={})" , sampleProbability );
303
303
Holder <Boolean > encounteredStats = new Holder <>(false );
304
- Set <NameId > variablesWithConfidenceInterval = new HashSet <>();
305
- Set <NameId > variablesWithPastConfidenceInterval = new HashSet <>();
304
+ Map <NameId , List <Alias >> variablesWithConfidenceInterval = new HashMap <>();
306
305
307
- Alias bucketId = new Alias (
306
+ Alias bucketIdField = new Alias (
308
307
Source .EMPTY ,
309
- ". bucket_id" ,
308
+ "$$ bucket_id" ,
310
309
new MvAppend (
311
310
Source .EMPTY ,
312
311
Literal .integer (Source .EMPTY , -1 ),
@@ -316,125 +315,134 @@ private LogicalPlan approximatePlan(double sampleProbability) {
316
315
317
316
LogicalPlan approximatePlan = logicalPlan .transformUp (plan -> {
318
317
if (plan instanceof LeafPlan ) {
319
- return new Sample (Source .EMPTY , Literal .fromDouble (Source .EMPTY , sampleProbability ), plan );
318
+ plan = new Sample (Source .EMPTY , Literal .fromDouble (Source .EMPTY , sampleProbability ), plan );
320
319
} else if (encounteredStats .get () == false && plan instanceof Aggregate aggregate ) {
321
320
encounteredStats .set (true );
322
321
323
- Eval addBucketId = new Eval (Source .EMPTY , aggregate .child (), List .of (bucketId ));
324
- List <NamedExpression > aggregates = new ArrayList <>(aggregate .aggregates ());
325
- aggregates .add (bucketId .toAttribute ());
326
- List <Expression > groupings = new ArrayList <>(aggregate .groupings ());
327
- groupings .add (bucketId .toAttribute ());
328
-
329
- Aggregate aggregateWithBucketId = (Aggregate ) aggregate .with (addBucketId , groupings , aggregates )
330
- .transformExpressionsOnlyUp (
331
- expr -> expr instanceof NeedsSampleCorrection nsc ? nsc .sampleCorrection (
332
- new Case (Source .EMPTY ,
333
- new Equals (Source .EMPTY , bucketId .toAttribute (), Literal .integer (Source .EMPTY , -1 )),
334
- List .of (
335
- Literal .fromDouble (Source .EMPTY , sampleProbability ),
336
- Literal .fromDouble (Source .EMPTY , sampleProbability / BUCKET_COUNT )
337
- )
338
- )
339
- ) : expr );
340
-
341
- for (NamedExpression aggr : aggregate .aggregates ()) {
342
- if (aggr instanceof Alias alias && alias .child () instanceof AggregateFunction ) {
343
- variablesWithConfidenceInterval .add (alias .id ());
322
+ Eval addBucketId = new Eval (Source .EMPTY , aggregate .child (), List .of (bucketIdField ));
323
+ List <NamedExpression > aggregates = new ArrayList <>();
324
+ for (NamedExpression aggOrKey : aggregate .aggregates ()) {
325
+ if ((aggOrKey instanceof Alias alias && alias .child () instanceof AggregateFunction ) == false ) {
326
+ // This is a grouping key, not an aggregate function.
327
+ aggregates .add (aggOrKey );
328
+ continue ;
344
329
}
330
+ Alias aggAlias = (Alias ) aggOrKey ;
331
+ AggregateFunction agg = (AggregateFunction ) aggAlias .child ();
332
+ List <Alias > bucketedAggs = new ArrayList <>();
333
+ for (int bucketId = -1 ; bucketId < BUCKET_COUNT ; bucketId ++) {
334
+ AggregateFunction bucketedAgg = agg .withFilter (
335
+ new MvContains (Source .EMPTY , bucketIdField .toAttribute (), Literal .integer (Source .EMPTY , bucketId )));
336
+ Expression correctedAgg = bucketedAgg instanceof NeedsSampleCorrection nsc
337
+ ? nsc .sampleCorrection (
338
+ Literal .fromDouble (Source .EMPTY , bucketId == -1 ? sampleProbability : sampleProbability / BUCKET_COUNT )
339
+ )
340
+ : bucketedAgg ;
341
+ Alias correctAggAlias = bucketId == -1
342
+ ? aggAlias .replaceChild (correctedAgg )
343
+ : new Alias (
344
+ Source .EMPTY ,
345
+ aggOrKey .name () + "$bucket:" + bucketId ,
346
+ correctedAgg
347
+ );
348
+ aggregates .add (correctAggAlias );
349
+ if (bucketId >= 0 ) {
350
+ bucketedAggs .add (correctAggAlias );
351
+ }
352
+ }
353
+ variablesWithConfidenceInterval .put (aggOrKey .id (), bucketedAggs );
345
354
}
355
+ plan = aggregate .with (addBucketId , aggregate .groupings (), aggregates );
346
356
347
- return aggregateWithBucketId ;
348
357
} else if (encounteredStats .get ()) {
349
358
System .out .println ("@@@ UPDATE variablesWithConfidenceInterval" );
350
359
System .out .println ("plan = " + plan );
351
- System .out .println ("vars = " + variablesWithConfidenceInterval + " / " + variablesWithPastConfidenceInterval );
360
+ System .out .println ("vars = " + variablesWithConfidenceInterval );
352
361
switch (plan ) {
353
362
case Eval eval :
363
+ List <Alias > newFields = new ArrayList <>(eval .fields ());
354
364
for (Alias field : eval .fields ()) {
355
- if (field .anyMatch (expr -> expr instanceof NamedExpression named && variablesWithConfidenceInterval .contains (named .id ()))) {
356
- // TODO: blacklist / whitelist?
357
- if (field .child () instanceof MvAppend == false && field .dataType ().isNumeric ()) {
358
- variablesWithConfidenceInterval .add (field .id ());
359
- } else {
360
- variablesWithPastConfidenceInterval .add (field .id ());
365
+ if (field .dataType ().isNumeric () == false || field .child ().anyMatch (expr -> expr instanceof MvAppend )) {
366
+ continue ;
367
+ }
368
+ if (field .child ().anyMatch (expr -> expr instanceof NamedExpression named && variablesWithConfidenceInterval .containsKey (named .id ()))) {
369
+ List <Alias > newBuckets = new ArrayList <>();
370
+ for (int bucketId = 0 ; bucketId < BUCKET_COUNT ; bucketId ++) {
371
+ final int finalBucketId = bucketId ;
372
+ Expression newChild = field .child ().transformDown (expr -> {
373
+ if (expr instanceof NamedExpression named && variablesWithConfidenceInterval .containsKey (named .id ())) {
374
+ List <Alias > buckets = variablesWithConfidenceInterval .get (named .id ());
375
+ return buckets .get (finalBucketId ).toAttribute ();
376
+ } else {
377
+ return expr ;
378
+ }
379
+ });
380
+ Alias newField = new Alias (
381
+ Source .EMPTY ,
382
+ field .name () + "$bucket:" + bucketId ,
383
+ newChild
384
+ );
385
+ newBuckets .add (newField );
361
386
}
362
- } else if ( field . anyMatch ( expr -> expr instanceof NamedExpression named && variablesWithPastConfidenceInterval . contains ( named .id ()))) {
363
- variablesWithPastConfidenceInterval . add ( field . id () );
387
+ variablesWithConfidenceInterval . put ( field .id (), newBuckets );
388
+ newFields . addAll ( newBuckets );
364
389
}
365
390
}
391
+ plan = new Eval (Source .EMPTY , eval .child (), newFields );
366
392
break ;
367
- case Project project :
368
- List <NamedExpression > projections = new ArrayList <>(project .projections ());
369
- projections .add (bucketId .toAttribute ());
370
- plan = project .withProjections (projections );
371
- break ;
393
+ // case Project project:
394
+ // List<NamedExpression> projections = new ArrayList<>(project.projections());
395
+ // plan = project.withProjections(projections);
396
+ // break;
372
397
case Rename rename :
373
398
// TODO
374
399
break ;
375
400
default :
376
401
}
377
- System .out .println ("vars = " + variablesWithConfidenceInterval + " / " + variablesWithPastConfidenceInterval );
402
+ System .out .println ("vars = " + variablesWithConfidenceInterval );
378
403
}
379
404
return plan ;
380
405
});
381
406
382
407
System .out .println ("### OUTPUT: " + approximatePlan .output ());
383
408
384
- List <NamedExpression > aggregates = new ArrayList <>();
385
- List <Expression > groupings = new ArrayList <>();
386
- for (Attribute attribute : approximatePlan .output ()) {
387
- if (attribute .id () == bucketId .id ()) {
388
- continue ;
389
- }
390
- if (variablesWithConfidenceInterval .contains (attribute .id ()) || variablesWithPastConfidenceInterval .contains (attribute .id ())) {
391
- Alias bestEstimate = new Alias (
409
+ List <Alias > confidenceIntervals = new ArrayList <>();
410
+ for (Attribute output : logicalPlan .output ()) {
411
+ if (variablesWithConfidenceInterval .containsKey (output .id ())) {
412
+ List <Alias > buckets = variablesWithConfidenceInterval .get (output .id ());
413
+ Expression appendedBuckets = buckets .getFirst ().toAttribute ();
414
+ for (int i = 1 ; i < buckets .size (); i ++) {
415
+ appendedBuckets = new MvAppend (Source .EMPTY , appendedBuckets , buckets .get (i ).toAttribute ());
416
+ }
417
+ confidenceIntervals .add (new Alias (
392
418
Source .EMPTY ,
393
- attribute .name (),
394
- new Values (
395
- Source .EMPTY ,
396
- attribute ,
397
- new Equals (Source .EMPTY , bucketId .toAttribute (), Literal .integer (Source .EMPTY , -1 ))
398
- )
399
- );
400
- aggregates .add (bestEstimate );
401
- if (variablesWithConfidenceInterval .contains (attribute .id ())) {
402
- aggregates .add (new Alias (
403
- Source .EMPTY , "CONFIDENCE_INTERVAL(" + attribute .name () + ")" , new ConfidenceInterval (
419
+ "CONFIDENCE_INTERVAL(" + output .name () + ")" ,
420
+ new ConfidenceInterval (
404
421
Source .EMPTY ,
405
- bestEstimate .toAttribute (),
406
- new Top (
407
- Source .EMPTY ,
408
- attribute ,
409
- new NotEquals (Source .EMPTY , bucketId .toAttribute (), Literal .integer (Source .EMPTY , -1 )),
410
- Literal .integer (Source .EMPTY , BUCKET_COUNT ),
411
- Literal .keyword (Source .EMPTY , "ASC" )
412
- ),
422
+ output ,
423
+ appendedBuckets ,
413
424
Literal .integer (Source .EMPTY , BUCKET_COUNT ),
414
425
Literal .fromDouble (Source .EMPTY , 0.0 )
415
- // TODO: fix, 0.0 or NaN ?? TODO: remove!!
416
426
)
417
- ));
418
- }
419
- } else {
420
- aggregates .add (attribute );
421
- groupings .add (attribute );
427
+ ));
422
428
}
423
429
}
424
430
425
- Aggregate finalAggregate = new Aggregate (
431
+ approximatePlan = new Eval (
426
432
Source .EMPTY ,
427
433
approximatePlan ,
428
- groupings ,
429
- aggregates
434
+ confidenceIntervals
430
435
);
431
436
432
- if (approximatePlan instanceof Limit || approximatePlan instanceof TopN ) {
433
- approximatePlan = ((UnaryPlan ) approximatePlan ).replaceChild (finalAggregate .replaceChild (((UnaryPlan ) approximatePlan ).child ()));
434
- } else {
435
- // Can this happen? Or is the last command always a Limit / TopN?
436
- approximatePlan = finalAggregate ;
437
- }
437
+ Set <Attribute > dropAttributes = variablesWithConfidenceInterval .values ().stream ().flatMap (List ::stream ).map (Alias ::toAttribute ).collect (Collectors .toSet ());
438
+ List <Attribute > keepAttributes = new ArrayList <>(approximatePlan .output ());
439
+ keepAttributes .removeAll (dropAttributes );
440
+
441
+ approximatePlan = new Project (
442
+ Source .EMPTY ,
443
+ approximatePlan ,
444
+ keepAttributes
445
+ );
438
446
439
447
logger .info ("### AFTER APPROXIMATE:\n {}" , approximatePlan );
440
448
System .out .println ("### OUTPUT: " + approximatePlan .output ());
0 commit comments