Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/docs/source/migration_guide/pyspark_upgrade.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ Upgrading from PySpark 4.0 to 4.1

* In Spark 4.1, Arrow-optimized Python UDF supports UDT input / output instead of falling back to the regular UDF. To restore the legacy behavior, set ``spark.sql.execution.pythonUDF.arrow.legacy.fallbackOnUDT`` to ``true``.

* In Spark 4.1, unnecessary conversion to pandas instances is removed when ``spark.sql.execution.pythonUDTF.arrow.enabled`` is enabled. As a result, the type coercion changes when the produced output has a schema different from the specified schema. To restore the previous behavior, enable ``spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled``.


Upgrading from PySpark 3.5 to 4.0
---------------------------------

Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,11 @@
"Cannot convert the output value of the column '<col_name>' with type '<col_type>' to the specified return type of the column: '<arrow_type>'. Please check if the data types match and try again."
]
},
"UDTF_ARROW_TYPE_CONVERSION_ERROR": {
"message": [
"Cannot convert the output value of the input '<data>' with type '<schema>' to the specified return type of the column: '<arrow_schema>'. Please check if the data types match and try again."
]
},
"UDTF_CONSTRUCTOR_INVALID_IMPLEMENTS_ANALYZE_METHOD": {
"message": [
"Failed to evaluate the user-defined table function '<name>' because its constructor is invalid: the function implements the 'analyze' method, but its constructor has more than two arguments (including the 'self' reference). Please update the table function so that its constructor accepts exactly one 'self' argument, or one 'self' argument plus another argument for the result of the 'analyze' method, and try the query again."
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ def convert(data: Sequence[Any], schema: StructType, use_large_var_types: bool)
if isinstance(item, dict):
for i, col in enumerate(column_names):
pylist[i].append(column_convs[i](item.get(col)))
elif item is None:
for i, col in enumerate(column_names):
pylist[i].append(None)
else:
if len(item) != len(column_names):
raise PySparkValueError(
Expand Down
20 changes: 17 additions & 3 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,14 @@ def wrap_and_init_stream():
assert isinstance(batch, pa.RecordBatch)

# Wrap the root struct
struct = pa.StructArray.from_arrays(
batch.columns, fields=pa.struct(list(batch.schema))
)
if len(batch.columns) == 0:
# When batch has no column, it should still create
# an empty batch with the number of rows set.
struct = pa.array([{}] * batch.num_rows)
else:
struct = pa.StructArray.from_arrays(
batch.columns, fields=pa.struct(list(batch.schema))
)
batch = pa.RecordBatch.from_arrays([struct], ["_0"])

# Write the first record batch with initialization.
Expand All @@ -181,6 +186,15 @@ def wrap_and_init_stream():
return super(ArrowStreamUDFSerializer, self).dump_stream(wrap_and_init_stream(), stream)


class ArrowStreamUDTFSerializer(ArrowStreamUDFSerializer):
"""
Same as :class:`ArrowStreamUDFSerializer` but it does not flatten when loading batches.
"""

def load_stream(self, stream):
return ArrowStreamSerializer.load_stream(self, stream)


class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer):
"""
Serializes pyarrow.RecordBatch data with Arrow streaming format.
Expand Down
40 changes: 39 additions & 1 deletion python/pyspark/sql/tests/connect/test_parity_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import unittest

from pyspark.testing.connectutils import should_test_connect
from pyspark.sql.tests.test_udtf import BaseUDTFTestsMixin, UDTFArrowTestsMixin
from pyspark.sql.tests.test_udtf import (
BaseUDTFTestsMixin,
UDTFArrowTestsMixin,
LegacyUDTFArrowTestsMixin,
)
from pyspark.testing.connectutils import ReusedConnectTestCase

if should_test_connect:
Expand Down Expand Up @@ -88,16 +92,50 @@ def _add_file(self, path):
self.spark.addArtifacts(path, file=True)


class LegacyArrowUDTFParityTests(LegacyUDTFArrowTestsMixin, UDTFParityTests):
@classmethod
def setUpClass(cls):
super(LegacyArrowUDTFParityTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "true")
cls.spark.conf.set(
"spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled", "true"
)

@classmethod
def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled")
cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled")
finally:
super(LegacyArrowUDTFParityTests, cls).tearDownClass()

def test_udtf_access_spark_session_connect(self):
df = self.spark.range(10)

@udtf(returnType="x: int")
class TestUDTF:
def eval(self):
df.collect()
yield 1,

with self.assertRaisesRegex(PythonException, "NO_ACTIVE_SESSION"):
TestUDTF().collect()


class ArrowUDTFParityTests(UDTFArrowTestsMixin, UDTFParityTests):
@classmethod
def setUpClass(cls):
super(ArrowUDTFParityTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "true")
cls.spark.conf.set(
"spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled", "false"
)

@classmethod
def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled")
cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled")
finally:
super(ArrowUDTFParityTests, cls).tearDownClass()

Expand Down
Loading