Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 35 additions & 68 deletions python/pyspark/sql/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union, overload

import pyspark
from pyspark.errors import PySparkValueError
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.sql.pandas.types import (
_dedup_names,
_deduplicate_field_names,
_create_converter_from_pandas,
_create_converter_to_pandas,
to_arrow_schema,
_deduplicate_field_names,
_dedup_names,
from_arrow_schema,
to_arrow_schema,
to_arrow_type,
)
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
from pyspark.sql.types import (
Expand Down Expand Up @@ -228,7 +230,6 @@ def convert(
assign_cols_by_name: bool = False,
int_to_decimal_coercion_enabled: bool = False,
ignore_unexpected_complex_type_values: bool = False,
is_udtf: bool = False,
) -> "pa.RecordBatch":
"""
Convert a pandas DataFrame or list of Series/DataFrames to an Arrow RecordBatch.
Expand All @@ -255,14 +256,6 @@ def convert(
Whether to enable int to decimal coercion (default False)
ignore_unexpected_complex_type_values : bool
Whether to ignore unexpected complex type values in converter (default False)
is_udtf : bool
Whether this conversion is for a UDTF. UDTFs use broader Arrow exception
handling to allow more type coercions (e.g., struct field casting via
ArrowTypeError), and convert errors to UDTF_ARROW_TYPE_CAST_ERROR.
# TODO(SPARK-55502): Unify UDTF and regular UDF conversion paths to
# eliminate the is_udtf flag.
Regular UDFs only catch ArrowInvalid to preserve legacy behavior where
e.g. string→decimal must raise an error. (default False)

Returns
-------
Expand All @@ -271,9 +264,6 @@ def convert(
import pyarrow as pa
import pandas as pd

from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkRuntimeError
from pyspark.sql.pandas.types import to_arrow_type, _create_converter_from_pandas

# Handle empty schema (0 columns)
# Use dummy column + select([]) to preserve row count (PyArrow limitation workaround)
if len(schema.fields) == 0:
Expand Down Expand Up @@ -318,7 +308,6 @@ def convert_column(
assign_cols_by_name=assign_cols_by_name,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values,
is_udtf=is_udtf,
)
# Wrap the nested RecordBatch as a single StructArray column
return ArrowBatchTransformer.wrap_struct(nested_batch).column(0)
Expand All @@ -343,60 +332,38 @@ def convert_column(

mask = None if hasattr(series.array, "__arrow_array__") else series.isnull()

if is_udtf:
# UDTF path: broad ArrowException catch so that both ArrowInvalid
# AND ArrowTypeError (e.g. dict→struct) trigger the cast fallback.
# Unified conversion path: broad ArrowException catch so that both ArrowInvalid
# AND ArrowTypeError (e.g. dict→struct) trigger the cast fallback.
try:
try:
try:
return pa.Array.from_pandas(
series, mask=mask, type=arrow_type, safe=safecheck
)
except pa.lib.ArrowException: # broad: includes ArrowTypeError
if arrow_cast:
return pa.Array.from_pandas(series, mask=mask).cast(
target_type=arrow_type, safe=safecheck
)
raise
except pa.lib.ArrowException: # convert any Arrow error to user-friendly message
raise PySparkRuntimeError(
errorClass="UDTF_ARROW_TYPE_CAST_ERROR",
messageParameters={
"col_name": field_name,
"col_type": str(series.dtype),
"arrow_type": str(arrow_type),
},
) from None
else:
# UDF path: only ArrowInvalid triggers the cast fallback.
# ArrowTypeError (e.g. string→decimal) must NOT be silently cast.
try:
try:
return pa.Array.from_pandas(
series, mask=mask, type=arrow_type, safe=safecheck
)
except pa.lib.ArrowInvalid: # narrow: skip ArrowTypeError
if arrow_cast:
return pa.Array.from_pandas(series, mask=mask).cast(
target_type=arrow_type, safe=safecheck
)
raise
except TypeError as e: # includes pa.lib.ArrowTypeError
raise PySparkTypeError(
f"Exception thrown when converting pandas.Series ({series.dtype}) "
f"with name '{field_name}' to Arrow Array ({arrow_type})."
) from e
except ValueError as e: # includes pa.lib.ArrowInvalid
error_msg = (
f"Exception thrown when converting pandas.Series ({series.dtype}) "
f"with name '{field_name}' to Arrow Array ({arrow_type})."
return pa.Array.from_pandas(
series, mask=mask, type=arrow_type, safe=safecheck
)
if safecheck:
error_msg += (
" It can be caused by overflows or other unsafe conversions "
"warned by Arrow. Arrow safe type check can be disabled by using "
"SQL config `spark.sql.execution.pandas.convertToArrowArraySafely`."
except pa.lib.ArrowException: # broad: includes ArrowTypeError and ArrowInvalid
if arrow_cast:
return pa.Array.from_pandas(series, mask=mask).cast(
target_type=arrow_type, safe=safecheck
)
raise PySparkValueError(error_msg) from e
raise
except (TypeError, pa.lib.ArrowTypeError) as e:
# ArrowTypeError is a subclass of TypeError
raise PySparkTypeError(
f"Exception thrown when converting pandas.Series ({series.dtype}) "
f"with name '{field_name}' to Arrow Array ({arrow_type})."
) from e
except (ValueError, pa.lib.ArrowInvalid) as e:
# ArrowInvalid is a subclass of ValueError
error_msg = (
f"Exception thrown when converting pandas.Series ({series.dtype}) "
f"with name '{field_name}' to Arrow Array ({arrow_type})."
)
if safecheck:
error_msg += (
" It can be caused by overflows or other unsafe conversions "
"warned by Arrow. Arrow safe type check can be disabled by using "
"SQL config `spark.sql.execution.pandas.convertToArrowArraySafely`."
)
raise PySparkValueError(error_msg) from e

arrays = [convert_column(col, field) for col, field in zip(columns, schema.fields)]
return pa.RecordBatch.from_arrays(arrays, schema.names)
Expand Down
6 changes: 0 additions & 6 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,6 @@ def __init__(
int_to_decimal_coercion_enabled: bool = False,
prefers_large_types: bool = False,
ignore_unexpected_complex_type_values: bool = False,
is_udtf: bool = False,
):
super().__init__(
timezone,
Expand All @@ -528,7 +527,6 @@ def __init__(
)
self._assign_cols_by_name = assign_cols_by_name
self._ignore_unexpected_complex_type_values = ignore_unexpected_complex_type_values
self._is_udtf = is_udtf

def dump_stream(self, iterator, stream):
"""
Expand Down Expand Up @@ -567,7 +565,6 @@ def create_batch(
assign_cols_by_name=self._assign_cols_by_name,
int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled,
ignore_unexpected_complex_type_values=self._ignore_unexpected_complex_type_values,
is_udtf=self._is_udtf,
)

batches = self._write_stream_start(
Expand Down Expand Up @@ -767,9 +764,6 @@ def __init__(self, timezone, safecheck, input_type, int_to_decimal_coercion_enab
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
# UDTF-specific: ignore unexpected complex type values in converter
ignore_unexpected_complex_type_values=True,
# UDTF-specific: enables broader Arrow exception handling and
# converts errors to UDTF_ARROW_TYPE_CAST_ERROR
is_udtf=True,
)

def __repr__(self):
Expand Down
28 changes: 19 additions & 9 deletions python/pyspark/sql/tests/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,19 +295,26 @@ def test_convert_error_messages(self):
with self.assertRaises((PySparkValueError, PySparkTypeError)) as ctx:
PandasToArrowConversion.convert(data, schema)
# Error message should reference the schema field name, not the positional index
self.assertIn("age", str(ctx.exception))
error_msg = str(ctx.exception)
self.assertIn("Exception thrown when converting pandas.Series", error_msg)
self.assertIn("with name 'age'", error_msg)
self.assertIn("to Arrow Array", error_msg)

def test_convert_is_udtf(self):
"""Test is_udtf=True produces PySparkRuntimeError with UDTF_ARROW_TYPE_CAST_ERROR."""
def test_convert_broad_exception_handling(self):
"""Test unified conversion uses broad exception handling for better type coercion."""
import pandas as pd

schema = StructType([StructField("val", DoubleType())])
data = [pd.Series(["not_a_number", "bad"])]

# ValueError path (string -> double)
with self.assertRaises(PySparkRuntimeError) as ctx:
PandasToArrowConversion.convert(data, schema, is_udtf=True)
self.assertIn("UDTF_ARROW_TYPE_CAST_ERROR", str(ctx.exception))
with self.assertRaises(PySparkValueError) as ctx:
PandasToArrowConversion.convert(data, schema)
error_msg = str(ctx.exception)
self.assertIn("Exception thrown when converting pandas.Series", error_msg)
self.assertIn("with name 'val'", error_msg)
self.assertIn("to Arrow Array", error_msg)
self.assertIn("double", error_msg)

# TypeError path (int -> struct): ArrowTypeError inherits from TypeError.
# ignore_unexpected_complex_type_values=True lets the bad value pass through
Expand All @@ -316,14 +323,17 @@ def test_convert_is_udtf(self):
[StructField("x", StructType([StructField("a", IntegerType())]))]
)
data = [pd.Series([0, 1])]
with self.assertRaises(PySparkRuntimeError) as ctx:
with self.assertRaises(PySparkTypeError) as ctx:
PandasToArrowConversion.convert(
data,
struct_schema,
is_udtf=True,
ignore_unexpected_complex_type_values=True,
)
self.assertIn("UDTF_ARROW_TYPE_CAST_ERROR", str(ctx.exception))
error_msg = str(ctx.exception)
self.assertIn("Exception thrown when converting pandas.Series", error_msg)
self.assertIn("with name 'x'", error_msg)
self.assertIn("to Arrow Array", error_msg)
self.assertIn("struct<a: int32>", error_msg)

def test_convert_prefers_large_types(self):
"""Test prefers_large_types produces large Arrow types."""
Expand Down