Skip to content

Commit e828917

Browse files
author
Robert Kruszewski
committed
fix tests
1 parent 8b85a42 commit e828917

File tree

2 files changed

+28
-35
lines changed

2 files changed

+28
-35
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -437,27 +437,28 @@ private[parquet] class ParquetFilters(
437437
// Parquet's type in the given file should be matched to the value's type
438438
// in the pushed filter in order to push down the filter to Parquet.
439439
def valueCanMakeFilterOn(name: String, value: Any): Boolean = {
440-
value == null || (nameToType(name) match {
441-
case ParquetBooleanType => value.isInstanceOf[JBoolean]
442-
case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number]
443-
case ParquetLongType => value.isInstanceOf[JLong]
444-
case ParquetFloatType => value.isInstanceOf[JFloat]
445-
case ParquetDoubleType => value.isInstanceOf[JDouble]
446-
case ParquetStringType => value.isInstanceOf[String]
447-
case ParquetBinaryType => value.isInstanceOf[Array[Byte]]
448-
case ParquetDateType => value.isInstanceOf[Date]
449-
case ParquetTimestampMicrosType | ParquetTimestampMillisType =>
450-
value.isInstanceOf[Timestamp]
451-
case ParquetSchemaType(DECIMAL, INT32, _, decimalMeta) =>
452-
isDecimalMatched(value, decimalMeta)
453-
case ParquetSchemaType(DECIMAL, INT64, _, decimalMeta) =>
454-
isDecimalMatched(value, decimalMeta)
455-
case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, _, decimalMeta) =>
456-
isDecimalMatched(value, decimalMeta)
457-
case set: Set[_] =>
458-
valueCanMakeFilterOn(name, set.iterator.next())
459-
case _ => false
460-
})
440+
value == null ||
441+
(value.isInstanceOf[Array[_]]
442+
&& canMakeFilterOn(name, value.asInstanceOf[Array].apply(0))) ||
443+
(nameToType(name) match {
444+
case ParquetBooleanType => value.isInstanceOf[JBoolean]
445+
case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number]
446+
case ParquetLongType => value.isInstanceOf[JLong]
447+
case ParquetFloatType => value.isInstanceOf[JFloat]
448+
case ParquetDoubleType => value.isInstanceOf[JDouble]
449+
case ParquetStringType => value.isInstanceOf[String]
450+
case ParquetBinaryType => value.isInstanceOf[Array[Byte]]
451+
case ParquetDateType => value.isInstanceOf[Date]
452+
case ParquetTimestampMicrosType | ParquetTimestampMillisType =>
453+
value.isInstanceOf[Timestamp]
454+
case ParquetSchemaType(DECIMAL, INT32, _, decimalMeta) =>
455+
isDecimalMatched(value, decimalMeta)
456+
case ParquetSchemaType(DECIMAL, INT64, _, decimalMeta) =>
457+
isDecimalMatched(value, decimalMeta)
458+
case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, _, decimalMeta) =>
459+
isDecimalMatched(value, decimalMeta)
460+
case _ => false
461+
})
461462
}
462463

463464
// Parquet does not allow dots in the column name because dots are used as a column path
@@ -525,7 +526,7 @@ private[parquet] class ParquetFilters(
525526
.map(FilterApi.not)
526527
.map(LogicalInverseRewriter.rewrite)
527528

528-
case sources.In(name, values) if canMakeFilterOn(name, values.head) =>
529+
case sources.In(name, values) if canMakeFilterOn(name, values) =>
529530
makeInSet.lift(nameToType(name)).map(_(name, values.toSet))
530531

531532
case sources.StringStartsWith(name, prefix)

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,32 +1004,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
10041004

10051005
val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema)
10061006

1007-
assertResult(Some(FilterApi.eq(intColumn("a"), null: Integer))) {
1007+
assertResult(Some(FilterApi.userDefined(intColumn("a"), SetInFilter[Integer](Set(null))))) {
10081008
parquetFilters.createFilter(parquetSchema, sources.In("a", Array(null)))
10091009
}
10101010

1011-
assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) {
1011+
assertResult(Some(FilterApi.userDefined(intColumn("a"), SetInFilter[Integer](Set(10))))) {
10121012
parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10)))
10131013
}
10141014

10151015
// Remove duplicates
1016-
assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) {
1016+
assertResult(Some(FilterApi.userDefined(intColumn("a"), SetInFilter[Integer](Set(10))))) {
10171017
parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 10)))
10181018
}
10191019

1020-
assertResult(Some(or(or(
1021-
FilterApi.eq(intColumn("a"), 10: Integer),
1022-
FilterApi.eq(intColumn("a"), 20: Integer)),
1023-
FilterApi.eq(intColumn("a"), 30: Integer)))
1024-
) {
1020+
assertResult(Some(
1021+
FilterApi.userDefined(intColumn("a"), SetInFilter[Integer](Set(10, 20, 30))))) {
10251022
parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 20, 30)))
10261023
}
10271024

1028-
assert(parquetFilters.createFilter(parquetSchema, sources.In("a",
1029-
Range(0, conf.parquetFilterPushDownInFilterThreshold).toArray)).isDefined)
1030-
assert(parquetFilters.createFilter(parquetSchema, sources.In("a",
1031-
Range(0, conf.parquetFilterPushDownInFilterThreshold + 1).toArray)).isEmpty)
1032-
10331025
import testImplicits._
10341026
withTempPath { path =>
10351027
val data = 0 to 1024

0 commit comments

Comments
 (0)