Skip to content

Commit bfe60fc

Browse files
HyukjinKwoncloud-fan
authored andcommitted
[SPARK-24934][SQL] Explicitly whitelist supported types in upper/lower bounds for in-memory partition pruning
## What changes were proposed in this pull request? Looks we intentionally set `null` for upper/lower bounds for complex types and don't use it. However, these look used in in-memory partition pruning, which ends up with incorrect results. This PR proposes to explicitly whitelist the supported types. ```scala val df = Seq(Array("a", "b"), Array("c", "d")).toDF("arrayCol") df.cache().filter("arrayCol > array('a', 'b')").show() ``` ```scala val df = sql("select cast('a' as binary) as a") df.cache().filter("a == cast('a' as binary)").show() ``` **Before:** ``` +--------+ |arrayCol| +--------+ +--------+ ``` ``` +---+ | a| +---+ +---+ ``` **After:** ``` +--------+ |arrayCol| +--------+ | [c, d]| +--------+ ``` ``` +----+ | a| +----+ |[61]| +----+ ``` ## How was this patch tested? Unit tests were added and manually tested. Author: hyukjinkwon <[email protected]> Closes apache#21882 from HyukjinKwon/stats-filter.
1 parent 65a4bc1 commit bfe60fc

File tree

2 files changed

+58
-14
lines changed

2 files changed

+58
-14
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,18 @@ case class InMemoryTableScanExec(
183183
private val stats = relation.partitionStatistics
184184
private def statsFor(a: Attribute) = stats.forAttribute(a)
185185

186+
// Currently, only use statistics from atomic types except binary type only.
187+
private object ExtractableLiteral {
188+
def unapply(expr: Expression): Option[Literal] = expr match {
189+
case lit: Literal => lit.dataType match {
190+
case BinaryType => None
191+
case _: AtomicType => Some(lit)
192+
case _ => None
193+
}
194+
case _ => None
195+
}
196+
}
197+
186198
// Returned filter predicate should return false iff it is impossible for the input expression
187199
// to evaluate to `true' based on statistics collected about this partition batch.
188200
@transient lazy val buildFilter: PartialFunction[Expression, Expression] = {
@@ -194,33 +206,37 @@ case class InMemoryTableScanExec(
194206
if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
195207
buildFilter(lhs) || buildFilter(rhs)
196208

197-
case EqualTo(a: AttributeReference, l: Literal) =>
209+
case EqualTo(a: AttributeReference, ExtractableLiteral(l)) =>
198210
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
199-
case EqualTo(l: Literal, a: AttributeReference) =>
211+
case EqualTo(ExtractableLiteral(l), a: AttributeReference) =>
200212
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
201213

202-
case EqualNullSafe(a: AttributeReference, l: Literal) =>
214+
case EqualNullSafe(a: AttributeReference, ExtractableLiteral(l)) =>
203215
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
204-
case EqualNullSafe(l: Literal, a: AttributeReference) =>
216+
case EqualNullSafe(ExtractableLiteral(l), a: AttributeReference) =>
205217
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
206218

207-
case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l
208-
case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound
219+
case LessThan(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound < l
220+
case LessThan(ExtractableLiteral(l), a: AttributeReference) => l < statsFor(a).upperBound
209221

210-
case LessThanOrEqual(a: AttributeReference, l: Literal) => statsFor(a).lowerBound <= l
211-
case LessThanOrEqual(l: Literal, a: AttributeReference) => l <= statsFor(a).upperBound
222+
case LessThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) =>
223+
statsFor(a).lowerBound <= l
224+
case LessThanOrEqual(ExtractableLiteral(l), a: AttributeReference) =>
225+
l <= statsFor(a).upperBound
212226

213-
case GreaterThan(a: AttributeReference, l: Literal) => l < statsFor(a).upperBound
214-
case GreaterThan(l: Literal, a: AttributeReference) => statsFor(a).lowerBound < l
227+
case GreaterThan(a: AttributeReference, ExtractableLiteral(l)) => l < statsFor(a).upperBound
228+
case GreaterThan(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound < l
215229

216-
case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound
217-
case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l
230+
case GreaterThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) =>
231+
l <= statsFor(a).upperBound
232+
case GreaterThanOrEqual(ExtractableLiteral(l), a: AttributeReference) =>
233+
statsFor(a).lowerBound <= l
218234

219235
case IsNull(a: Attribute) => statsFor(a).nullCount > 0
220236
case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0
221237

222238
case In(a: AttributeReference, list: Seq[Expression])
223-
if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty =>
239+
if list.forall(ExtractableLiteral.unapply(_).isDefined) && list.nonEmpty =>
224240
list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] &&
225241
l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _)
226242
}

sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.columnar
2020
import org.scalatest.BeforeAndAfterEach
2121

2222
import org.apache.spark.SparkFunSuite
23+
import org.apache.spark.sql.DataFrame
2324
import org.apache.spark.sql.internal.SQLConf
2425
import org.apache.spark.sql.test.SharedSQLContext
2526
import org.apache.spark.sql.test.SQLTestData._
@@ -35,6 +36,12 @@ class PartitionBatchPruningSuite
3536
private lazy val originalColumnBatchSize = spark.conf.get(SQLConf.COLUMN_BATCH_SIZE)
3637
private lazy val originalInMemoryPartitionPruning =
3738
spark.conf.get(SQLConf.IN_MEMORY_PARTITION_PRUNING)
39+
private val testArrayData = (1 to 100).map { key =>
40+
Tuple1(Array.fill(key)(key))
41+
}
42+
private val testBinaryData = (1 to 100).map { key =>
43+
Tuple1(Array.fill(key)(key.toByte))
44+
}
3845

3946
override protected def beforeAll(): Unit = {
4047
super.beforeAll()
@@ -71,12 +78,22 @@ class PartitionBatchPruningSuite
7178
}, 5).toDF()
7279
pruningStringData.createOrReplaceTempView("pruningStringData")
7380
spark.catalog.cacheTable("pruningStringData")
81+
82+
val pruningArrayData = sparkContext.makeRDD(testArrayData, 5).toDF()
83+
pruningArrayData.createOrReplaceTempView("pruningArrayData")
84+
spark.catalog.cacheTable("pruningArrayData")
85+
86+
val pruningBinaryData = sparkContext.makeRDD(testBinaryData, 5).toDF()
87+
pruningBinaryData.createOrReplaceTempView("pruningBinaryData")
88+
spark.catalog.cacheTable("pruningBinaryData")
7489
}
7590

7691
override protected def afterEach(): Unit = {
7792
try {
7893
spark.catalog.uncacheTable("pruningData")
7994
spark.catalog.uncacheTable("pruningStringData")
95+
spark.catalog.uncacheTable("pruningArrayData")
96+
spark.catalog.uncacheTable("pruningBinaryData")
8097
} finally {
8198
super.afterEach()
8299
}
@@ -95,6 +112,14 @@ class PartitionBatchPruningSuite
95112
checkBatchPruning("SELECT key FROM pruningData WHERE 11 >= key", 1, 2)(1 to 11)
96113
checkBatchPruning("SELECT key FROM pruningData WHERE 88 < key", 1, 2)(89 to 100)
97114
checkBatchPruning("SELECT key FROM pruningData WHERE 89 <= key", 1, 2)(89 to 100)
115+
// Do not filter on array type
116+
checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 = array(1)", 5, 10)(Seq(Array(1)))
117+
checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 <= array(1)", 5, 10)(Seq(Array(1)))
118+
checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 >= array(1)", 5, 10)(
119+
testArrayData.map(_._1))
120+
// Do not filter on binary type
121+
checkBatchPruning(
122+
"SELECT _1 FROM pruningBinaryData WHERE _1 == binary(chr(1))", 5, 10)(Seq(Array(1.toByte)))
98123

99124
// IS NULL
100125
checkBatchPruning("SELECT key FROM pruningData WHERE value IS NULL", 5, 5) {
@@ -131,6 +156,9 @@ class PartitionBatchPruningSuite
131156
checkBatchPruning(
132157
"SELECT CAST(s AS INT) FROM pruningStringData WHERE s IN ('99', '150', '201')", 1, 1)(
133158
Seq(150))
159+
// Do not filter on array type
160+
checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 IN (array(1), array(2, 2))", 5, 10)(
161+
Seq(Array(1), Array(2, 2)))
134162

135163
// With unsupported `InSet` predicate
136164
{
@@ -161,7 +189,7 @@ class PartitionBatchPruningSuite
161189
query: String,
162190
expectedReadPartitions: Int,
163191
expectedReadBatches: Int)(
164-
expectedQueryResult: => Seq[Int]): Unit = {
192+
expectedQueryResult: => Seq[Any]): Unit = {
165193

166194
test(query) {
167195
val df = sql(query)

0 commit comments

Comments
 (0)