Skip to content

Commit 32c8a3d

Browse files
mgaido91HyukjinKwon
authored andcommitted
[MINOR] Avoid code duplication for nullable in Higher Order function
## What changes were proposed in this pull request? Most of `HigherOrderFunction`s have the same `nullable` definition, ie. they are nullable when one of their arguments is nullable. The PR refactors it in order to avoid code duplication. ## How was this patch tested? NA Closes apache#22243 from mgaido91/MINOR_nullable_hof. Authored-by: Marco Gaido <[email protected]> Signed-off-by: hyukjinkwon <[email protected]>
1 parent bbbf814 commit 32c8a3d

File tree

1 file changed

+2
-16
lines changed

1 file changed

+2
-16
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ object LambdaFunction {
9090
*/
9191
trait HigherOrderFunction extends Expression with ExpectsInputTypes {
9292

93+
override def nullable: Boolean = arguments.exists(_.nullable)
94+
9395
override def children: Seq[Expression] = arguments ++ functions
9496

9597
/**
@@ -217,8 +219,6 @@ case class ArrayTransform(
217219
function: Expression)
218220
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
219221

220-
override def nullable: Boolean = argument.nullable
221-
222222
override def dataType: ArrayType = ArrayType(function.dataType, function.nullable)
223223

224224
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = {
@@ -287,8 +287,6 @@ case class MapFilter(
287287
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
288288
}
289289

290-
override def nullable: Boolean = argument.nullable
291-
292290
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
293291
val m = argumentValue.asInstanceOf[MapData]
294292
val f = functionForEval
@@ -328,8 +326,6 @@ case class ArrayFilter(
328326
function: Expression)
329327
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
330328

331-
override def nullable: Boolean = argument.nullable
332-
333329
override def dataType: DataType = argument.dataType
334330

335331
override def functionType: AbstractDataType = BooleanType
@@ -375,8 +371,6 @@ case class ArrayExists(
375371
function: Expression)
376372
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
377373

378-
override def nullable: Boolean = argument.nullable
379-
380374
override def dataType: DataType = BooleanType
381375

382376
override def functionType: AbstractDataType = BooleanType
@@ -516,8 +510,6 @@ case class TransformKeys(
516510
function: Expression)
517511
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {
518512

519-
override def nullable: Boolean = argument.nullable
520-
521513
@transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType
522514

523515
override def dataType: DataType = MapType(function.dataType, valueType, valueContainsNull)
@@ -568,8 +560,6 @@ case class TransformValues(
568560
function: Expression)
569561
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {
570562

571-
override def nullable: Boolean = argument.nullable
572-
573563
@transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType
574564

575565
override def dataType: DataType = MapType(keyType, function.dataType, function.nullable)
@@ -638,8 +628,6 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
638628

639629
override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil
640630

641-
override def nullable: Boolean = left.nullable || right.nullable
642-
643631
override def dataType: DataType = MapType(keyType, function.dataType, function.nullable)
644632

645633
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapZipWith = {
@@ -810,8 +798,6 @@ case class ZipWith(left: Expression, right: Expression, function: Expression)
810798

811799
override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil
812800

813-
override def nullable: Boolean = left.nullable || right.nullable
814-
815801
override def dataType: ArrayType = ArrayType(function.dataType, function.nullable)
816802

817803
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ZipWith = {

0 commit comments

Comments
 (0)