@@ -702,127 +702,127 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
702702 comparePlans(df1.queryExecution.optimizedPlan, df2.queryExecution.optimizedPlan)
703703 checkAnswer(df1, Row (3 ) :: Nil )
704704 }
705- }
706705
707- case object StrLenDefault extends ScalarFunction [Int ] {
708- override def inputTypes (): Array [DataType ] = Array (StringType )
709- override def resultType (): DataType = IntegerType
710- override def name (): String = " strlen_default"
706+ private case object StrLenDefault extends ScalarFunction [Int ] {
707+ override def inputTypes (): Array [DataType ] = Array (StringType )
708+ override def resultType (): DataType = IntegerType
709+ override def name (): String = " strlen_default"
711710
712- override def produceResult (input : InternalRow ): Int = {
713- val s = input.getString(0 )
714- s.length
711+ override def produceResult (input : InternalRow ): Int = {
712+ val s = input.getString(0 )
713+ s.length
714+ }
715715 }
716- }
717716
718- case object StrLenMagic extends ScalarFunction [Int ] {
719- override def inputTypes (): Array [DataType ] = Array (StringType )
720- override def resultType (): DataType = IntegerType
721- override def name (): String = " strlen_magic"
717+ case object StrLenMagic extends ScalarFunction [Int ] {
718+ override def inputTypes (): Array [DataType ] = Array (StringType )
719+ override def resultType (): DataType = IntegerType
720+ override def name (): String = " strlen_magic"
722721
723- def invoke (input : UTF8String ): Int = {
724- input.toString.length
722+ def invoke (input : UTF8String ): Int = {
723+ input.toString.length
724+ }
725725 }
726- }
727726
728- case object StrLenBadMagic extends ScalarFunction [Int ] {
729- override def inputTypes (): Array [DataType ] = Array (StringType )
730- override def resultType (): DataType = IntegerType
731- override def name (): String = " strlen_bad_magic"
727+ case object StrLenBadMagic extends ScalarFunction [Int ] {
728+ override def inputTypes (): Array [DataType ] = Array (StringType )
729+ override def resultType (): DataType = IntegerType
730+ override def name (): String = " strlen_bad_magic"
732731
733- def invoke (input : String ): Int = {
734- input.length
732+ def invoke (input : String ): Int = {
733+ input.length
734+ }
735735 }
736- }
737736
738- case object StrLenBadMagicWithDefault extends ScalarFunction [Int ] {
739- override def inputTypes (): Array [DataType ] = Array (StringType )
740- override def resultType (): DataType = IntegerType
741- override def name (): String = " strlen_bad_magic"
737+ case object StrLenBadMagicWithDefault extends ScalarFunction [Int ] {
738+ override def inputTypes (): Array [DataType ] = Array (StringType )
739+ override def resultType (): DataType = IntegerType
740+ override def name (): String = " strlen_bad_magic"
741+
742+ def invoke (input : String ): Int = {
743+ input.length
744+ }
742745
743- def invoke (input : String ): Int = {
744- input.length
746+ override def produceResult (input : InternalRow ): Int = {
747+ val s = input.getString(0 )
748+ s.length
749+ }
745750 }
746751
747- override def produceResult (input : InternalRow ): Int = {
748- val s = input.getString(0 )
749- s.length
752+ private case object StrLenNoImpl extends ScalarFunction [Int ] {
753+ override def inputTypes (): Array [DataType ] = Array (StringType )
754+ override def resultType (): DataType = IntegerType
755+ override def name (): String = " strlen_noimpl"
750756 }
751- }
752757
753- case object StrLenNoImpl extends ScalarFunction [Int ] {
754- override def inputTypes (): Array [DataType ] = Array (StringType )
755- override def resultType (): DataType = IntegerType
756- override def name (): String = " strlen_noimpl"
757- }
758+ // input type doesn't match arguments accepted by `UnboundFunction.bind`
759+ private case object StrLenBadInputTypes extends ScalarFunction [Int ] {
760+ override def inputTypes (): Array [DataType ] = Array (StringType , IntegerType )
761+ override def resultType (): DataType = IntegerType
762+ override def name (): String = " strlen_bad_input_types"
763+ }
758764
759- // input type doesn't match arguments accepted by `UnboundFunction.bind`
760- case object StrLenBadInputTypes extends ScalarFunction [Int ] {
761- override def inputTypes (): Array [DataType ] = Array (StringType , IntegerType )
762- override def resultType (): DataType = IntegerType
763- override def name (): String = " strlen_bad_input_types"
764- }
765+ private case object BadBoundFunction extends BoundFunction {
766+ override def inputTypes (): Array [DataType ] = Array (StringType )
767+ override def resultType (): DataType = IntegerType
768+ override def name (): String = " bad_bound_func"
769+ }
765770
766- case object BadBoundFunction extends BoundFunction {
767- override def inputTypes (): Array [DataType ] = Array (StringType )
768- override def resultType (): DataType = IntegerType
769- override def name (): String = " bad_bound_func"
770- }
771+ object UnboundDecimalAverage extends UnboundFunction {
772+ override def name (): String = " decimal_avg"
771773
772- object UnboundDecimalAverage extends UnboundFunction {
773- override def name (): String = " decimal_avg"
774+ override def bind (inputType : StructType ): BoundFunction = {
775+ if (inputType.fields.length > 1 ) {
776+ throw new UnsupportedOperationException (" Too many arguments" )
777+ }
774778
775- override def bind (inputType : StructType ): BoundFunction = {
776- if (inputType.fields.length > 1 ) {
777- throw new UnsupportedOperationException (" Too many arguments" )
779+ // put interval type here for testing purpose
780+ inputType.fields(0 ).dataType match {
781+ case _ : NumericType | _ : DayTimeIntervalType => DecimalAverage
782+ case dataType =>
783+ throw new UnsupportedOperationException (s " Unsupported input type: $dataType" )
784+ }
778785 }
779786
780- // put interval type here for testing purpose
781- inputType.fields(0 ).dataType match {
782- case _ : NumericType | _ : DayTimeIntervalType => DecimalAverage
783- case dataType =>
784- throw new UnsupportedOperationException (s " Unsupported input type: $dataType" )
785- }
787+ override def description (): String =
788+ " decimal_avg: produces an average using decimal division"
786789 }
787790
788- override def description (): String =
789- " decimal_avg: produces an average using decimal division"
790- }
791-
792- object DecimalAverage extends AggregateFunction [(Decimal , Int ), Decimal ] {
793- override def name (): String = " decimal_avg"
794- override def inputTypes (): Array [DataType ] = Array (DecimalType .SYSTEM_DEFAULT )
795- override def resultType (): DataType = DecimalType .SYSTEM_DEFAULT
791+ object DecimalAverage extends AggregateFunction [(Decimal , Int ), Decimal ] {
792+ override def name (): String = " decimal_avg"
793+ override def inputTypes (): Array [DataType ] = Array (DecimalType .SYSTEM_DEFAULT )
794+ override def resultType (): DataType = DecimalType .SYSTEM_DEFAULT
796795
797- override def newAggregationState (): (Decimal , Int ) = (Decimal .ZERO , 0 )
796+ override def newAggregationState (): (Decimal , Int ) = (Decimal .ZERO , 0 )
798797
799- override def update (state : (Decimal , Int ), input : InternalRow ): (Decimal , Int ) = {
800- if (input.isNullAt(0 )) {
801- state
802- } else {
803- val l = input.getDecimal(0 , DecimalType .SYSTEM_DEFAULT .precision,
804- DecimalType .SYSTEM_DEFAULT .scale)
805- state match {
806- case (_, d) if d == 0 =>
807- (l, 1 )
808- case (total, count) =>
809- (total + l, count + 1 )
798+ override def update (state : (Decimal , Int ), input : InternalRow ): (Decimal , Int ) = {
799+ if (input.isNullAt(0 )) {
800+ state
801+ } else {
802+ val l = input.getDecimal(0 , DecimalType .SYSTEM_DEFAULT .precision,
803+ DecimalType .SYSTEM_DEFAULT .scale)
804+ state match {
805+ case (_, d) if d == 0 =>
806+ (l, 1 )
807+ case (total, count) =>
808+ (total + l, count + 1 )
809+ }
810810 }
811811 }
812- }
813812
814- override def merge (leftState : (Decimal , Int ), rightState : (Decimal , Int )): (Decimal , Int ) = {
815- (leftState._1 + rightState._1, leftState._2 + rightState._2)
816- }
813+ override def merge (leftState : (Decimal , Int ), rightState : (Decimal , Int )): (Decimal , Int ) = {
814+ (leftState._1 + rightState._1, leftState._2 + rightState._2)
815+ }
817816
818- override def produceResult (state : (Decimal , Int )): Decimal = state._1 / Decimal (state._2)
819- }
817+ override def produceResult (state : (Decimal , Int )): Decimal = state._1 / Decimal (state._2)
818+ }
820819
821- object NoImplAverage extends UnboundFunction {
822- override def name (): String = " no_impl_avg"
823- override def description (): String = name()
820+ object NoImplAverage extends UnboundFunction {
821+ override def name (): String = " no_impl_avg"
822+ override def description (): String = name()
824823
825- override def bind (inputType : StructType ): BoundFunction = {
826- throw SparkUnsupportedOperationException ()
824+ override def bind (inputType : StructType ): BoundFunction = {
825+ throw SparkUnsupportedOperationException ()
826+ }
827827 }
828828}
0 commit comments