Skip to content
Closed
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
6470f2d
refactor: introduce PandasBatchTransformer.to_arrow for pandas to Arr…
Yicong-Huang Feb 3, 2026
750ac96
refactor: use PandasBatchTransformer.to_arrow in serializers
Yicong-Huang Feb 3, 2026
b26279b
refactor: remove _create_struct_array and inline to_arrow calls
Yicong-Huang Feb 4, 2026
2c6dcd1
refactor: consolidate conversion utilities into PandasToArrowConversi…
Yicong-Huang Feb 4, 2026
632b4fc
test: add more tests for dataframe_to_batch edge cases
Yicong-Huang Feb 4, 2026
9a73dcb
doc: add comment
Yicong-Huang Feb 4, 2026
df31ecf
refactor: simplify
Yicong-Huang Feb 4, 2026
549cb7e
fix: take care of comments
Yicong-Huang Feb 4, 2026
f6bb23e
fix: simplify
Yicong-Huang Feb 4, 2026
4e27fbd
fix: always return iterable (data, type)
Yicong-Huang Feb 4, 2026
a1a2f15
fix: consume generator
Yicong-Huang Feb 4, 2026
8e4e1f4
fix: use tuple
Yicong-Huang Feb 4, 2026
c253a9e
fix: unwrap in serializer
Yicong-Huang Feb 4, 2026
0122b4c
fix: wrap for all cases
Yicong-Huang Feb 4, 2026
1d6dc86
fix: missing wraps and import
Yicong-Huang Feb 5, 2026
77a4ce5
refactor: redesign api
Yicong-Huang Feb 5, 2026
d6efd1f
fix: test case catch
Yicong-Huang Feb 5, 2026
d891fb6
Merge remote-tracking branch 'upstream/master' into SPARK-55159/refac…
Yicong-Huang Feb 5, 2026
73d0318
fix: use column name in error message
Yicong-Huang Feb 5, 2026
39edd29
fix: remove struct_in_pandas
Yicong-Huang Feb 5, 2026
7879938
revert: changes on wrapping
Yicong-Huang Feb 5, 2026
1749490
refactor: simplify
Yicong-Huang Feb 6, 2026
f931071
fix: mypy
Yicong-Huang Feb 6, 2026
3e6fed7
Merge remote-tracking branch 'upstream/master' into SPARK-55159/refac…
Yicong-Huang Feb 6, 2026
73de5bf
refactor: simplify
Yicong-Huang Feb 6, 2026
c612f32
fix: error class handling
Yicong-Huang Feb 6, 2026
880b581
fix: ArrowNotImplementedError
Yicong-Huang Feb 6, 2026
e648cb0
fix: format
Yicong-Huang Feb 6, 2026
3f8f6fd
fix: error handling
Yicong-Huang Feb 6, 2026
87262a8
fix: error handling
Yicong-Huang Feb 6, 2026
d0e260d
refactor: use is_udtf
Yicong-Huang Feb 6, 2026
c4e8bab
test: update test case
Yicong-Huang Feb 6, 2026
bf56510
fix: handle comments
Yicong-Huang Feb 7, 2026
e8aaddc
fix: format
Yicong-Huang Feb 7, 2026
30ea61a
fix: mypy
Yicong-Huang Feb 8, 2026
a10f02f
merge upstream/master
Yicong-Huang Feb 8, 2026
4925b72
refactor: handle udtf separately
Yicong-Huang Feb 8, 2026
c1d8d9f
refactor: simplify coerce_arrow_array
Yicong-Huang Feb 8, 2026
4ee380a
fix: use per type error message
Yicong-Huang Feb 9, 2026
100a2b6
fix: address review comments
Yicong-Huang Feb 12, 2026
4d84f7b
merge upstream/master
Yicong-Huang Feb 12, 2026
a8d0ba4
fix: use Sequence for convert() param to fix mypy list invariance
Yicong-Huang Feb 12, 2026
842b8aa
Merge branch 'master' into SPARK-55159/refactor/consolidate-pandas-to…
Yicong-Huang Feb 13, 2026
94a060d
refactor: consolidate series_to_array into convert_column
Yicong-Huang Feb 13, 2026
20f7ea3
Merge remote-tracking branch 'upstream/master' into SPARK-55159/refac…
Yicong-Huang Feb 18, 2026
4da884f
fix: remove unused is_variant import after merge
Yicong-Huang Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 233 additions & 0 deletions python/pyspark/sql/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,239 @@ def to_pandas(
]


# TODO: elevate to ArrowBatchTransformer and operate on full RecordBatch schema
# instead of per-column coercion.
def coerce_arrow_array(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a ArrowArrayConversion, should this function in it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's already a TODO to elevate this to ArrowBatchTransformer to operate on full RecordBatch schema instead of per-column coercion. Moving it to ArrowArrayConversion first would just add an extra migration step. I'll address this when we do the batch-level coercion refactor.

arr: "pa.Array",
target_type: "pa.DataType",
*,
safecheck: bool = True,
arrow_cast: bool = True,
) -> "pa.Array":
"""
Coerce an Arrow Array to a target type, with optional type-mismatch enforcement.

When ``arrow_cast`` is True (default), mismatched types are cast to the
target type. When False, a type mismatch raises an error instead.

Parameters
----------
arr : pa.Array
Input Arrow array
target_type : pa.DataType
Target Arrow type
safecheck : bool
Whether to use safe casting (default True)
arrow_cast : bool
Whether to allow casting when types don't match (default True)

Returns
-------
pa.Array
"""
from pyspark.errors import PySparkTypeError

if arr.type == target_type:
return arr

if not arrow_cast:
raise PySparkTypeError(
"Arrow UDFs require the return type to match the expected Arrow type. "
f"Expected: {target_type}, but got: {arr.type}."
)

# when safe is True, the cast will fail if there's a overflow or other
# unsafe conversion.
# RecordBatch.cast(...) isn't used as minimum PyArrow version
# required for RecordBatch.cast(...) is v16.0
return arr.cast(target_type=target_type, safe=safecheck)


class PandasToArrowConversion:
"""
Conversion utilities from pandas data to Arrow.
"""

@classmethod
def convert(
cls,
data: Union["pd.DataFrame", List[Union["pd.Series", "pd.DataFrame"]]],
schema: StructType,
*,
timezone: Optional[str] = None,
safecheck: bool = True,
arrow_cast: bool = False,
prefers_large_types: bool = False,
assign_cols_by_name: bool = False,
int_to_decimal_coercion_enabled: bool = False,
ignore_unexpected_complex_type_values: bool = False,
is_udtf: bool = False,
) -> "pa.RecordBatch":
"""
Convert a pandas DataFrame or list of Series/DataFrames to an Arrow RecordBatch.

Parameters
----------
data : pd.DataFrame or list of pd.Series/pd.DataFrame
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in what case the input is a list of DataFrames?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A list of DataFrames is used in stateful processing (e.g., applyInPandasWithState), where the batch contains multiple DataFrames representing different parts of the output (count, data, and state), each wrapped as a StructArray column. Updated the docstring to clarify this.

Input data - either a DataFrame or a list of Series/DataFrames.
schema : StructType
Spark schema defining the types for each column
timezone : str, optional
Timezone for timestamp conversion
safecheck : bool
Whether to use safe Arrow conversion (default True)
arrow_cast : bool
Whether to allow Arrow casting on type mismatch (default False)
prefers_large_types : bool
Whether to prefer large Arrow types (default False)
assign_cols_by_name : bool
Whether to reorder DataFrame columns by name to match schema (default False)
int_to_decimal_coercion_enabled : bool
Whether to enable int to decimal coercion (default False)
ignore_unexpected_complex_type_values : bool
Whether to ignore unexpected complex type values in converter (default False)
is_udtf : bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a TODO with JIRA to unify it in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added TODO with SPARK-55502.

Whether this conversion is for a UDTF. UDTFs use broader Arrow exception
handling to allow more type coercions (e.g., struct field casting via
ArrowTypeError), and convert errors to UDTF_ARROW_TYPE_CAST_ERROR.
Regular UDFs only catch ArrowInvalid to preserve legacy behavior where
e.g. string→decimal must raise an error. (default False)

Returns
-------
pa.RecordBatch
"""
import pyarrow as pa
import pandas as pd

from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkRuntimeError
from pyspark.sql.pandas.types import to_arrow_type, _create_converter_from_pandas

# Handle empty schema (0 columns)
# Use dummy column + select([]) to preserve row count (PyArrow limitation workaround)
if not schema.fields:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does not schema.fields means? fields is None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

schema.fields is a list, so not schema.fields checks for an empty list (0 columns). Changed to len(schema.fields) == 0 for clarity.

num_rows = len(data[0]) if isinstance(data, list) and data else len(data)
return pa.RecordBatch.from_pydict({"_": [None] * num_rows}).select([])

# Handle empty DataFrame (0 columns) with non-empty schema
# This happens when user returns pd.DataFrame() for struct types
if isinstance(data, pd.DataFrame) and len(data.columns) == 0:
arrow_type = to_arrow_type(
schema, timezone=timezone, prefers_large_types=prefers_large_types
)
return pa.RecordBatch.from_struct_array(pa.array([{}] * len(data), arrow_type))

# Normalize input: reorder DataFrame columns by schema names if needed,
# then extract columns as a list for uniform iteration.
if isinstance(data, list):
columns = data
else:
if assign_cols_by_name and any(isinstance(c, str) for c in data.columns):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If assign_cols_by_name is True but columns does not have name, what happens? Is the fallback behavior to ignore assign_cols_by_name expected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will fall back to use position (index) based column reference. This is intended, same as master.

data = data[schema.names]
columns = [data.iloc[:, i] for i in range(len(schema.fields))]

def series_to_array(series: "pd.Series", ret_type: DataType, field_name: str) -> "pa.Array":
"""Convert a pandas Series to an Arrow Array (closure over conversion params).

Uses field_name for error messages instead of series.name to avoid
copying the Series via rename() — a ~20% overhead on the hot path.
"""
if isinstance(series.dtype, pd.CategoricalDtype):
series = series.astype(series.dtype.categories.dtype)

arrow_type = to_arrow_type(
ret_type, timezone=timezone, prefers_large_types=prefers_large_types
)
series = _create_converter_from_pandas(
ret_type,
timezone=timezone,
error_on_duplicated_field_names=False,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values,
)(series)

mask = None if hasattr(series.array, "__arrow_array__") else series.isnull()

if is_udtf:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't the is_udtf handling be inside coerce_arrow_array?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The is_udtf special handling is in the from_pandas stage (catching broader ArrowException instead of just ArrowInvalid), not in the .cast() stage that coerce_arrow_array handles. So it fits better in series_to_array. This flag will be eliminated via SPARK-55502 when we unify UDTF and regular UDF conversion paths.

# UDTF path: broad ArrowException catch so that both ArrowInvalid
# AND ArrowTypeError (e.g. dict→struct) trigger the cast fallback.
try:
try:
return pa.Array.from_pandas(
series, mask=mask, type=arrow_type, safe=safecheck
)
except pa.lib.ArrowException: # broad: includes ArrowTypeError
if arrow_cast:
return pa.Array.from_pandas(series, mask=mask).cast(
target_type=arrow_type, safe=safecheck
)
raise
except pa.lib.ArrowException: # convert any Arrow error to user-friendly message
raise PySparkRuntimeError(
errorClass="UDTF_ARROW_TYPE_CAST_ERROR",
messageParameters={
"col_name": field_name,
"col_type": str(series.dtype),
"arrow_type": str(arrow_type),
},
) from None
else:
# UDF path: only ArrowInvalid triggers the cast fallback.
# ArrowTypeError (e.g. string→decimal) must NOT be silently cast.
try:
try:
return pa.Array.from_pandas(
series, mask=mask, type=arrow_type, safe=safecheck
)
except pa.lib.ArrowInvalid: # narrow: skip ArrowTypeError
if arrow_cast:
return pa.Array.from_pandas(series, mask=mask).cast(
target_type=arrow_type, safe=safecheck
)
raise
except TypeError as e: # includes pa.lib.ArrowTypeError
raise PySparkTypeError(
f"Exception thrown when converting pandas.Series ({series.dtype}) "
f"with name '{field_name}' to Arrow Array ({arrow_type})."
) from e
except ValueError as e: # includes pa.lib.ArrowInvalid
error_msg = (
f"Exception thrown when converting pandas.Series ({series.dtype}) "
f"with name '{field_name}' to Arrow Array ({arrow_type})."
)
if safecheck:
error_msg += (
" It can be caused by overflows or other unsafe conversions "
"warned by Arrow. Arrow safe type check can be disabled by using "
"SQL config `spark.sql.execution.pandas.convertToArrowArraySafely`."
)
raise PySparkValueError(error_msg) from e

def convert_column(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we consolidate convert_column and series_to_array?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

col: Union["pd.Series", "pd.DataFrame"], field: StructField
) -> "pa.Array":
"""Convert a single column (Series or DataFrame) to an Arrow Array."""
if isinstance(col, pd.DataFrame):
assert isinstance(field.dataType, StructType)
nested_batch = cls.convert(
col,
field.dataType,
timezone=timezone,
safecheck=safecheck,
arrow_cast=arrow_cast,
prefers_large_types=prefers_large_types,
assign_cols_by_name=assign_cols_by_name,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values,
is_udtf=is_udtf,
)
return ArrowBatchTransformer.wrap_struct(nested_batch).column(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line really takes me some seconds to remember what it does

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understand. Added a comment to clarify. We can revisit the wrap_struct transformer in the future.

return series_to_array(col, field.dataType, field.name)

arrays = [convert_column(col, field) for col, field in zip(columns, schema.fields)]
return pa.RecordBatch.from_arrays(arrays, schema.names)


class LocalDataToArrowConversion:
"""
Conversion from local data (except pandas DataFrame and numpy ndarray) to Arrow.
Expand Down
Loading