@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet
20
20
import java .lang .{Boolean => JBoolean , Double => JDouble , Float => JFloat , Long => JLong }
21
21
import java .math .{BigDecimal => JBigDecimal }
22
22
import java .sql .{Date , Timestamp }
23
+ import java .util .Locale
23
24
24
25
import scala .collection .JavaConverters .asScalaBufferConverter
25
26
@@ -31,7 +32,7 @@ import org.apache.parquet.schema.OriginalType._
31
32
import org .apache .parquet .schema .PrimitiveType .PrimitiveTypeName
32
33
import org .apache .parquet .schema .PrimitiveType .PrimitiveTypeName ._
33
34
34
- import org .apache .spark .sql .catalyst .util .DateTimeUtils
35
+ import org .apache .spark .sql .catalyst .util .{ CaseInsensitiveMap , DateTimeUtils }
35
36
import org .apache .spark .sql .catalyst .util .DateTimeUtils .SQLDate
36
37
import org .apache .spark .sql .sources
37
38
import org .apache .spark .unsafe .types .UTF8String
@@ -44,7 +45,18 @@ private[parquet] class ParquetFilters(
44
45
pushDownTimestamp : Boolean ,
45
46
pushDownDecimal : Boolean ,
46
47
pushDownStartWith : Boolean ,
47
- pushDownInFilterThreshold : Int ) {
48
+ pushDownInFilterThreshold : Int ,
49
+ caseSensitive : Boolean ) {
50
+
51
+ /**
52
+ * Holds a single field information stored in the underlying parquet file.
53
+ *
54
+ * @param fieldName field name in parquet file
55
+ * @param fieldType field type related info in parquet file
56
+ */
57
+ private case class ParquetField (
58
+ fieldName : String ,
59
+ fieldType : ParquetSchemaType )
48
60
49
61
private case class ParquetSchemaType (
50
62
originalType : OriginalType ,
@@ -350,25 +362,38 @@ private[parquet] class ParquetFilters(
350
362
}
351
363
352
364
/**
353
- * Returns a map from name of the column to the data type, if predicate push down applies.
365
+ * Returns a map, which contains parquet field name and data type, if predicate push down applies.
354
366
*/
355
- private def getFieldMap (dataType : MessageType ): Map [String , ParquetSchemaType ] = dataType match {
356
- case m : MessageType =>
357
- // Here we don't flatten the fields in the nested schema but just look up through
358
- // root fields. Currently, accessing to nested fields does not push down filters
359
- // and it does not support to create filters for them.
360
- m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f =>
361
- f.getName -> ParquetSchemaType (
362
- f.getOriginalType, f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata)
363
- }.toMap
364
- case _ => Map .empty[String , ParquetSchemaType ]
367
+ private def getFieldMap (dataType : MessageType ): Map [String , ParquetField ] = {
368
+ // Here we don't flatten the fields in the nested schema but just look up through
369
+ // root fields. Currently, accessing to nested fields does not push down filters
370
+ // and it does not support to create filters for them.
371
+ val primitiveFields =
372
+ dataType.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f =>
373
+ f.getName -> ParquetField (f.getName,
374
+ ParquetSchemaType (f.getOriginalType,
375
+ f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata))
376
+ }
377
+ if (caseSensitive) {
378
+ primitiveFields.toMap
379
+ } else {
380
+ // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive
381
+ // mode, just skip pushdown for these fields, they will trigger Exception when reading,
382
+ // See: SPARK-25132.
383
+ val dedupPrimitiveFields =
384
+ primitiveFields
385
+ .groupBy(_._1.toLowerCase(Locale .ROOT ))
386
+ .filter(_._2.size == 1 )
387
+ .mapValues(_.head._2)
388
+ CaseInsensitiveMap (dedupPrimitiveFields)
389
+ }
365
390
}
366
391
367
392
/**
368
393
* Converts data sources filters to Parquet filter predicates.
369
394
*/
370
395
def createFilter (schema : MessageType , predicate : sources.Filter ): Option [FilterPredicate ] = {
371
- val nameToType = getFieldMap(schema)
396
+ val nameToParquetField = getFieldMap(schema)
372
397
373
398
// Decimal type must make sure that filter value's scale matched the file.
374
399
// If doesn't matched, which would cause data corruption.
@@ -381,7 +406,7 @@ private[parquet] class ParquetFilters(
381
406
// Parquet's type in the given file should be matched to the value's type
382
407
// in the pushed filter in order to push down the filter to Parquet.
383
408
def valueCanMakeFilterOn (name : String , value : Any ): Boolean = {
384
- value == null || (nameToType (name) match {
409
+ value == null || (nameToParquetField (name).fieldType match {
385
410
case ParquetBooleanType => value.isInstanceOf [JBoolean ]
386
411
case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf [Number ]
387
412
case ParquetLongType => value.isInstanceOf [JLong ]
@@ -408,7 +433,7 @@ private[parquet] class ParquetFilters(
408
433
// filters for the column having dots in the names. Thus, we do not push down such filters.
409
434
// See SPARK-20364.
410
435
def canMakeFilterOn (name : String , value : Any ): Boolean = {
411
- nameToType .contains(name) && ! name.contains(" ." ) && valueCanMakeFilterOn(name, value)
436
+ nameToParquetField .contains(name) && ! name.contains(" ." ) && valueCanMakeFilterOn(name, value)
412
437
}
413
438
414
439
// NOTE:
@@ -428,29 +453,39 @@ private[parquet] class ParquetFilters(
428
453
429
454
predicate match {
430
455
case sources.IsNull (name) if canMakeFilterOn(name, null ) =>
431
- makeEq.lift(nameToType(name)).map(_(name, null ))
456
+ makeEq.lift(nameToParquetField(name).fieldType)
457
+ .map(_(nameToParquetField(name).fieldName, null ))
432
458
case sources.IsNotNull (name) if canMakeFilterOn(name, null ) =>
433
- makeNotEq.lift(nameToType(name)).map(_(name, null ))
459
+ makeNotEq.lift(nameToParquetField(name).fieldType)
460
+ .map(_(nameToParquetField(name).fieldName, null ))
434
461
435
462
case sources.EqualTo (name, value) if canMakeFilterOn(name, value) =>
436
- makeEq.lift(nameToType(name)).map(_(name, value))
463
+ makeEq.lift(nameToParquetField(name).fieldType)
464
+ .map(_(nameToParquetField(name).fieldName, value))
437
465
case sources.Not (sources.EqualTo (name, value)) if canMakeFilterOn(name, value) =>
438
- makeNotEq.lift(nameToType(name)).map(_(name, value))
466
+ makeNotEq.lift(nameToParquetField(name).fieldType)
467
+ .map(_(nameToParquetField(name).fieldName, value))
439
468
440
469
case sources.EqualNullSafe (name, value) if canMakeFilterOn(name, value) =>
441
- makeEq.lift(nameToType(name)).map(_(name, value))
470
+ makeEq.lift(nameToParquetField(name).fieldType)
471
+ .map(_(nameToParquetField(name).fieldName, value))
442
472
case sources.Not (sources.EqualNullSafe (name, value)) if canMakeFilterOn(name, value) =>
443
- makeNotEq.lift(nameToType(name)).map(_(name, value))
473
+ makeNotEq.lift(nameToParquetField(name).fieldType)
474
+ .map(_(nameToParquetField(name).fieldName, value))
444
475
445
476
case sources.LessThan (name, value) if canMakeFilterOn(name, value) =>
446
- makeLt.lift(nameToType(name)).map(_(name, value))
477
+ makeLt.lift(nameToParquetField(name).fieldType)
478
+ .map(_(nameToParquetField(name).fieldName, value))
447
479
case sources.LessThanOrEqual (name, value) if canMakeFilterOn(name, value) =>
448
- makeLtEq.lift(nameToType(name)).map(_(name, value))
480
+ makeLtEq.lift(nameToParquetField(name).fieldType)
481
+ .map(_(nameToParquetField(name).fieldName, value))
449
482
450
483
case sources.GreaterThan (name, value) if canMakeFilterOn(name, value) =>
451
- makeGt.lift(nameToType(name)).map(_(name, value))
484
+ makeGt.lift(nameToParquetField(name).fieldType)
485
+ .map(_(nameToParquetField(name).fieldName, value))
452
486
case sources.GreaterThanOrEqual (name, value) if canMakeFilterOn(name, value) =>
453
- makeGtEq.lift(nameToType(name)).map(_(name, value))
487
+ makeGtEq.lift(nameToParquetField(name).fieldType)
488
+ .map(_(nameToParquetField(name).fieldName, value))
454
489
455
490
case sources.And (lhs, rhs) =>
456
491
// At here, it is not safe to just convert one side if we do not understand the
@@ -477,7 +512,8 @@ private[parquet] class ParquetFilters(
477
512
case sources.In (name, values) if canMakeFilterOn(name, values.head)
478
513
&& values.distinct.length <= pushDownInFilterThreshold =>
479
514
values.distinct.flatMap { v =>
480
- makeEq.lift(nameToType(name)).map(_(name, v))
515
+ makeEq.lift(nameToParquetField(name).fieldType)
516
+ .map(_(nameToParquetField(name).fieldName, v))
481
517
}.reduceLeftOption(FilterApi .or)
482
518
483
519
case sources.StringStartsWith (name, prefix)
0 commit comments