@@ -4763,17 +4763,6 @@ def test_vectorized_udf_invalid_length(self):
4763
4763
'Result vector from pandas_udf was not the required length' ):
4764
4764
df .select (raise_exception (col ('id' ))).collect ()
4765
4765
4766
- def test_vectorized_udf_mix_udf (self ):
4767
- from pyspark .sql .functions import pandas_udf , udf , col
4768
- df = self .spark .range (10 )
4769
- row_by_row_udf = udf (lambda x : x , LongType ())
4770
- pd_udf = pandas_udf (lambda x : x , LongType ())
4771
- with QuietTest (self .sc ):
4772
- with self .assertRaisesRegexp (
4773
- Exception ,
4774
- 'Can not mix vectorized and non-vectorized UDFs' ):
4775
- df .select (row_by_row_udf (col ('id' )), pd_udf (col ('id' ))).collect ()
4776
-
4777
4766
def test_vectorized_udf_chained (self ):
4778
4767
from pyspark .sql .functions import pandas_udf , col
4779
4768
df = self .spark .range (10 )
@@ -5060,6 +5049,166 @@ def test_type_annotation(self):
5060
5049
df = self .spark .range (1 ).select (pandas_udf (f = _locals ['noop' ], returnType = 'bigint' )('id' ))
5061
5050
self .assertEqual (df .first ()[0 ], 0 )
5062
5051
5052
+ def test_mixed_udf (self ):
5053
+ import pandas as pd
5054
+ from pyspark .sql .functions import col , udf , pandas_udf
5055
+
5056
+ df = self .spark .range (0 , 1 ).toDF ('v' )
5057
+
5058
+ # Test mixture of multiple UDFs and Pandas UDFs.
5059
+
5060
+ @udf ('int' )
5061
+ def f1 (x ):
5062
+ assert type (x ) == int
5063
+ return x + 1
5064
+
5065
+ @pandas_udf ('int' )
5066
+ def f2 (x ):
5067
+ assert type (x ) == pd .Series
5068
+ return x + 10
5069
+
5070
+ @udf ('int' )
5071
+ def f3 (x ):
5072
+ assert type (x ) == int
5073
+ return x + 100
5074
+
5075
+ @pandas_udf ('int' )
5076
+ def f4 (x ):
5077
+ assert type (x ) == pd .Series
5078
+ return x + 1000
5079
+
5080
+ # Test single expression with chained UDFs
5081
+ df_chained_1 = df .withColumn ('f2_f1' , f2 (f1 (df ['v' ])))
5082
+ df_chained_2 = df .withColumn ('f3_f2_f1' , f3 (f2 (f1 (df ['v' ]))))
5083
+ df_chained_3 = df .withColumn ('f4_f3_f2_f1' , f4 (f3 (f2 (f1 (df ['v' ])))))
5084
+ df_chained_4 = df .withColumn ('f4_f2_f1' , f4 (f2 (f1 (df ['v' ]))))
5085
+ df_chained_5 = df .withColumn ('f4_f3_f1' , f4 (f3 (f1 (df ['v' ]))))
5086
+
5087
+ expected_chained_1 = df .withColumn ('f2_f1' , df ['v' ] + 11 )
5088
+ expected_chained_2 = df .withColumn ('f3_f2_f1' , df ['v' ] + 111 )
5089
+ expected_chained_3 = df .withColumn ('f4_f3_f2_f1' , df ['v' ] + 1111 )
5090
+ expected_chained_4 = df .withColumn ('f4_f2_f1' , df ['v' ] + 1011 )
5091
+ expected_chained_5 = df .withColumn ('f4_f3_f1' , df ['v' ] + 1101 )
5092
+
5093
+ self .assertEquals (expected_chained_1 .collect (), df_chained_1 .collect ())
5094
+ self .assertEquals (expected_chained_2 .collect (), df_chained_2 .collect ())
5095
+ self .assertEquals (expected_chained_3 .collect (), df_chained_3 .collect ())
5096
+ self .assertEquals (expected_chained_4 .collect (), df_chained_4 .collect ())
5097
+ self .assertEquals (expected_chained_5 .collect (), df_chained_5 .collect ())
5098
+
5099
+ # Test multiple mixed UDF expressions in a single projection
5100
+ df_multi_1 = df \
5101
+ .withColumn ('f1' , f1 (col ('v' ))) \
5102
+ .withColumn ('f2' , f2 (col ('v' ))) \
5103
+ .withColumn ('f3' , f3 (col ('v' ))) \
5104
+ .withColumn ('f4' , f4 (col ('v' ))) \
5105
+ .withColumn ('f2_f1' , f2 (col ('f1' ))) \
5106
+ .withColumn ('f3_f1' , f3 (col ('f1' ))) \
5107
+ .withColumn ('f4_f1' , f4 (col ('f1' ))) \
5108
+ .withColumn ('f3_f2' , f3 (col ('f2' ))) \
5109
+ .withColumn ('f4_f2' , f4 (col ('f2' ))) \
5110
+ .withColumn ('f4_f3' , f4 (col ('f3' ))) \
5111
+ .withColumn ('f3_f2_f1' , f3 (col ('f2_f1' ))) \
5112
+ .withColumn ('f4_f2_f1' , f4 (col ('f2_f1' ))) \
5113
+ .withColumn ('f4_f3_f1' , f4 (col ('f3_f1' ))) \
5114
+ .withColumn ('f4_f3_f2' , f4 (col ('f3_f2' ))) \
5115
+ .withColumn ('f4_f3_f2_f1' , f4 (col ('f3_f2_f1' )))
5116
+
5117
+ # Test mixed udfs in a single expression
5118
+ df_multi_2 = df \
5119
+ .withColumn ('f1' , f1 (col ('v' ))) \
5120
+ .withColumn ('f2' , f2 (col ('v' ))) \
5121
+ .withColumn ('f3' , f3 (col ('v' ))) \
5122
+ .withColumn ('f4' , f4 (col ('v' ))) \
5123
+ .withColumn ('f2_f1' , f2 (f1 (col ('v' )))) \
5124
+ .withColumn ('f3_f1' , f3 (f1 (col ('v' )))) \
5125
+ .withColumn ('f4_f1' , f4 (f1 (col ('v' )))) \
5126
+ .withColumn ('f3_f2' , f3 (f2 (col ('v' )))) \
5127
+ .withColumn ('f4_f2' , f4 (f2 (col ('v' )))) \
5128
+ .withColumn ('f4_f3' , f4 (f3 (col ('v' )))) \
5129
+ .withColumn ('f3_f2_f1' , f3 (f2 (f1 (col ('v' ))))) \
5130
+ .withColumn ('f4_f2_f1' , f4 (f2 (f1 (col ('v' ))))) \
5131
+ .withColumn ('f4_f3_f1' , f4 (f3 (f1 (col ('v' ))))) \
5132
+ .withColumn ('f4_f3_f2' , f4 (f3 (f2 (col ('v' ))))) \
5133
+ .withColumn ('f4_f3_f2_f1' , f4 (f3 (f2 (f1 (col ('v' ))))))
5134
+
5135
+ expected = df \
5136
+ .withColumn ('f1' , df ['v' ] + 1 ) \
5137
+ .withColumn ('f2' , df ['v' ] + 10 ) \
5138
+ .withColumn ('f3' , df ['v' ] + 100 ) \
5139
+ .withColumn ('f4' , df ['v' ] + 1000 ) \
5140
+ .withColumn ('f2_f1' , df ['v' ] + 11 ) \
5141
+ .withColumn ('f3_f1' , df ['v' ] + 101 ) \
5142
+ .withColumn ('f4_f1' , df ['v' ] + 1001 ) \
5143
+ .withColumn ('f3_f2' , df ['v' ] + 110 ) \
5144
+ .withColumn ('f4_f2' , df ['v' ] + 1010 ) \
5145
+ .withColumn ('f4_f3' , df ['v' ] + 1100 ) \
5146
+ .withColumn ('f3_f2_f1' , df ['v' ] + 111 ) \
5147
+ .withColumn ('f4_f2_f1' , df ['v' ] + 1011 ) \
5148
+ .withColumn ('f4_f3_f1' , df ['v' ] + 1101 ) \
5149
+ .withColumn ('f4_f3_f2' , df ['v' ] + 1110 ) \
5150
+ .withColumn ('f4_f3_f2_f1' , df ['v' ] + 1111 )
5151
+
5152
+ self .assertEquals (expected .collect (), df_multi_1 .collect ())
5153
+ self .assertEquals (expected .collect (), df_multi_2 .collect ())
5154
+
5155
+ def test_mixed_udf_and_sql (self ):
5156
+ import pandas as pd
5157
+ from pyspark .sql import Column
5158
+ from pyspark .sql .functions import udf , pandas_udf
5159
+
5160
+ df = self .spark .range (0 , 1 ).toDF ('v' )
5161
+
5162
+ # Test mixture of UDFs, Pandas UDFs and SQL expression.
5163
+
5164
+ @udf ('int' )
5165
+ def f1 (x ):
5166
+ assert type (x ) == int
5167
+ return x + 1
5168
+
5169
+ def f2 (x ):
5170
+ assert type (x ) == Column
5171
+ return x + 10
5172
+
5173
+ @pandas_udf ('int' )
5174
+ def f3 (x ):
5175
+ assert type (x ) == pd .Series
5176
+ return x + 100
5177
+
5178
+ df1 = df .withColumn ('f1' , f1 (df ['v' ])) \
5179
+ .withColumn ('f2' , f2 (df ['v' ])) \
5180
+ .withColumn ('f3' , f3 (df ['v' ])) \
5181
+ .withColumn ('f1_f2' , f1 (f2 (df ['v' ]))) \
5182
+ .withColumn ('f1_f3' , f1 (f3 (df ['v' ]))) \
5183
+ .withColumn ('f2_f1' , f2 (f1 (df ['v' ]))) \
5184
+ .withColumn ('f2_f3' , f2 (f3 (df ['v' ]))) \
5185
+ .withColumn ('f3_f1' , f3 (f1 (df ['v' ]))) \
5186
+ .withColumn ('f3_f2' , f3 (f2 (df ['v' ]))) \
5187
+ .withColumn ('f1_f2_f3' , f1 (f2 (f3 (df ['v' ])))) \
5188
+ .withColumn ('f1_f3_f2' , f1 (f3 (f2 (df ['v' ])))) \
5189
+ .withColumn ('f2_f1_f3' , f2 (f1 (f3 (df ['v' ])))) \
5190
+ .withColumn ('f2_f3_f1' , f2 (f3 (f1 (df ['v' ])))) \
5191
+ .withColumn ('f3_f1_f2' , f3 (f1 (f2 (df ['v' ])))) \
5192
+ .withColumn ('f3_f2_f1' , f3 (f2 (f1 (df ['v' ]))))
5193
+
5194
+ expected = df .withColumn ('f1' , df ['v' ] + 1 ) \
5195
+ .withColumn ('f2' , df ['v' ] + 10 ) \
5196
+ .withColumn ('f3' , df ['v' ] + 100 ) \
5197
+ .withColumn ('f1_f2' , df ['v' ] + 11 ) \
5198
+ .withColumn ('f1_f3' , df ['v' ] + 101 ) \
5199
+ .withColumn ('f2_f1' , df ['v' ] + 11 ) \
5200
+ .withColumn ('f2_f3' , df ['v' ] + 110 ) \
5201
+ .withColumn ('f3_f1' , df ['v' ] + 101 ) \
5202
+ .withColumn ('f3_f2' , df ['v' ] + 110 ) \
5203
+ .withColumn ('f1_f2_f3' , df ['v' ] + 111 ) \
5204
+ .withColumn ('f1_f3_f2' , df ['v' ] + 111 ) \
5205
+ .withColumn ('f2_f1_f3' , df ['v' ] + 111 ) \
5206
+ .withColumn ('f2_f3_f1' , df ['v' ] + 111 ) \
5207
+ .withColumn ('f3_f1_f2' , df ['v' ] + 111 ) \
5208
+ .withColumn ('f3_f2_f1' , df ['v' ] + 111 )
5209
+
5210
+ self .assertEquals (expected .collect (), df1 .collect ())
5211
+
5063
5212
5064
5213
@unittest .skipIf (
5065
5214
not _have_pandas or not _have_pyarrow ,
@@ -5487,6 +5636,21 @@ def dummy_pandas_udf(df):
5487
5636
F .col ('temp0.key' ) == F .col ('temp1.key' ))
5488
5637
self .assertEquals (res .count (), 5 )
5489
5638
5639
+ def test_mixed_scalar_udfs_followed_by_grouby_apply (self ):
5640
+ import pandas as pd
5641
+ from pyspark .sql .functions import udf , pandas_udf , PandasUDFType
5642
+
5643
+ df = self .spark .range (0 , 10 ).toDF ('v1' )
5644
+ df = df .withColumn ('v2' , udf (lambda x : x + 1 , 'int' )(df ['v1' ])) \
5645
+ .withColumn ('v3' , pandas_udf (lambda x : x + 2 , 'int' )(df ['v1' ]))
5646
+
5647
+ result = df .groupby () \
5648
+ .apply (pandas_udf (lambda x : pd .DataFrame ([x .sum ().sum ()]),
5649
+ 'sum int' ,
5650
+ PandasUDFType .GROUPED_MAP ))
5651
+
5652
+ self .assertEquals (result .collect ()[0 ]['sum' ], 165 )
5653
+
5490
5654
5491
5655
@unittest .skipIf (
5492
5656
not _have_pandas or not _have_pyarrow ,
0 commit comments