-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-55349][PYTHON] Consolidate pandas-to-Arrow conversion utilities in serializers #54125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 39 commits
6470f2d
750ac96
b26279b
2c6dcd1
632b4fc
9a73dcb
df31ecf
549cb7e
f6bb23e
4e27fbd
a1a2f15
8e4e1f4
c253a9e
0122b4c
1d6dc86
77a4ce5
d6efd1f
d891fb6
73d0318
39edd29
7879938
1749490
f931071
3e6fed7
73de5bf
c612f32
880b581
e648cb0
3f8f6fd
87262a8
d0e260d
c4e8bab
bf56510
e8aaddc
30ea61a
a10f02f
4925b72
c1d8d9f
4ee380a
100a2b6
4d84f7b
a8d0ba4
842b8aa
94a060d
20f7ea3
4da884f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -162,6 +162,239 @@ def to_pandas( | |
| ] | ||
|
|
||
|
|
||
| # TODO: elevate to ArrowBatchTransformer and operate on full RecordBatch schema | ||
| # instead of per-column coercion. | ||
| def coerce_arrow_array( | ||
| arr: "pa.Array", | ||
| target_type: "pa.DataType", | ||
| *, | ||
| safecheck: bool = True, | ||
| arrow_cast: bool = True, | ||
| ) -> "pa.Array": | ||
| """ | ||
| Coerce an Arrow Array to a target type, with optional type-mismatch enforcement. | ||
|
|
||
| When ``arrow_cast`` is True (default), mismatched types are cast to the | ||
| target type. When False, a type mismatch raises an error instead. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| arr : pa.Array | ||
| Input Arrow array | ||
| target_type : pa.DataType | ||
| Target Arrow type | ||
| safecheck : bool | ||
| Whether to use safe casting (default True) | ||
| arrow_cast : bool | ||
| Whether to allow casting when types don't match (default True) | ||
|
|
||
| Returns | ||
| ------- | ||
| pa.Array | ||
| """ | ||
| from pyspark.errors import PySparkTypeError | ||
|
|
||
| if arr.type == target_type: | ||
| return arr | ||
|
|
||
| if not arrow_cast: | ||
| raise PySparkTypeError( | ||
| "Arrow UDFs require the return type to match the expected Arrow type. " | ||
| f"Expected: {target_type}, but got: {arr.type}." | ||
| ) | ||
|
|
||
| # when safe is True, the cast will fail if there's a overflow or other | ||
| # unsafe conversion. | ||
| # RecordBatch.cast(...) isn't used as minimum PyArrow version | ||
| # required for RecordBatch.cast(...) is v16.0 | ||
| return arr.cast(target_type=target_type, safe=safecheck) | ||
|
|
||
|
|
||
| class PandasToArrowConversion: | ||
| """ | ||
| Conversion utilities from pandas data to Arrow. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def convert( | ||
| cls, | ||
| data: Union["pd.DataFrame", List[Union["pd.Series", "pd.DataFrame"]]], | ||
| schema: StructType, | ||
| *, | ||
| timezone: Optional[str] = None, | ||
| safecheck: bool = True, | ||
| arrow_cast: bool = False, | ||
| prefers_large_types: bool = False, | ||
| assign_cols_by_name: bool = False, | ||
| int_to_decimal_coercion_enabled: bool = False, | ||
| ignore_unexpected_complex_type_values: bool = False, | ||
| is_udtf: bool = False, | ||
| ) -> "pa.RecordBatch": | ||
| """ | ||
| Convert a pandas DataFrame or list of Series/DataFrames to an Arrow RecordBatch. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data : pd.DataFrame or list of pd.Series/pd.DataFrame | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in what case the input is a list of DataFrames?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A list of DataFrames is used in stateful processing (e.g., |
||
| Input data - either a DataFrame or a list of Series/DataFrames. | ||
| schema : StructType | ||
| Spark schema defining the types for each column | ||
| timezone : str, optional | ||
| Timezone for timestamp conversion | ||
| safecheck : bool | ||
| Whether to use safe Arrow conversion (default True) | ||
| arrow_cast : bool | ||
| Whether to allow Arrow casting on type mismatch (default False) | ||
| prefers_large_types : bool | ||
| Whether to prefer large Arrow types (default False) | ||
| assign_cols_by_name : bool | ||
| Whether to reorder DataFrame columns by name to match schema (default False) | ||
| int_to_decimal_coercion_enabled : bool | ||
| Whether to enable int to decimal coercion (default False) | ||
| ignore_unexpected_complex_type_values : bool | ||
| Whether to ignore unexpected complex type values in converter (default False) | ||
| is_udtf : bool | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add a TODO with JIRA to unify it in the future
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added TODO with SPARK-55502. |
||
| Whether this conversion is for a UDTF. UDTFs use broader Arrow exception | ||
| handling to allow more type coercions (e.g., struct field casting via | ||
| ArrowTypeError), and convert errors to UDTF_ARROW_TYPE_CAST_ERROR. | ||
| Regular UDFs only catch ArrowInvalid to preserve legacy behavior where | ||
| e.g. string→decimal must raise an error. (default False) | ||
|
|
||
| Returns | ||
| ------- | ||
| pa.RecordBatch | ||
| """ | ||
| import pyarrow as pa | ||
| import pandas as pd | ||
|
|
||
| from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkRuntimeError | ||
| from pyspark.sql.pandas.types import to_arrow_type, _create_converter_from_pandas | ||
|
|
||
| # Handle empty schema (0 columns) | ||
| # Use dummy column + select([]) to preserve row count (PyArrow limitation workaround) | ||
| if not schema.fields: | ||
|
||
| num_rows = len(data[0]) if isinstance(data, list) and data else len(data) | ||
| return pa.RecordBatch.from_pydict({"_": [None] * num_rows}).select([]) | ||
|
|
||
| # Handle empty DataFrame (0 columns) with non-empty schema | ||
| # This happens when user returns pd.DataFrame() for struct types | ||
| if isinstance(data, pd.DataFrame) and len(data.columns) == 0: | ||
| arrow_type = to_arrow_type( | ||
| schema, timezone=timezone, prefers_large_types=prefers_large_types | ||
| ) | ||
| return pa.RecordBatch.from_struct_array(pa.array([{}] * len(data), arrow_type)) | ||
|
|
||
| # Normalize input: reorder DataFrame columns by schema names if needed, | ||
| # then extract columns as a list for uniform iteration. | ||
| if isinstance(data, list): | ||
| columns = data | ||
| else: | ||
| if assign_cols_by_name and any(isinstance(c, str) for c in data.columns): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it will fall back to use position (index) based column reference. This is intended, same as master. |
||
| data = data[schema.names] | ||
| columns = [data.iloc[:, i] for i in range(len(schema.fields))] | ||
|
|
||
| def series_to_array(series: "pd.Series", ret_type: DataType, field_name: str) -> "pa.Array": | ||
| """Convert a pandas Series to an Arrow Array (closure over conversion params). | ||
|
|
||
| Uses field_name for error messages instead of series.name to avoid | ||
| copying the Series via rename() — a ~20% overhead on the hot path. | ||
| """ | ||
| if isinstance(series.dtype, pd.CategoricalDtype): | ||
| series = series.astype(series.dtype.categories.dtype) | ||
|
|
||
| arrow_type = to_arrow_type( | ||
| ret_type, timezone=timezone, prefers_large_types=prefers_large_types | ||
| ) | ||
| series = _create_converter_from_pandas( | ||
| ret_type, | ||
| timezone=timezone, | ||
| error_on_duplicated_field_names=False, | ||
| int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, | ||
| ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, | ||
| )(series) | ||
|
|
||
| mask = None if hasattr(series.array, "__arrow_array__") else series.isnull() | ||
|
|
||
| if is_udtf: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| # UDTF path: broad ArrowException catch so that both ArrowInvalid | ||
| # AND ArrowTypeError (e.g. dict→struct) trigger the cast fallback. | ||
| try: | ||
| try: | ||
| return pa.Array.from_pandas( | ||
| series, mask=mask, type=arrow_type, safe=safecheck | ||
| ) | ||
| except pa.lib.ArrowException: # broad: includes ArrowTypeError | ||
| if arrow_cast: | ||
| return pa.Array.from_pandas(series, mask=mask).cast( | ||
| target_type=arrow_type, safe=safecheck | ||
| ) | ||
| raise | ||
| except pa.lib.ArrowException: # convert any Arrow error to user-friendly message | ||
| raise PySparkRuntimeError( | ||
| errorClass="UDTF_ARROW_TYPE_CAST_ERROR", | ||
| messageParameters={ | ||
| "col_name": field_name, | ||
| "col_type": str(series.dtype), | ||
| "arrow_type": str(arrow_type), | ||
| }, | ||
| ) from None | ||
| else: | ||
| # UDF path: only ArrowInvalid triggers the cast fallback. | ||
| # ArrowTypeError (e.g. string→decimal) must NOT be silently cast. | ||
| try: | ||
| try: | ||
| return pa.Array.from_pandas( | ||
| series, mask=mask, type=arrow_type, safe=safecheck | ||
| ) | ||
| except pa.lib.ArrowInvalid: # narrow: skip ArrowTypeError | ||
| if arrow_cast: | ||
| return pa.Array.from_pandas(series, mask=mask).cast( | ||
| target_type=arrow_type, safe=safecheck | ||
| ) | ||
| raise | ||
| except TypeError as e: # includes pa.lib.ArrowTypeError | ||
| raise PySparkTypeError( | ||
| f"Exception thrown when converting pandas.Series ({series.dtype}) " | ||
| f"with name '{field_name}' to Arrow Array ({arrow_type})." | ||
| ) from e | ||
| except ValueError as e: # includes pa.lib.ArrowInvalid | ||
| error_msg = ( | ||
| f"Exception thrown when converting pandas.Series ({series.dtype}) " | ||
| f"with name '{field_name}' to Arrow Array ({arrow_type})." | ||
| ) | ||
| if safecheck: | ||
| error_msg += ( | ||
| " It can be caused by overflows or other unsafe conversions " | ||
| "warned by Arrow. Arrow safe type check can be disabled by using " | ||
| "SQL config `spark.sql.execution.pandas.convertToArrowArraySafely`." | ||
| ) | ||
| raise PySparkValueError(error_msg) from e | ||
|
|
||
| def convert_column( | ||
|
||
| col: Union["pd.Series", "pd.DataFrame"], field: StructField | ||
| ) -> "pa.Array": | ||
| """Convert a single column (Series or DataFrame) to an Arrow Array.""" | ||
| if isinstance(col, pd.DataFrame): | ||
| assert isinstance(field.dataType, StructType) | ||
| nested_batch = cls.convert( | ||
| col, | ||
| field.dataType, | ||
| timezone=timezone, | ||
| safecheck=safecheck, | ||
| arrow_cast=arrow_cast, | ||
| prefers_large_types=prefers_large_types, | ||
| assign_cols_by_name=assign_cols_by_name, | ||
| int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, | ||
| ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values, | ||
| is_udtf=is_udtf, | ||
| ) | ||
| return ArrowBatchTransformer.wrap_struct(nested_batch).column(0) | ||
|
||
| return series_to_array(col, field.dataType, field.name) | ||
|
|
||
| arrays = [convert_column(col, field) for col, field in zip(columns, schema.fields)] | ||
| return pa.RecordBatch.from_arrays(arrays, schema.names) | ||
|
|
||
|
|
||
| class LocalDataToArrowConversion: | ||
| """ | ||
| Conversion from local data (except pandas DataFrame and numpy ndarray) to Arrow. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have a
ArrowArrayConversion, should this function in it?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's already a TODO to elevate this to
ArrowBatchTransformerto operate on full RecordBatch schema instead of per-column coercion. Moving it toArrowArrayConversionfirst would just add an extra migration step. I'll address this when we do the batch-level coercion refactor.