Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 571aa27

Browse files
wzhfycloud-fan
authored andcommitted
[SPARK-21984][SQL] Join estimation based on equi-height histogram
## What changes were proposed in this pull request? Equi-height histogram is one of the state-of-the-art statistics for cardinality estimation, which can provide better estimation accuracy, and good at cases with skew data. This PR is to improve join estimation based on equi-height histogram. The difference from basic estimation (based on ndv) is the logic for computing join cardinality and the new ndv after join. The main idea is as follows: 1. find overlapped ranges between two histograms from two join keys; 2. apply the formula `T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1))` in each overlapped range. ## How was this patch tested? Added new test cases. Author: Zhenhua Wang <[email protected]> Closes apache#19594 from wzhfy/join_estimation_histogram.
1 parent ab7346f commit 571aa27

File tree

3 files changed

+428
-4
lines changed

3 files changed

+428
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
1919

20+
import scala.collection.mutable.ArrayBuffer
2021
import scala.math.BigDecimal.RoundingMode
2122

2223
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
@@ -212,4 +213,172 @@ object EstimationUtils {
212213
}
213214
}
214215

216+
/**
217+
* Returns overlapped ranges between two histograms, in the given value range
218+
* [lowerBound, upperBound].
219+
*/
220+
def getOverlappedRanges(
221+
leftHistogram: Histogram,
222+
rightHistogram: Histogram,
223+
lowerBound: Double,
224+
upperBound: Double): Seq[OverlappedRange] = {
225+
val overlappedRanges = new ArrayBuffer[OverlappedRange]()
226+
// Only bins whose range intersect [lowerBound, upperBound] have join possibility.
227+
val leftBins = leftHistogram.bins
228+
.filter(b => b.lo <= upperBound && b.hi >= lowerBound)
229+
val rightBins = rightHistogram.bins
230+
.filter(b => b.lo <= upperBound && b.hi >= lowerBound)
231+
232+
leftBins.foreach { lb =>
233+
rightBins.foreach { rb =>
234+
val (left, leftHeight) = trimBin(lb, leftHistogram.height, lowerBound, upperBound)
235+
val (right, rightHeight) = trimBin(rb, rightHistogram.height, lowerBound, upperBound)
236+
// Only collect overlapped ranges.
237+
if (left.lo <= right.hi && left.hi >= right.lo) {
238+
// Collect overlapped ranges.
239+
val range = if (right.lo >= left.lo && right.hi >= left.hi) {
240+
// Case1: the left bin is "smaller" than the right bin
241+
// left.lo right.lo left.hi right.hi
242+
// --------+------------------+------------+----------------+------->
243+
if (left.hi == right.lo) {
244+
// The overlapped range has only one value.
245+
OverlappedRange(
246+
lo = right.lo,
247+
hi = right.lo,
248+
leftNdv = 1,
249+
rightNdv = 1,
250+
leftNumRows = leftHeight / left.ndv,
251+
rightNumRows = rightHeight / right.ndv
252+
)
253+
} else {
254+
val leftRatio = (left.hi - right.lo) / (left.hi - left.lo)
255+
val rightRatio = (left.hi - right.lo) / (right.hi - right.lo)
256+
OverlappedRange(
257+
lo = right.lo,
258+
hi = left.hi,
259+
leftNdv = left.ndv * leftRatio,
260+
rightNdv = right.ndv * rightRatio,
261+
leftNumRows = leftHeight * leftRatio,
262+
rightNumRows = rightHeight * rightRatio
263+
)
264+
}
265+
} else if (right.lo <= left.lo && right.hi <= left.hi) {
266+
// Case2: the left bin is "larger" than the right bin
267+
// right.lo left.lo right.hi left.hi
268+
// --------+------------------+------------+----------------+------->
269+
if (right.hi == left.lo) {
270+
// The overlapped range has only one value.
271+
OverlappedRange(
272+
lo = right.hi,
273+
hi = right.hi,
274+
leftNdv = 1,
275+
rightNdv = 1,
276+
leftNumRows = leftHeight / left.ndv,
277+
rightNumRows = rightHeight / right.ndv
278+
)
279+
} else {
280+
val leftRatio = (right.hi - left.lo) / (left.hi - left.lo)
281+
val rightRatio = (right.hi - left.lo) / (right.hi - right.lo)
282+
OverlappedRange(
283+
lo = left.lo,
284+
hi = right.hi,
285+
leftNdv = left.ndv * leftRatio,
286+
rightNdv = right.ndv * rightRatio,
287+
leftNumRows = leftHeight * leftRatio,
288+
rightNumRows = rightHeight * rightRatio
289+
)
290+
}
291+
} else if (right.lo >= left.lo && right.hi <= left.hi) {
292+
// Case3: the left bin contains the right bin
293+
// left.lo right.lo right.hi left.hi
294+
// --------+------------------+------------+----------------+------->
295+
val leftRatio = (right.hi - right.lo) / (left.hi - left.lo)
296+
OverlappedRange(
297+
lo = right.lo,
298+
hi = right.hi,
299+
leftNdv = left.ndv * leftRatio,
300+
rightNdv = right.ndv,
301+
leftNumRows = leftHeight * leftRatio,
302+
rightNumRows = rightHeight
303+
)
304+
} else {
305+
assert(right.lo <= left.lo && right.hi >= left.hi)
306+
// Case4: the right bin contains the left bin
307+
// right.lo left.lo left.hi right.hi
308+
// --------+------------------+------------+----------------+------->
309+
val rightRatio = (left.hi - left.lo) / (right.hi - right.lo)
310+
OverlappedRange(
311+
lo = left.lo,
312+
hi = left.hi,
313+
leftNdv = left.ndv,
314+
rightNdv = right.ndv * rightRatio,
315+
leftNumRows = leftHeight,
316+
rightNumRows = rightHeight * rightRatio
317+
)
318+
}
319+
overlappedRanges += range
320+
}
321+
}
322+
}
323+
overlappedRanges
324+
}
325+
326+
/**
327+
* Given an original bin and a value range [lowerBound, upperBound], returns the trimmed part
328+
* of the bin in that range and its number of rows.
329+
* @param bin the input histogram bin.
330+
* @param height the number of rows of the given histogram bin inside an equi-height histogram.
331+
* @param lowerBound lower bound of the given range.
332+
* @param upperBound upper bound of the given range.
333+
* @return trimmed part of the given bin and its number of rows.
334+
*/
335+
def trimBin(bin: HistogramBin, height: Double, lowerBound: Double, upperBound: Double)
336+
: (HistogramBin, Double) = {
337+
val (lo, hi) = if (bin.lo <= lowerBound && bin.hi >= upperBound) {
338+
// bin.lo lowerBound upperBound bin.hi
339+
// --------+------------------+------------+-------------+------->
340+
(lowerBound, upperBound)
341+
} else if (bin.lo <= lowerBound && bin.hi >= lowerBound) {
342+
// bin.lo lowerBound bin.hi upperBound
343+
// --------+------------------+------------+-------------+------->
344+
(lowerBound, bin.hi)
345+
} else if (bin.lo <= upperBound && bin.hi >= upperBound) {
346+
// lowerBound bin.lo upperBound bin.hi
347+
// --------+------------------+------------+-------------+------->
348+
(bin.lo, upperBound)
349+
} else {
350+
// lowerBound bin.lo bin.hi upperBound
351+
// --------+------------------+------------+-------------+------->
352+
assert(bin.lo >= lowerBound && bin.hi <= upperBound)
353+
(bin.lo, bin.hi)
354+
}
355+
356+
if (hi == lo) {
357+
// Note that bin.hi == bin.lo also falls into this branch.
358+
(HistogramBin(lo, hi, 1), height / bin.ndv)
359+
} else {
360+
assert(bin.hi != bin.lo)
361+
val ratio = (hi - lo) / (bin.hi - bin.lo)
362+
(HistogramBin(lo, hi, math.ceil(bin.ndv * ratio).toLong), height * ratio)
363+
}
364+
}
365+
366+
/**
367+
* A join between two equi-height histograms may produce multiple overlapped ranges.
368+
* Each overlapped range is produced by a part of one bin in the left histogram and a part of
369+
* one bin in the right histogram.
370+
* @param lo lower bound of this overlapped range.
371+
* @param hi higher bound of this overlapped range.
372+
* @param leftNdv ndv in the left part.
373+
* @param rightNdv ndv in the right part.
374+
* @param leftNumRows number of rows in the left part.
375+
* @param rightNumRows number of rows in the right part.
376+
*/
377+
case class OverlappedRange(
378+
lo: Double,
379+
hi: Double,
380+
leftNdv: Double,
381+
rightNdv: Double,
382+
leftNumRows: Double,
383+
rightNumRows: Double)
215384
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging
2424
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression}
2525
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
2626
import org.apache.spark.sql.catalyst.plans._
27-
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics}
27+
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, Join, Statistics}
2828
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
2929

3030

@@ -191,8 +191,19 @@ case class JoinEstimation(join: Join) extends Logging {
191191
val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType)
192192
if (ValueInterval.isIntersected(lInterval, rInterval)) {
193193
val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType)
194-
val (card, joinStat) = computeByNdv(leftKey, rightKey, newMin, newMax)
195-
keyStatsAfterJoin += (leftKey -> joinStat, rightKey -> joinStat)
194+
val (card, joinStat) = (leftKeyStat.histogram, rightKeyStat.histogram) match {
195+
case (Some(l: Histogram), Some(r: Histogram)) =>
196+
computeByHistogram(leftKey, rightKey, l, r, newMin, newMax)
197+
case _ =>
198+
computeByNdv(leftKey, rightKey, newMin, newMax)
199+
}
200+
keyStatsAfterJoin += (
201+
// Histograms are propagated as unchanged. During future estimation, they should be
202+
// truncated by the updated max/min. In this way, only pointers of the histograms are
203+
// propagated and thus reduce memory consumption.
204+
leftKey -> joinStat.copy(histogram = leftKeyStat.histogram),
205+
rightKey -> joinStat.copy(histogram = rightKeyStat.histogram)
206+
)
196207
// Return cardinality estimated from the most selective join keys.
197208
if (card < joinCard) joinCard = card
198209
} else {
@@ -225,6 +236,43 @@ case class JoinEstimation(join: Join) extends Logging {
225236
(ceil(card), newStats)
226237
}
227238

239+
/** Compute join cardinality using equi-height histograms. */
240+
private def computeByHistogram(
241+
leftKey: AttributeReference,
242+
rightKey: AttributeReference,
243+
leftHistogram: Histogram,
244+
rightHistogram: Histogram,
245+
newMin: Option[Any],
246+
newMax: Option[Any]): (BigInt, ColumnStat) = {
247+
val overlappedRanges = getOverlappedRanges(
248+
leftHistogram = leftHistogram,
249+
rightHistogram = rightHistogram,
250+
// Only numeric values have equi-height histograms.
251+
lowerBound = newMin.get.toString.toDouble,
252+
upperBound = newMax.get.toString.toDouble)
253+
254+
var card: BigDecimal = 0
255+
var totalNdv: Double = 0
256+
for (i <- overlappedRanges.indices) {
257+
val range = overlappedRanges(i)
258+
if (i == 0 || range.hi != overlappedRanges(i - 1).hi) {
259+
// If range.hi == overlappedRanges(i - 1).hi, that means the current range has only one
260+
// value, and this value is already counted in the previous range. So there is no need to
261+
// count it in this range.
262+
totalNdv += math.min(range.leftNdv, range.rightNdv)
263+
}
264+
// Apply the formula in this overlapped range.
265+
card += range.leftNumRows * range.rightNumRows / math.max(range.leftNdv, range.rightNdv)
266+
}
267+
268+
val leftKeyStat = leftStats.attributeStats(leftKey)
269+
val rightKeyStat = rightStats.attributeStats(rightKey)
270+
val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen)
271+
val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2
272+
val newStats = ColumnStat(ceil(totalNdv), newMin, newMax, 0, newAvgLen, newMaxLen)
273+
(ceil(card), newStats)
274+
}
275+
228276
/**
229277
* Propagate or update column stats for output attributes.
230278
*/

0 commit comments

Comments
 (0)