Skip to content

Commit a5849ad

Browse files
BryanCutlerHyukjinKwon
authored andcommitted
[SPARK-24324][PYTHON] Pandas Grouped Map UDF should assign result columns by name
## What changes were proposed in this pull request? Currently, a `pandas_udf` of type `PandasUDFType.GROUPED_MAP` will assign the resulting columns based on index of the return pandas.DataFrame. If a new DataFrame is returned and constructed using a dict, then the order of the columns could be arbitrary and be different than the defined schema for the UDF. If the schema types still match, then no error will be raised and the user will see column names and column data mixed up. This change will first try to assign columns using the return type field names. If a KeyError occurs, then the column index is checked if it is string based. If so, then the error is raised as it is most likely a naming mistake, else it will fallback to assign columns by position and raise a TypeError if the field types do not match. ## How was this patch tested? Added a test that returns a new DataFrame with column order different than the schema. Author: Bryan Cutler <[email protected]> Closes apache#21427 from BryanCutler/arrow-grouped-map-mixesup-cols-SPARK-24324.
1 parent 98f363b commit a5849ad

File tree

11 files changed

+226
-55
lines changed

11 files changed

+226
-55
lines changed

docs/sql-programming-guide.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1752,14 +1752,10 @@ To use `groupBy().apply()`, the user needs to define the following:
17521752
* A Python function that defines the computation for each group.
17531753
* A `StructType` object or a string that defines the schema of the output `DataFrame`.
17541754

1755-
The output schema will be applied to the columns of the returned `pandas.DataFrame` in order by position,
1756-
not by name. This means that the columns in the `pandas.DataFrame` must be indexed so that their
1757-
position matches the corresponding field in the schema.
1758-
1759-
Note that when creating a new `pandas.DataFrame` using a dictionary, the actual position of the column
1760-
can differ from the order that it was placed in the dictionary. It is recommended in this case to
1761-
explicitly define the column order using the `columns` keyword, e.g.
1762-
`pandas.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])`, or alternatively use an `OrderedDict`.
1755+
The column labels of the returned `pandas.DataFrame` must either match the field names in the
1756+
defined output schema if specified as strings, or match the field data types by position if not
1757+
strings, e.g. integer indices. See [pandas.DataFrame](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html#pandas.DataFrame)
1758+
on how to label columns when constructing a `pandas.DataFrame`.
17631759

17641760
Note that all data for a group will be loaded into memory before the function is applied. This can
17651761
lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for

python/pyspark/sql/functions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2584,9 +2584,10 @@ def pandas_udf(f=None, returnType=None, functionType=None):
25842584
25852585
A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame`
25862586
The returnType should be a :class:`StructType` describing the schema of the returned
2587-
`pandas.DataFrame`.
2588-
The length of the returned `pandas.DataFrame` can be arbitrary and the columns must be
2589-
indexed so that their position matches the corresponding field in the schema.
2587+
`pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match
2588+
the field names in the defined returnType schema if specified as strings, or match the
2589+
field data types by position if not strings, e.g. integer indices.
2590+
The length of the returned `pandas.DataFrame` can be arbitrary.
25902591
25912592
Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`.
25922593

python/pyspark/sql/tests.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4742,7 +4742,6 @@ def test_vectorized_udf_chained(self):
47424742

47434743
def test_vectorized_udf_wrong_return_type(self):
47444744
from pyspark.sql.functions import pandas_udf, col
4745-
df = self.spark.range(10)
47464745
with QuietTest(self.sc):
47474746
with self.assertRaisesRegexp(
47484747
NotImplementedError,
@@ -5327,6 +5326,109 @@ def foo3(key, pdf):
53275326
expected4 = udf3.func((), pdf)
53285327
self.assertPandasEqual(expected4, result4)
53295328

5329+
def test_column_order(self):
5330+
from collections import OrderedDict
5331+
import pandas as pd
5332+
from pyspark.sql.functions import pandas_udf, PandasUDFType
5333+
5334+
# Helper function to set column names from a list
5335+
def rename_pdf(pdf, names):
5336+
pdf.rename(columns={old: new for old, new in
5337+
zip(pd_result.columns, names)}, inplace=True)
5338+
5339+
df = self.data
5340+
grouped_df = df.groupby('id')
5341+
grouped_pdf = df.toPandas().groupby('id')
5342+
5343+
# Function returns a pdf with required column names, but order could be arbitrary using dict
5344+
def change_col_order(pdf):
5345+
# Constructing a DataFrame from a dict should result in the same order,
5346+
# but use from_items to ensure the pdf column order is different than schema
5347+
return pd.DataFrame.from_items([
5348+
('id', pdf.id),
5349+
('u', pdf.v * 2),
5350+
('v', pdf.v)])
5351+
5352+
ordered_udf = pandas_udf(
5353+
change_col_order,
5354+
'id long, v int, u int',
5355+
PandasUDFType.GROUPED_MAP
5356+
)
5357+
5358+
# The UDF result should assign columns by name from the pdf
5359+
result = grouped_df.apply(ordered_udf).sort('id', 'v')\
5360+
.select('id', 'u', 'v').toPandas()
5361+
pd_result = grouped_pdf.apply(change_col_order)
5362+
expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
5363+
self.assertPandasEqual(expected, result)
5364+
5365+
# Function returns a pdf with positional columns, indexed by range
5366+
def range_col_order(pdf):
5367+
# Create a DataFrame with positional columns, fix types to long
5368+
return pd.DataFrame(list(zip(pdf.id, pdf.v * 3, pdf.v)), dtype='int64')
5369+
5370+
range_udf = pandas_udf(
5371+
range_col_order,
5372+
'id long, u long, v long',
5373+
PandasUDFType.GROUPED_MAP
5374+
)
5375+
5376+
# The UDF result uses positional columns from the pdf
5377+
result = grouped_df.apply(range_udf).sort('id', 'v') \
5378+
.select('id', 'u', 'v').toPandas()
5379+
pd_result = grouped_pdf.apply(range_col_order)
5380+
rename_pdf(pd_result, ['id', 'u', 'v'])
5381+
expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
5382+
self.assertPandasEqual(expected, result)
5383+
5384+
# Function returns a pdf with columns indexed with integers
5385+
def int_index(pdf):
5386+
return pd.DataFrame(OrderedDict([(0, pdf.id), (1, pdf.v * 4), (2, pdf.v)]))
5387+
5388+
int_index_udf = pandas_udf(
5389+
int_index,
5390+
'id long, u int, v int',
5391+
PandasUDFType.GROUPED_MAP
5392+
)
5393+
5394+
# The UDF result should assign columns by position of integer index
5395+
result = grouped_df.apply(int_index_udf).sort('id', 'v') \
5396+
.select('id', 'u', 'v').toPandas()
5397+
pd_result = grouped_pdf.apply(int_index)
5398+
rename_pdf(pd_result, ['id', 'u', 'v'])
5399+
expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
5400+
self.assertPandasEqual(expected, result)
5401+
5402+
@pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP)
5403+
def column_name_typo(pdf):
5404+
return pd.DataFrame({'iid': pdf.id, 'v': pdf.v})
5405+
5406+
@pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP)
5407+
def invalid_positional_types(pdf):
5408+
return pd.DataFrame([(u'a', 1.2)])
5409+
5410+
with QuietTest(self.sc):
5411+
with self.assertRaisesRegexp(Exception, "KeyError: 'id'"):
5412+
grouped_df.apply(column_name_typo).collect()
5413+
with self.assertRaisesRegexp(Exception, "No cast implemented"):
5414+
grouped_df.apply(invalid_positional_types).collect()
5415+
5416+
def test_positional_assignment_conf(self):
5417+
import pandas as pd
5418+
from pyspark.sql.functions import pandas_udf, PandasUDFType
5419+
5420+
with self.sql_conf({"spark.sql.execution.pandas.groupedMap.assignColumnsByPosition": True}):
5421+
5422+
@pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP)
5423+
def foo(_):
5424+
return pd.DataFrame([('hi', 1)], columns=['x', 'y'])
5425+
5426+
df = self.data
5427+
result = df.groupBy('id').apply(foo).select('a', 'b').collect()
5428+
for r in result:
5429+
self.assertEqual(r.a, 'hi')
5430+
self.assertEqual(r.b, 1)
5431+
53305432

53315433
@unittest.skipIf(
53325434
not _have_pandas or not _have_pyarrow,

python/pyspark/worker.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
from pyspark.util import _get_argspec, fail_on_stopiteration
3939
from pyspark import shuffle
4040

41+
if sys.version >= '3':
42+
basestring = str
43+
4144
pickleSer = PickleSerializer()
4245
utf8_deserializer = UTF8Deserializer()
4346

@@ -92,7 +95,10 @@ def verify_result_length(*a):
9295
return lambda *a: (verify_result_length(*a), arrow_return_type)
9396

9497

95-
def wrap_grouped_map_pandas_udf(f, return_type, argspec):
98+
def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
99+
assign_cols_by_pos = runner_conf.get(
100+
"spark.sql.execution.pandas.groupedMap.assignColumnsByPosition", False)
101+
96102
def wrapped(key_series, value_series):
97103
import pandas as pd
98104

@@ -110,9 +116,13 @@ def wrapped(key_series, value_series):
110116
"Number of columns of the returned pandas.DataFrame "
111117
"doesn't match specified schema. "
112118
"Expected: {} Actual: {}".format(len(return_type), len(result.columns)))
113-
arrow_return_types = (to_arrow_type(field.dataType) for field in return_type)
114-
return [(result[result.columns[i]], arrow_type)
115-
for i, arrow_type in enumerate(arrow_return_types)]
119+
120+
# Assign result columns by schema name if user labeled with strings, else use position
121+
if not assign_cols_by_pos and any(isinstance(name, basestring) for name in result.columns):
122+
return [(result[field.name], to_arrow_type(field.dataType)) for field in return_type]
123+
else:
124+
return [(result[result.columns[i]], to_arrow_type(field.dataType))
125+
for i, field in enumerate(return_type)]
116126

117127
return wrapped
118128

@@ -143,7 +153,7 @@ def wrapped(*series):
143153
return lambda *a: (wrapped(*a), arrow_return_type)
144154

145155

146-
def read_single_udf(pickleSer, infile, eval_type):
156+
def read_single_udf(pickleSer, infile, eval_type, runner_conf):
147157
num_arg = read_int(infile)
148158
arg_offsets = [read_int(infile) for i in range(num_arg)]
149159
row_func = None
@@ -163,7 +173,7 @@ def read_single_udf(pickleSer, infile, eval_type):
163173
return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
164174
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
165175
argspec = _get_argspec(row_func) # signature was lost when wrapping it
166-
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
176+
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf)
167177
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
168178
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
169179
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
@@ -175,6 +185,26 @@ def read_single_udf(pickleSer, infile, eval_type):
175185

176186

177187
def read_udfs(pickleSer, infile, eval_type):
188+
runner_conf = {}
189+
190+
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
191+
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
192+
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
193+
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF):
194+
195+
# Load conf used for pandas_udf evaluation
196+
num_conf = read_int(infile)
197+
for i in range(num_conf):
198+
k = utf8_deserializer.loads(infile)
199+
v = utf8_deserializer.loads(infile)
200+
runner_conf[k] = v
201+
202+
# NOTE: if timezone is set here, that implies respectSessionTimeZone is True
203+
timezone = runner_conf.get("spark.sql.session.timeZone", None)
204+
ser = ArrowStreamPandasSerializer(timezone)
205+
else:
206+
ser = BatchedSerializer(PickleSerializer(), 100)
207+
178208
num_udfs = read_int(infile)
179209
udfs = {}
180210
call_udf = []
@@ -189,7 +219,7 @@ def read_udfs(pickleSer, infile, eval_type):
189219

190220
# See FlatMapGroupsInPandasExec for how arg_offsets are used to
191221
# distinguish between grouping attributes and data attributes
192-
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
222+
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf)
193223
udfs['f'] = udf
194224
split_offset = arg_offsets[0] + 1
195225
arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
@@ -201,7 +231,7 @@ def read_udfs(pickleSer, infile, eval_type):
201231
# In the special case of a single UDF this will return a single result rather
202232
# than a tuple of results; this is the format that the JVM side expects.
203233
for i in range(num_udfs):
204-
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
234+
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf)
205235
udfs['f%d' % i] = udf
206236
args = ["a[%d]" % o for o in arg_offsets]
207237
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
@@ -210,15 +240,6 @@ def read_udfs(pickleSer, infile, eval_type):
210240
mapper = eval(mapper_str, udfs)
211241
func = lambda _, it: map(mapper, it)
212242

213-
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
214-
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
215-
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
216-
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF):
217-
timezone = utf8_deserializer.loads(infile)
218-
ser = ArrowStreamPandasSerializer(timezone)
219-
else:
220-
ser = BatchedSerializer(PickleSerializer(), 100)
221-
222243
# profiling is not supported for UDF
223244
return func, None, ser, ser
224245

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,6 +1161,16 @@ object SQLConf {
11611161
.booleanConf
11621162
.createWithDefault(true)
11631163

1164+
val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION =
1165+
buildConf("spark.sql.execution.pandas.groupedMap.assignColumnsByPosition")
1166+
.internal()
1167+
.doc("When true, a grouped map Pandas UDF will assign columns from the returned " +
1168+
"Pandas DataFrame based on position, regardless of column label type. When false, " +
1169+
"columns will be looked up by name if labeled with a string and fallback to use " +
1170+
"position if not.")
1171+
.booleanConf
1172+
.createWithDefault(false)
1173+
11641174
val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter")
11651175
.internal()
11661176
.doc("When true, the apply function of the rule verifies whether the right node of the" +
@@ -1647,6 +1657,9 @@ class SQLConf extends Serializable with Logging {
16471657

16481658
def pandasRespectSessionTimeZone: Boolean = getConf(PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE)
16491659

1660+
def pandasGroupedMapAssignColumnssByPosition: Boolean =
1661+
getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION)
1662+
16501663
def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER)
16511664

16521665
def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)

sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.arrow.memory.RootAllocator
2323
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit}
2424
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
2525

26+
import org.apache.spark.sql.internal.SQLConf
2627
import org.apache.spark.sql.types._
2728

2829
object ArrowUtils {
@@ -120,4 +121,19 @@ object ArrowUtils {
120121
StructField(field.getName, dt, field.isNullable)
121122
})
122123
}
124+
125+
/** Return Map with conf settings to be used in ArrowPythonRunner */
126+
def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
127+
val timeZoneConf = if (conf.pandasRespectSessionTimeZone) {
128+
Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone)
129+
} else {
130+
Nil
131+
}
132+
val pandasColsByPosition = if (conf.pandasGroupedMapAssignColumnssByPosition) {
133+
Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION.key -> "true")
134+
} else {
135+
Nil
136+
}
137+
Map(timeZoneConf ++ pandasColsByPosition: _*)
138+
}
123139
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2828
import org.apache.spark.sql.catalyst.expressions._
2929
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
3030
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
31+
import org.apache.spark.sql.execution.arrow.ArrowUtils
3132
import org.apache.spark.sql.types.{DataType, StructField, StructType}
3233
import org.apache.spark.util.Utils
3334

@@ -81,7 +82,7 @@ case class AggregateInPandasExec(
8182
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
8283
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
8384
val sessionLocalTimeZone = conf.sessionLocalTimeZone
84-
val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
85+
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
8586

8687
val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip
8788

@@ -135,10 +136,14 @@ case class AggregateInPandasExec(
135136
}
136137

137138
val columnarBatchIter = new ArrowPythonRunner(
138-
pyFuncs, bufferSize, reuseWorker,
139-
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, argOffsets, aggInputSchema,
140-
sessionLocalTimeZone, pandasRespectSessionTimeZone)
141-
.compute(projectedRowIter, context.partitionId(), context)
139+
pyFuncs,
140+
bufferSize,
141+
reuseWorker,
142+
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
143+
argOffsets,
144+
aggInputSchema,
145+
sessionLocalTimeZone,
146+
pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context)
142147

143148
val joinedAttributes =
144149
groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.execution.SparkPlan
27+
import org.apache.spark.sql.execution.arrow.ArrowUtils
2728
import org.apache.spark.sql.types.StructType
2829

2930
/**
@@ -63,7 +64,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
6364

6465
private val batchSize = conf.arrowMaxRecordsPerBatch
6566
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
66-
private val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
67+
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
6768

6869
protected override def evaluate(
6970
funcs: Seq[ChainedPythonFunctions],
@@ -80,10 +81,14 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
8081
val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter)
8182

8283
val columnarBatchIter = new ArrowPythonRunner(
83-
funcs, bufferSize, reuseWorker,
84-
PythonEvalType.SQL_SCALAR_PANDAS_UDF, argOffsets, schema,
85-
sessionLocalTimeZone, pandasRespectSessionTimeZone)
86-
.compute(batchIter, context.partitionId(), context)
84+
funcs,
85+
bufferSize,
86+
reuseWorker,
87+
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
88+
argOffsets,
89+
schema,
90+
sessionLocalTimeZone,
91+
pythonRunnerConf).compute(batchIter, context.partitionId(), context)
8792

8893
new Iterator[InternalRow] {
8994

0 commit comments

Comments
 (0)