3030import org .apache .calcite .rex .RexLiteral ;
3131import com .tdunning .math .stats .MergingDigest ;
3232import org .apache .calcite .sql .SqlOperator ;
33+ import org .apache .drill .shaded .guava .com .google .common .annotations .VisibleForTesting ;
3334import org .apache .drill .shaded .guava .com .google .common .base .Preconditions ;
35+ import org .apache .drill .shaded .guava .com .google .common .collect .BoundType ;
3436import org .apache .drill .shaded .guava .com .google .common .collect .Range ;
3537
3638/**
@@ -85,6 +87,11 @@ public Double[] getBuckets() {
8587 return buckets ;
8688 }
8789
90+ @ VisibleForTesting
91+ protected void setBucketValue (int index , Double value ) {
92+ buckets [index ] = value ;
93+ }
94+
8895 /**
8996 * Get the number of buckets in the histogram
9097 * number of buckets is 1 less than the total # entries in the buckets array since last
@@ -105,7 +112,7 @@ public int getNumBuckets() {
105112 * first and last bucket may be partially covered and all other buckets in the middle are fully covered.
106113 */
107114 @ Override
108- public Double estimatedSelectivity (final RexNode columnFilter , final long totalRowCount ) {
115+ public Double estimatedSelectivity (final RexNode columnFilter , final long totalRowCount , final long ndv ) {
109116 if (numRowsPerBucket == 0 ) {
110117 return null ;
111118 }
@@ -127,7 +134,7 @@ public Double estimatedSelectivity(final RexNode columnFilter, final long totalR
127134 int unknown = unknownFilterList .size ();
128135
129136 if (valuesRange .hasLowerBound () || valuesRange .hasUpperBound ()) {
130- numSelectedRows = getSelectedRows (valuesRange );
137+ numSelectedRows = getSelectedRows (valuesRange , ndv );
131138 } else {
132139 numSelectedRows = 0 ;
133140 }
@@ -178,101 +185,143 @@ private Range<Double> getValuesRange(List<RexNode> filterList, Range<Double> ful
178185 return currentRange ;
179186 }
180187
181- private long getSelectedRows ( final Range range ) {
182- final int numBuckets = buckets . length - 1 ;
188+ @ VisibleForTesting
189+ protected long getSelectedRows ( final Range range , final long ndv ) {
183190 double startBucketFraction = 1.0 ;
184191 double endBucketFraction = 1.0 ;
185192 long numRows = 0 ;
186193 int result ;
187194 Double lowValue = null ;
188195 Double highValue = null ;
189- final int first = 0 ;
190- final int last = buckets .length - 1 ;
191- int startBucket = first ;
192- int endBucket = last ;
196+ final int firstStartPointIndex = 0 ;
197+ final int lastEndPointIndex = buckets .length - 1 ;
198+ int startBucket = firstStartPointIndex ;
199+ int endBucket = lastEndPointIndex - 1 ;
193200
194201 if (range .hasLowerBound ()) {
195202 lowValue = (Double ) range .lowerEndpoint ();
196203
197- // if low value is greater than the end point of the last bucket then none of the rows qualify
198- if (lowValue .compareTo (buckets [last ]) > 0 ) {
204+ // if low value is greater than the end point of the last bucket or if it is equal but the range is open (i.e
205+ // predicate is of type > 5 where 5 is the end point of last bucket) then none of the rows qualify
206+ result = lowValue .compareTo (buckets [lastEndPointIndex ]);
207+ if (result > 0 || result == 0 && range .lowerBoundType () == BoundType .OPEN ) {
199208 return 0 ;
200209 }
201-
202- result = lowValue .compareTo (buckets [first ]);
210+ result = lowValue .compareTo (buckets [firstStartPointIndex ]);
203211
204212 // if low value is less than or equal to the first bucket's start point then start with the first bucket and all
205213 // rows in first bucket are included
206214 if (result <= 0 ) {
207- startBucket = first ;
215+ startBucket = firstStartPointIndex ;
208216 startBucketFraction = 1.0 ;
209217 } else {
210- // Use a simplified logic where we treat > and >= the same when computing selectivity since the
211- // difference is going to be very small for reasonable sized data sets
212- startBucket = getContainingBucket (lowValue , numBuckets );
218+ startBucket = getContainingBucket (lowValue , lastEndPointIndex , true );
219+
213220 // expecting start bucket to be >= 0 since other conditions have been handled previously
214221 Preconditions .checkArgument (startBucket >= 0 , "Expected start bucket id >= 0" );
215- startBucketFraction = ((double ) (buckets [startBucket + 1 ] - lowValue )) / (buckets [startBucket + 1 ] - buckets [startBucket ]);
222+
223+ if (buckets [startBucket + 1 ].doubleValue () == buckets [startBucket ].doubleValue ()) {
224+ // if start and end points of the bucket are the same, consider entire bucket
225+ startBucketFraction = 1.0 ;
226+ } else if (range .lowerBoundType () == BoundType .CLOSED && buckets [startBucket + 1 ].doubleValue () == lowValue .doubleValue ()) {
227+ // predicate is of type >= 5.0 and 5.0 happens to be the start point of the bucket
228+ // In this case, use the overall NDV to approximate
229+ startBucketFraction = 1.0 / ndv ;
230+ } else {
231+ startBucketFraction = ((double ) (buckets [startBucket + 1 ] - lowValue )) / (buckets [startBucket + 1 ] - buckets [startBucket ]);
232+ }
216233 }
217234 }
218235
219236 if (range .hasUpperBound ()) {
220237 highValue = (Double ) range .upperEndpoint ();
221238
222- // if the high value is less than the start point of the first bucket then none of the rows qualify
223- if (highValue .compareTo (buckets [first ]) < 0 ) {
239+ // if the high value is less than the start point of the first bucket or if it is equal but the range is open (i.e
240+ // predicate is of type < 1 where 1 is the start point of the first bucket) then none of the rows qualify
241+ result = highValue .compareTo (buckets [firstStartPointIndex ]);
242+ if (result < 0 || (result == 0 && range .upperBoundType () == BoundType .OPEN )) {
224243 return 0 ;
225244 }
226245
227- result = highValue .compareTo (buckets [last ]);
246+ result = highValue .compareTo (buckets [lastEndPointIndex ]);
228247
229248 // if high value is greater than or equal to the last bucket's end point then include the last bucket and all rows in
230249 // last bucket qualify
231250 if (result >= 0 ) {
232- endBucket = last ;
251+ endBucket = lastEndPointIndex - 1 ;
233252 endBucketFraction = 1.0 ;
234253 } else {
235- // Use a simplified logic where we treat < and <= the same when computing selectivity since the
236- // difference is going to be very small for reasonable sized data sets
237- endBucket = getContainingBucket (highValue , numBuckets );
254+ endBucket = getContainingBucket (highValue , lastEndPointIndex , false );
255+
238256 // expecting end bucket to be >= 0 since other conditions have been handled previously
239257 Preconditions .checkArgument (endBucket >= 0 , "Expected end bucket id >= 0" );
240- endBucketFraction = ((double )(highValue - buckets [endBucket ])) / (buckets [endBucket + 1 ] - buckets [endBucket ]);
258+
259+ if (buckets [endBucket + 1 ].doubleValue () == buckets [endBucket ].doubleValue ()) {
260+ // if start and end points of the bucket are the same, consider entire bucket
261+ endBucketFraction = 1.0 ;
262+ } else if (range .upperBoundType () == BoundType .CLOSED && buckets [endBucket ].doubleValue () == highValue .doubleValue ()) {
263+ // predicate is of type <= 5.0 and 5.0 happens to be the start point of the bucket
264+ // In this case, use the overall NDV to approximate
265+ endBucketFraction = 1.0 /ndv ;
266+ } else {
267+ endBucketFraction = ((double ) (highValue - buckets [endBucket ])) / (buckets [endBucket + 1 ] - buckets [endBucket ]);
268+ }
241269 }
242270 }
243271
244- Preconditions .checkArgument (startBucket <= endBucket );
272+ Preconditions .checkArgument (startBucket >= 0 && startBucket + 1 <= lastEndPointIndex , "Invalid startBucket: " + startBucket );
273+ Preconditions .checkArgument (endBucket >= 0 && endBucket + 1 <= lastEndPointIndex , "Invalid endBucket: " + endBucket );
274+ Preconditions .checkArgument (startBucket <= endBucket ,
275+ "Start bucket: " + startBucket + " should be less than or equal to end bucket: " + endBucket );
245276
246- // if the endBucketId corresponds to the last endpoint, then adjust it to be one less
247- if (endBucket == last ) {
248- endBucket = last - 1 ;
249- }
250- if (startBucket == endBucket && highValue != null && lowValue != null ) {
277+ if (startBucket == endBucket ) {
251278 // if the start and end buckets are the same, interpolate based on the difference between the high and low value
252- numRows = (long ) ((highValue - lowValue ) / (buckets [endBucket + 1 ] - buckets [startBucket ]) * numRowsPerBucket );
279+ if (highValue != null && lowValue != null ) {
280+ numRows = (long ) ((highValue - lowValue ) / (buckets [startBucket + 1 ] - buckets [startBucket ]) * numRowsPerBucket );
281+ } else if (highValue != null ) {
282+ numRows = (long ) (endBucketFraction * numRowsPerBucket );
283+ } else {
284+ numRows = (long ) (startBucketFraction * numRowsPerBucket );
285+ }
253286 } else {
254- numRows = (long ) ((startBucketFraction + endBucketFraction ) * numRowsPerBucket + (endBucket - startBucket - 1 ) * numRowsPerBucket );
287+ int numIntermediateBuckets = (endBucket > startBucket + 1 ) ? (endBucket - startBucket - 1 ) : 0 ;
288+ numRows = (long ) ((startBucketFraction + endBucketFraction ) * numRowsPerBucket + numIntermediateBuckets * numRowsPerBucket );
255289 }
256290
257291 return numRows ;
258292 }
259293
260- private int getContainingBucket (final Double value , final int numBuckets ) {
294+ /**
295+ * Get the start point of the containing bucket for the supplied value. If there are multiple buckets with the
296+ * same start point, return either the first matching or last matching depending on firstMatching flag
297+ * @param value the input double value
298+ * @param lastEndPointIndex
299+ * @param firstMatching If true, return the first bucket that matches the specified criteria otherwise return the last one
300+ * @return index of either the first or last matching bucket if a match was found, otherwise return -1
301+ */
302+ private int getContainingBucket (final Double value , final int lastEndPointIndex , final boolean firstMatching ) {
261303 int i = 0 ;
262304 int containing_bucket = -1 ;
305+
263306 // check which bucket this value falls in
264- for (; i <= numBuckets ; i ++) {
307+ for (; i <= lastEndPointIndex ; i ++) {
265308 int result = buckets [i ].compareTo (value );
266309 if (result > 0 ) {
267310 containing_bucket = i - 1 ;
268311 break ;
269312 } else if (result == 0 ) {
270- containing_bucket = i ;
271- break ;
313+ // if we are already at the lastEndPointIndex, mark the containing bucket
314+ // as i-1 because the containing bucket should correspond to the start point of the bucket
315+ // (recall that lastEndPointIndex is the final end point of the last bucket)
316+ containing_bucket = (i == lastEndPointIndex ) ? i - 1 : i ;
317+ if (firstMatching ) {
318+ // break if we are only interested in the first matching bucket
319+ break ;
320+ }
272321 }
273322 }
274323 return containing_bucket ;
275- }
324+ }
276325
277326 private Double getLiteralValue (final RexNode filter ) {
278327 Double value = null ;
0 commit comments