1818from typing import (
1919 Any ,
2020 Callable ,
21+ Iterable ,
2122 List ,
2223 Optional ,
2324 Sequence ,
25+ Tuple ,
2426 Union ,
2527 cast ,
2628 no_type_check ,
4850
4951if TYPE_CHECKING :
5052 import numpy as np
53+ import pandas as pd
5154 import pyarrow as pa
5255 from py4j .java_gateway import JavaObject
5356
5457 from pyspark .sql .pandas ._typing import DataFrameLike as PandasDataFrameLike
5558 from pyspark .sql import DataFrame
5659
5760
61+ def create_arrow_array_from_pandas (
62+ series : "pd.Series" ,
63+ spark_type : Optional [DataType ],
64+ * ,
65+ timezone : Optional [str ] = None ,
66+ safecheck : bool = False ,
67+ prefers_large_types : bool = False ,
68+ ) -> "pa.Array" :
69+ """
70+ Create an Arrow Array from the given pandas.Series and Spark type.
71+
72+ Parameters
73+ ----------
74+ series : pandas.Series
75+ A single series
76+ spark_type : DataType, optional
77+ The Spark return type. If None, pyarrow's inferred type will be used.
78+ timezone : str, optional
79+ The timezone to use for timestamp conversions.
80+ safecheck : bool, optional
81+ Whether to enable safe type checking during conversion.
82+ prefers_large_types : bool, optional
83+ Whether to prefer large Arrow types (e.g., large_string instead of string).
84+
85+ Returns
86+ -------
87+ pyarrow.Array
88+ """
89+ import pyarrow as pa
90+ import pandas as pd
91+ from pyspark .sql .pandas .types import to_arrow_type , _create_converter_from_pandas
92+
93+ if isinstance (series .dtype , pd .CategoricalDtype ):
94+ series = series .astype (series .dtype .categories .dtype )
95+
96+ # Derive arrow_type from spark_type
97+ arrow_type = (
98+ to_arrow_type (spark_type , timezone = timezone , prefers_large_types = prefers_large_types )
99+ if spark_type is not None
100+ else None
101+ )
102+
103+ if spark_type is not None :
104+ conv = _create_converter_from_pandas (
105+ spark_type ,
106+ timezone = timezone ,
107+ error_on_duplicated_field_names = False ,
108+ )
109+ series = conv (series )
110+
111+ if hasattr (series .array , "__arrow_array__" ):
112+ mask = None
113+ else :
114+ mask = series .isnull ()
115+ try :
116+ return pa .Array .from_pandas (series , mask = mask , type = arrow_type , safe = safecheck )
117+ except TypeError as e :
118+ error_msg = (
119+ "Exception thrown when converting pandas.Series (%s) "
120+ "with name '%s' to Arrow Array (%s)."
121+ )
122+ raise PySparkTypeError (error_msg % (series .dtype , series .name , arrow_type )) from e
123+ except ValueError as e :
124+ error_msg = (
125+ "Exception thrown when converting pandas.Series (%s) "
126+ "with name '%s' to Arrow Array (%s)."
127+ )
128+ if safecheck :
129+ error_msg = error_msg + (
130+ " It can be caused by overflows or other "
131+ "unsafe conversions warned by Arrow. Arrow safe type check "
132+ "can be disabled by using SQL config "
133+ "`spark.sql.execution.pandas.convertToArrowArraySafely`."
134+ )
135+ raise PySparkValueError (error_msg % (series .dtype , series .name , arrow_type )) from e
136+
137+
138+ def create_arrow_batch_from_pandas (
139+ series_with_types : Iterable [Tuple ["pd.Series" , Optional [DataType ]]],
140+ * ,
141+ timezone : Optional [str ] = None ,
142+ safecheck : bool = False ,
143+ prefers_large_types : bool = False ,
144+ ) -> "pa.RecordBatch" :
145+ """
146+ Create an Arrow record batch from the given iterable of (series, spark_type) tuples.
147+
148+ Parameters
149+ ----------
150+ series_with_types : iterable
151+ Iterable of (series, spark_type) tuples.
152+ timezone : str, optional
153+ The timezone to use for timestamp conversions.
154+ safecheck : bool, optional
155+ Whether to enable safe type checking during conversion.
156+ prefers_large_types : bool, optional
157+ Whether to prefer large Arrow types (e.g., large_string instead of string).
158+
159+ Returns
160+ -------
161+ pyarrow.RecordBatch
162+ Arrow RecordBatch
163+ """
164+ import pyarrow as pa
165+
166+ arrs = [
167+ create_arrow_array_from_pandas (
168+ s ,
169+ spark_type ,
170+ timezone = timezone ,
171+ safecheck = safecheck ,
172+ prefers_large_types = prefers_large_types ,
173+ )
174+ for s , spark_type in series_with_types
175+ ]
176+ return pa .RecordBatch .from_arrays (arrs , ["_%d" % i for i in range (len (arrs ))])
177+
178+
58179def _convert_arrow_table_to_pandas (
59180 arrow_table : "pa.Table" ,
60181 schema : "StructType" ,
@@ -807,7 +928,7 @@ def _create_from_pandas_with_arrow(
807928
808929 assert isinstance (self , SparkSession )
809930
810- from pyspark .sql .pandas .serializers import ArrowStreamPandasSerializer
931+ from pyspark .sql .pandas .serializers import ArrowStreamSerializer
811932 from pyspark .sql .types import TimestampType
812933 from pyspark .sql .pandas .types import (
813934 from_arrow_type ,
@@ -877,20 +998,20 @@ def _create_from_pandas_with_arrow(
877998 step = step if step > 0 else len (pdf )
878999 pdf_slices = (pdf .iloc [start : start + step ] for start in range (0 , len (pdf ), step ))
8791000
880- # Create list of (columns, spark_type) for serializer dump_stream
881- arrow_data = [
882- [(c , t ) for (_ , c ), t in zip (pdf_slice .items (), spark_types )]
1001+ # Create Arrow batches directly using the standalone function
1002+ arrow_batches = [
1003+ create_arrow_batch_from_pandas (
1004+ [(c , t ) for (_ , c ), t in zip (pdf_slice .items (), spark_types )],
1005+ timezone = timezone ,
1006+ safecheck = safecheck ,
1007+ prefers_large_types = prefers_large_var_types ,
1008+ )
8831009 for pdf_slice in pdf_slices
8841010 ]
8851011
8861012 jsparkSession = self ._jsparkSession
8871013
888- ser = ArrowStreamPandasSerializer (
889- timezone = timezone ,
890- safecheck = safecheck ,
891- int_to_decimal_coercion_enabled = False ,
892- prefers_large_types = prefers_large_var_types ,
893- )
1014+ ser = ArrowStreamSerializer ()
8941015
8951016 @no_type_check
8961017 def reader_func (temp_filename ):
@@ -901,7 +1022,7 @@ def create_iter_server():
9011022 return self ._jvm .ArrowIteratorServer ()
9021023
9031024 # Create Spark DataFrame from Arrow stream file, using one batch per partition
904- jiter = self ._sc ._serialize_to_jvm (arrow_data , ser , reader_func , create_iter_server )
1025+ jiter = self ._sc ._serialize_to_jvm (arrow_batches , ser , reader_func , create_iter_server )
9051026 assert self ._jvm is not None
9061027 jdf = self ._jvm .PythonSQLUtils .toDataFrame (jiter , schema .json (), jsparkSession )
9071028 df = DataFrame (jdf , self )
0 commit comments