@@ -225,8 +225,7 @@ private LogicalPlan countPlan(double sampleProbability) {
225225 } else if (encounteredStats .get () == false ) {
226226 if (plan instanceof Aggregate aggregate ) {
227227 encounteredStats .set (true );
228- Expression sampleProbabilityExpr = new Literal (Source .EMPTY , sampleProbability , DataType .DOUBLE );
229- Sample sample = new Sample (Source .EMPTY , sampleProbabilityExpr , aggregate .child ());
228+ Sample sample = new Sample (Source .EMPTY , Literal .fromDouble (Source .EMPTY , sampleProbability ), aggregate .child ());
230229 plan = new Aggregate (
231230 Source .EMPTY ,
232231 sample ,
@@ -296,78 +295,78 @@ private LogicalPlan approximatePlan(double sampleProbability) {
296295 Holder <Boolean > encounteredStats = new Holder <>(false );
297296 LogicalPlan approximatePlan = logicalPlan .transformUp (plan -> {
298297 if (plan instanceof LeafPlan ) {
299- encounteredStats .set (false );
300- } else if (encounteredStats .get () == false ) {
301- if (plan instanceof Aggregate aggregate ) {
302- encounteredStats .set (true );
303- Expression sampleProbabilityExpr = new Literal (Source .EMPTY , sampleProbability , DataType .DOUBLE );
304- Expression bucketProbabilityExpr = new Literal (Source .EMPTY , sampleProbability / BUCKET_COUNT , DataType .DOUBLE );
305- Sample sample = new Sample (Source .EMPTY , sampleProbabilityExpr , aggregate .child ());
306- Alias sampleId = new Alias (
298+ return new Sample (Source .EMPTY , Literal .fromDouble (Source .EMPTY , sampleProbability ), plan );
299+ } else if (encounteredStats .get () == false && plan instanceof Aggregate aggregate ) {
300+ encounteredStats .set (true );
301+ Alias sampleId = new Alias (
302+ Source .EMPTY ,
303+ ".sample_id" ,
304+ new MvAppend (
307305 Source .EMPTY ,
308- ".sample_id" ,
309- new MvAppend (
310- Source .EMPTY ,
311- new Literal (Source .EMPTY , -1 , DataType .INTEGER ),
312- new Random (Source .EMPTY , new Literal (Source .EMPTY , BUCKET_COUNT , DataType .INTEGER ))
313- )
314- );
315- Eval addSampleId = new Eval (Source .EMPTY , sample , List .of (sampleId ));
316- List <NamedExpression > aggregates = new ArrayList <>();
317- for (NamedExpression aggr : aggregate .aggregates ()) {
318- if (aggr instanceof Alias alias && alias .child () instanceof AggregateFunction ) {
319- aggregates .add (new Alias (Source .EMPTY , ".sampled-" + alias .name (), alias .child ()));
320- } else {
321- aggregates .add (aggr );
322- }
306+ Literal .integer (Source .EMPTY , -1 ),
307+ new Random (Source .EMPTY , Literal .integer (Source .EMPTY , BUCKET_COUNT ))
308+ )
309+ );
310+ Eval addSampleId = new Eval (Source .EMPTY , aggregate .child (), List .of (sampleId ));
311+ List <NamedExpression > aggregates = new ArrayList <>();
312+ for (NamedExpression aggr : aggregate .aggregates ()) {
313+ if (aggr instanceof Alias alias && alias .child () instanceof AggregateFunction ) {
314+ aggregates .add (new Alias (Source .EMPTY , ".sampled-" + alias .name (), alias .child ()));
315+ } else {
316+ aggregates .add (aggr );
323317 }
324- List <Expression > groupings = new ArrayList <>(aggregate .groupings ());
325- groupings .add (sampleId .toAttribute ());
326- aggregates .add (sampleId .toAttribute ());
327- Aggregate aggregateWithSampledId = (Aggregate ) aggregate .with (addSampleId , groupings , aggregates )
328- .transformExpressionsOnlyUp (
329- expr -> expr instanceof NeedsSampleCorrection nsc ? nsc .sampleCorrection (
330- new Case (Source .EMPTY ,
331- new Equals (Source .EMPTY , sampleId .toAttribute (), Literal .integer (Source .EMPTY , -1 )),
332- List .of (sampleProbabilityExpr , bucketProbabilityExpr ))) : expr );
333- aggregates = new ArrayList <>();
334- for (int i = 0 ; i < aggregate .aggregates ().size (); i ++) {
335- NamedExpression aggr = aggregate .aggregates ().get (i );
336- NamedExpression sampledAggr = aggregateWithSampledId .aggregates ().get (i );
337- if (aggr instanceof Alias alias && alias .child () instanceof AggregateFunction aggFn ) {
338- // TODO: probably filter low non-empty bucket counts. They're inaccurate and for skew, you need >=3.
339- aggregates .add (
340- alias .replaceChild (
341- new ConfidenceInterval ( // TODO: move confidence level to the end
318+ }
319+ List <Expression > groupings = new ArrayList <>(aggregate .groupings ());
320+ groupings .add (sampleId .toAttribute ());
321+ aggregates .add (sampleId .toAttribute ());
322+ Aggregate aggregateWithSampledId = (Aggregate ) aggregate .with (addSampleId , groupings , aggregates )
323+ .transformExpressionsOnlyUp (
324+ expr -> expr instanceof NeedsSampleCorrection nsc ? nsc .sampleCorrection (
325+ new Case (Source .EMPTY ,
326+ new Equals (Source .EMPTY , sampleId .toAttribute (), Literal .integer (Source .EMPTY , -1 )),
327+ List .of (
328+ Literal .fromDouble (Source .EMPTY , sampleProbability ),
329+ Literal .fromDouble (Source .EMPTY , sampleProbability / BUCKET_COUNT )
330+ )
331+ )
332+ ) : expr );
333+ aggregates = new ArrayList <>();
334+ for (int i = 0 ; i < aggregate .aggregates ().size (); i ++) {
335+ NamedExpression aggr = aggregate .aggregates ().get (i );
336+ NamedExpression sampledAggr = aggregateWithSampledId .aggregates ().get (i );
337+ if (aggr instanceof Alias alias && alias .child () instanceof AggregateFunction aggFn ) {
338+ // TODO: probably filter low non-empty bucket counts. They're inaccurate and for skew, you need >=3.
339+ aggregates .add (
340+ alias .replaceChild (
341+ new ConfidenceInterval ( // TODO: move confidence level to the end
342+ Source .EMPTY ,
343+ new Min (
344+ Source .EMPTY ,
345+ sampledAggr .toAttribute (),
346+ new Equals (Source .EMPTY , sampleId .toAttribute (), Literal .integer (Source .EMPTY , -1 ))
347+ ),
348+ new Top (
342349 Source .EMPTY ,
343- new Min (
344- Source .EMPTY ,
345- sampledAggr .toAttribute (),
346- new Equals (Source .EMPTY , sampleId .toAttribute (), Literal .integer (Source .EMPTY , -1 ))
347- ),
348- new Top (
349- Source .EMPTY ,
350- sampledAggr .toAttribute (),
351- new NotEquals (Source .EMPTY , sampleId .toAttribute (), Literal .integer (Source .EMPTY , -1 )),
352- Literal .integer (Source .EMPTY , BUCKET_COUNT ),
353- Literal .keyword (Source .EMPTY , "ASC" )
354- ),
350+ sampledAggr .toAttribute (),
351+ new NotEquals (Source .EMPTY , sampleId .toAttribute (), Literal .integer (Source .EMPTY , -1 )),
355352 Literal .integer (Source .EMPTY , BUCKET_COUNT ),
356- Literal .fromDouble (Source .EMPTY , aggFn instanceof NeedsSampleCorrection ? 0.0 : Double .NaN )
357- )
353+ Literal .keyword (Source .EMPTY , "ASC" )
354+ ),
355+ Literal .integer (Source .EMPTY , BUCKET_COUNT ),
356+ Literal .fromDouble (Source .EMPTY , aggFn instanceof NeedsSampleCorrection ? 0.0 : Double .NaN )
358357 )
359- );
360- } else {
361- aggregates . add ( aggr );
362- }
358+ )
359+ );
360+ } else {
361+ aggregates . add ( aggr );
363362 }
364- plan = new Aggregate (
365- Source .EMPTY ,
366- aggregateWithSampledId ,
367- aggregate .groupings ().stream ().map (e -> e instanceof Alias a ? a .toAttribute () : e ).toList (),
368- aggregates
369- );
370363 }
364+ plan = new Aggregate (
365+ Source .EMPTY ,
366+ aggregateWithSampledId ,
367+ aggregate .groupings ().stream ().map (e -> e instanceof Alias a ? a .toAttribute () : e ).toList (),
368+ aggregates
369+ );
371370 }
372371 return plan ;
373372 });
0 commit comments