Skip to content

Commit 4b2718e

Browse files
committed
refactor(python): unify UDF/UDTF error handling by removing is_udtf flag
1 parent 1f758c2 commit 4b2718e

File tree

3 files changed

+37
-76
lines changed

3 files changed

+37
-76
lines changed

python/pyspark/sql/conversion.py

Lines changed: 30 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,6 @@ def convert(
228228
assign_cols_by_name: bool = False,
229229
int_to_decimal_coercion_enabled: bool = False,
230230
ignore_unexpected_complex_type_values: bool = False,
231-
is_udtf: bool = False,
232231
) -> "pa.RecordBatch":
233232
"""
234233
Convert a pandas DataFrame or list of Series/DataFrames to an Arrow RecordBatch.
@@ -255,14 +254,6 @@ def convert(
255254
Whether to enable int to decimal coercion (default False)
256255
ignore_unexpected_complex_type_values : bool
257256
Whether to ignore unexpected complex type values in converter (default False)
258-
is_udtf : bool
259-
Whether this conversion is for a UDTF. UDTFs use broader Arrow exception
260-
handling to allow more type coercions (e.g., struct field casting via
261-
ArrowTypeError), and convert errors to UDTF_ARROW_TYPE_CAST_ERROR.
262-
# TODO(SPARK-55502): Unify UDTF and regular UDF conversion paths to
263-
# eliminate the is_udtf flag.
264-
Regular UDFs only catch ArrowInvalid to preserve legacy behavior where
265-
e.g. string→decimal must raise an error. (default False)
266257
267258
Returns
268259
-------
@@ -271,7 +262,7 @@ def convert(
271262
import pyarrow as pa
272263
import pandas as pd
273264

274-
from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkRuntimeError
265+
from pyspark.errors import PySparkTypeError, PySparkValueError
275266
from pyspark.sql.pandas.types import to_arrow_type, _create_converter_from_pandas
276267

277268
# Handle empty schema (0 columns)
@@ -318,7 +309,6 @@ def convert_column(
318309
assign_cols_by_name=assign_cols_by_name,
319310
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
320311
ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values,
321-
is_udtf=is_udtf,
322312
)
323313
# Wrap the nested RecordBatch as a single StructArray column
324314
return ArrowBatchTransformer.wrap_struct(nested_batch).column(0)
@@ -343,60 +333,38 @@ def convert_column(
343333

344334
mask = None if hasattr(series.array, "__arrow_array__") else series.isnull()
345335

346-
if is_udtf:
347-
# UDTF path: broad ArrowException catch so that both ArrowInvalid
348-
# AND ArrowTypeError (e.g. dict→struct) trigger the cast fallback.
336+
# Unified conversion path: broad ArrowException catch so that both ArrowInvalid
337+
# AND ArrowTypeError (e.g. dict→struct) trigger the cast fallback.
338+
try:
349339
try:
350-
try:
351-
return pa.Array.from_pandas(
352-
series, mask=mask, type=arrow_type, safe=safecheck
353-
)
354-
except pa.lib.ArrowException: # broad: includes ArrowTypeError
355-
if arrow_cast:
356-
return pa.Array.from_pandas(series, mask=mask).cast(
357-
target_type=arrow_type, safe=safecheck
358-
)
359-
raise
360-
except pa.lib.ArrowException: # convert any Arrow error to user-friendly message
361-
raise PySparkRuntimeError(
362-
errorClass="UDTF_ARROW_TYPE_CAST_ERROR",
363-
messageParameters={
364-
"col_name": field_name,
365-
"col_type": str(series.dtype),
366-
"arrow_type": str(arrow_type),
367-
},
368-
) from None
369-
else:
370-
# UDF path: only ArrowInvalid triggers the cast fallback.
371-
# ArrowTypeError (e.g. string→decimal) must NOT be silently cast.
372-
try:
373-
try:
374-
return pa.Array.from_pandas(
375-
series, mask=mask, type=arrow_type, safe=safecheck
376-
)
377-
except pa.lib.ArrowInvalid: # narrow: skip ArrowTypeError
378-
if arrow_cast:
379-
return pa.Array.from_pandas(series, mask=mask).cast(
380-
target_type=arrow_type, safe=safecheck
381-
)
382-
raise
383-
except TypeError as e: # includes pa.lib.ArrowTypeError
384-
raise PySparkTypeError(
385-
f"Exception thrown when converting pandas.Series ({series.dtype}) "
386-
f"with name '{field_name}' to Arrow Array ({arrow_type})."
387-
) from e
388-
except ValueError as e: # includes pa.lib.ArrowInvalid
389-
error_msg = (
390-
f"Exception thrown when converting pandas.Series ({series.dtype}) "
391-
f"with name '{field_name}' to Arrow Array ({arrow_type})."
340+
return pa.Array.from_pandas(
341+
series, mask=mask, type=arrow_type, safe=safecheck
392342
)
393-
if safecheck:
394-
error_msg += (
395-
" It can be caused by overflows or other unsafe conversions "
396-
"warned by Arrow. Arrow safe type check can be disabled by using "
397-
"SQL config `spark.sql.execution.pandas.convertToArrowArraySafely`."
343+
except pa.lib.ArrowException: # broad: includes ArrowTypeError and ArrowInvalid
344+
if arrow_cast:
345+
return pa.Array.from_pandas(series, mask=mask).cast(
346+
target_type=arrow_type, safe=safecheck
398347
)
399-
raise PySparkValueError(error_msg) from e
348+
raise
349+
except (TypeError, pa.lib.ArrowTypeError) as e:
350+
# ArrowTypeError is a subclass of TypeError
351+
raise PySparkTypeError(
352+
f"Exception thrown when converting pandas.Series ({series.dtype}) "
353+
f"with name '{field_name}' to Arrow Array ({arrow_type})."
354+
) from e
355+
except (ValueError, pa.lib.ArrowInvalid) as e:
356+
# ArrowInvalid is a subclass of ValueError
357+
error_msg = (
358+
f"Exception thrown when converting pandas.Series ({series.dtype}) "
359+
f"with name '{field_name}' to Arrow Array ({arrow_type})."
360+
)
361+
if safecheck:
362+
error_msg += (
363+
" It can be caused by overflows or other unsafe conversions "
364+
"warned by Arrow. Arrow safe type check can be disabled by using "
365+
"SQL config `spark.sql.execution.pandas.convertToArrowArraySafely`."
366+
)
367+
raise PySparkValueError(error_msg) from e
400368

401369
arrays = [convert_column(col, field) for col, field in zip(columns, schema.fields)]
402370
return pa.RecordBatch.from_arrays(arrays, schema.names)

python/pyspark/sql/pandas/serializers.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,6 @@ def __init__(
513513
int_to_decimal_coercion_enabled: bool = False,
514514
prefers_large_types: bool = False,
515515
ignore_unexpected_complex_type_values: bool = False,
516-
is_udtf: bool = False,
517516
):
518517
super().__init__(
519518
timezone,
@@ -528,7 +527,6 @@ def __init__(
528527
)
529528
self._assign_cols_by_name = assign_cols_by_name
530529
self._ignore_unexpected_complex_type_values = ignore_unexpected_complex_type_values
531-
self._is_udtf = is_udtf
532530

533531
def dump_stream(self, iterator, stream):
534532
"""
@@ -567,7 +565,6 @@ def create_batch(
567565
assign_cols_by_name=self._assign_cols_by_name,
568566
int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled,
569567
ignore_unexpected_complex_type_values=self._ignore_unexpected_complex_type_values,
570-
is_udtf=self._is_udtf,
571568
)
572569

573570
batches = self._write_stream_start(
@@ -767,9 +764,6 @@ def __init__(self, timezone, safecheck, input_type, int_to_decimal_coercion_enab
767764
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
768765
# UDTF-specific: ignore unexpected complex type values in converter
769766
ignore_unexpected_complex_type_values=True,
770-
# UDTF-specific: enables broader Arrow exception handling and
771-
# converts errors to UDTF_ARROW_TYPE_CAST_ERROR
772-
is_udtf=True,
773767
)
774768

775769
def __repr__(self):

python/pyspark/sql/tests/test_conversion.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -297,17 +297,17 @@ def test_convert_error_messages(self):
297297
# Error message should reference the schema field name, not the positional index
298298
self.assertIn("age", str(ctx.exception))
299299

300-
def test_convert_is_udtf(self):
301-
"""Test is_udtf=True produces PySparkRuntimeError with UDTF_ARROW_TYPE_CAST_ERROR."""
300+
def test_convert_broad_exception_handling(self):
301+
"""Test unified conversion uses broad exception handling for better type coercion."""
302302
import pandas as pd
303303

304304
schema = StructType([StructField("val", DoubleType())])
305305
data = [pd.Series(["not_a_number", "bad"])]
306306

307307
# ValueError path (string -> double)
308-
with self.assertRaises(PySparkRuntimeError) as ctx:
309-
PandasToArrowConversion.convert(data, schema, is_udtf=True)
310-
self.assertIn("UDTF_ARROW_TYPE_CAST_ERROR", str(ctx.exception))
308+
with self.assertRaises(PySparkValueError) as ctx:
309+
PandasToArrowConversion.convert(data, schema)
310+
self.assertIn("val", str(ctx.exception))
311311

312312
# TypeError path (int -> struct): ArrowTypeError inherits from TypeError.
313313
# ignore_unexpected_complex_type_values=True lets the bad value pass through
@@ -316,14 +316,13 @@ def test_convert_is_udtf(self):
316316
[StructField("x", StructType([StructField("a", IntegerType())]))]
317317
)
318318
data = [pd.Series([0, 1])]
319-
with self.assertRaises(PySparkRuntimeError) as ctx:
319+
with self.assertRaises(PySparkTypeError) as ctx:
320320
PandasToArrowConversion.convert(
321321
data,
322322
struct_schema,
323-
is_udtf=True,
324323
ignore_unexpected_complex_type_values=True,
325324
)
326-
self.assertIn("UDTF_ARROW_TYPE_CAST_ERROR", str(ctx.exception))
325+
self.assertIn("x", str(ctx.exception))
327326

328327
def test_convert_prefers_large_types(self):
329328
"""Test prefers_large_types produces large Arrow types."""

0 commit comments

Comments
 (0)