@@ -585,53 +585,17 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
585
585
* Unsupported predicates are skipped.
586
586
*/
587
587
def convertFilters (table : Table , filters : Seq [Expression ]): String = {
588
- if (SQLConf .get.advancedPartitionPredicatePushdownEnabled) {
589
- convertComplexFilters(table, filters)
590
- } else {
591
- convertBasicFilters(table, filters)
592
- }
593
- }
594
-
595
-
596
- /**
597
- * An extractor that matches all binary comparison operators except null-safe equality.
598
- *
599
- * Null-safe equality is not supported by Hive metastore partition predicate pushdown
600
- */
601
- object SpecialBinaryComparison {
602
- def unapply (e : BinaryComparison ): Option [(Expression , Expression )] = e match {
603
- case _ : EqualNullSafe => None
604
- case _ => Some ((e.left, e.right))
588
+ /**
589
+ * An extractor that matches all binary comparison operators except null-safe equality.
590
+ *
591
+ * Null-safe equality is not supported by Hive metastore partition predicate pushdown
592
+ */
593
+ object SpecialBinaryComparison {
594
+ def unapply (e : BinaryComparison ): Option [(Expression , Expression )] = e match {
595
+ case _ : EqualNullSafe => None
596
+ case _ => Some ((e.left, e.right))
597
+ }
605
598
}
606
- }
607
-
608
- private def convertBasicFilters (table : Table , filters : Seq [Expression ]): String = {
609
- // hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
610
- lazy val varcharKeys = table.getPartitionKeys.asScala
611
- .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME ) ||
612
- col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME ))
613
- .map(col => col.getName).toSet
614
-
615
- filters.collect {
616
- case op @ SpecialBinaryComparison (a : Attribute , Literal (v, _ : IntegralType )) =>
617
- s " ${a.name} ${op.symbol} $v"
618
- case op @ SpecialBinaryComparison (Literal (v, _ : IntegralType ), a : Attribute ) =>
619
- s " $v ${op.symbol} ${a.name}"
620
- case op @ SpecialBinaryComparison (a : Attribute , Literal (v, _ : StringType ))
621
- if ! varcharKeys.contains(a.name) =>
622
- s """ ${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}"""
623
- case op @ SpecialBinaryComparison (Literal (v, _ : StringType ), a : Attribute )
624
- if ! varcharKeys.contains(a.name) =>
625
- s """ ${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}"""
626
- }.mkString(" and " )
627
- }
628
-
629
- private def convertComplexFilters (table : Table , filters : Seq [Expression ]): String = {
630
- // hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
631
- lazy val varcharKeys = table.getPartitionKeys.asScala
632
- .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME ) ||
633
- col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME ))
634
- .map(col => col.getName).toSet
635
599
636
600
object ExtractableLiteral {
637
601
def unapply (expr : Expression ): Option [String ] = expr match {
@@ -643,9 +607,11 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
643
607
644
608
object ExtractableLiterals {
645
609
def unapply (exprs : Seq [Expression ]): Option [Seq [String ]] = {
646
- exprs.map(ExtractableLiteral .unapply).foldLeft(Option (Seq .empty[String ])) {
647
- case (Some (accum), Some (value)) => Some (accum :+ value)
648
- case _ => None
610
+ val extractables = exprs.map(ExtractableLiteral .unapply)
611
+ if (extractables.nonEmpty && extractables.forall(_.isDefined)) {
612
+ Some (extractables.map(_.get))
613
+ } else {
614
+ None
649
615
}
650
616
}
651
617
}
@@ -660,40 +626,68 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
660
626
}
661
627
662
628
def unapply (values : Set [Any ]): Option [Seq [String ]] = {
663
- values.toSeq.foldLeft(Option (Seq .empty[String ])) {
664
- case (Some (accum), value) if valueToLiteralString.isDefinedAt(value) =>
665
- Some (accum :+ valueToLiteralString(value))
666
- case _ => None
629
+ val extractables = values.toSeq.map(valueToLiteralString.lift)
630
+ if (extractables.nonEmpty && extractables.forall(_.isDefined)) {
631
+ Some (extractables.map(_.get))
632
+ } else {
633
+ None
667
634
}
668
635
}
669
636
}
670
637
671
- def convertInToOr (a : Attribute , values : Seq [String ]): String = {
672
- values.map(value => s " ${a.name} = $value" ).mkString(" (" , " or " , " )" )
638
+ object NonVarcharAttribute {
639
+ // hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
640
+ private val varcharKeys = table.getPartitionKeys.asScala
641
+ .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME ) ||
642
+ col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME ))
643
+ .map(col => col.getName).toSet
644
+
645
+ def unapply (attr : Attribute ): Option [String ] = {
646
+ if (varcharKeys.contains(attr.name)) {
647
+ None
648
+ } else {
649
+ Some (attr.name)
650
+ }
651
+ }
652
+ }
653
+
654
+ def convertInToOr (name : String , values : Seq [String ]): String = {
655
+ values.map(value => s " $name = $value" ).mkString(" (" , " or " , " )" )
673
656
}
674
657
675
- lazy val convert : PartialFunction [Expression , String ] = {
676
- case In (a : Attribute , ExtractableLiterals (values))
677
- if ! varcharKeys.contains(a.name) && values.nonEmpty =>
678
- convertInToOr(a, values)
679
- case InSet (a : Attribute , ExtractableValues (values))
680
- if ! varcharKeys.contains(a.name) && values.nonEmpty =>
681
- convertInToOr(a, values)
682
- case op @ SpecialBinaryComparison (a : Attribute , ExtractableLiteral (value))
683
- if ! varcharKeys.contains(a.name) =>
684
- s " ${a.name} ${op.symbol} $value"
685
- case op @ SpecialBinaryComparison (ExtractableLiteral (value), a : Attribute )
686
- if ! varcharKeys.contains(a.name) =>
687
- s " $value ${op.symbol} ${a.name}"
688
- case And (expr1, expr2)
689
- if convert.isDefinedAt(expr1) || convert.isDefinedAt(expr2) =>
690
- (convert.lift(expr1) ++ convert.lift(expr2)).mkString(" (" , " and " , " )" )
691
- case Or (expr1, expr2)
692
- if convert.isDefinedAt(expr1) && convert.isDefinedAt(expr2) =>
693
- s " ( ${convert(expr1)} or ${convert(expr2)}) "
658
+ val useAdvanced = SQLConf .get.advancedPartitionPredicatePushdownEnabled
659
+
660
+ def convert (expr : Expression ): Option [String ] = expr match {
661
+ case In (NonVarcharAttribute (name), ExtractableLiterals (values)) if useAdvanced =>
662
+ Some (convertInToOr(name, values))
663
+
664
+ case InSet (NonVarcharAttribute (name), ExtractableValues (values)) if useAdvanced =>
665
+ Some (convertInToOr(name, values))
666
+
667
+ case op @ SpecialBinaryComparison (NonVarcharAttribute (name), ExtractableLiteral (value)) =>
668
+ Some (s " $name ${op.symbol} $value" )
669
+
670
+ case op @ SpecialBinaryComparison (ExtractableLiteral (value), NonVarcharAttribute (name)) =>
671
+ Some (s " $value ${op.symbol} $name" )
672
+
673
+ case And (expr1, expr2) if useAdvanced =>
674
+ val converted = convert(expr1) ++ convert(expr2)
675
+ if (converted.isEmpty) {
676
+ None
677
+ } else {
678
+ Some (converted.mkString(" (" , " and " , " )" ))
679
+ }
680
+
681
+ case Or (expr1, expr2) if useAdvanced =>
682
+ for {
683
+ left <- convert(expr1)
684
+ right <- convert(expr2)
685
+ } yield s " ( $left or $right) "
686
+
687
+ case _ => None
694
688
}
695
689
696
- filters.map (convert.lift).collect { case Some (filterString) => filterString } .mkString(" and " )
690
+ filters.flatMap (convert) .mkString(" and " )
697
691
}
698
692
699
693
private def quoteStringLiteral (str : String ): String = {
0 commit comments