@@ -161,9 +161,12 @@ def wrap_and_init_stream():
161161 assert isinstance (batch , pa .RecordBatch )
162162
163163 # Wrap the root struct
164- struct = pa .StructArray .from_arrays (
165- batch .columns , fields = pa .struct (list (batch .schema ))
166- )
164+ if len (batch .columns ) == 0 :
165+ struct = pa .array ([{}] * batch .num_rows )
166+ else :
167+ struct = pa .StructArray .from_arrays (
168+ batch .columns , fields = pa .struct (list (batch .schema ))
169+ )
167170 batch = pa .RecordBatch .from_arrays ([struct ], ["_0" ])
168171
169172 # Write the first record batch with initialization.
@@ -175,6 +178,16 @@ def wrap_and_init_stream():
175178 return super (ArrowStreamUDFSerializer , self ).dump_stream (wrap_and_init_stream (), stream )
176179
177180
181+ class ArrowStreamUDTFSerializer (ArrowStreamUDFSerializer ):
182+ """
183+ Same as :class:`ArrowStreamSerializer` but it flattens the struct to Arrow record batch
184+ for applying each function with the raw record arrow batch. See also `DataFrame.mapInArrow`.
185+ """
186+
187+ def load_stream (self , stream ):
188+ return super (ArrowStreamUDFSerializer , self ).load_stream (stream )
189+
190+
178191class ArrowStreamGroupUDFSerializer (ArrowStreamUDFSerializer ):
179192 """
180193 Serializes pyarrow.RecordBatch data with Arrow streaming format.
@@ -566,151 +579,6 @@ def __repr__(self):
566579 return "ArrowStreamPandasUDFSerializer"
567580
568581
569- class ArrowStreamPandasUDTFSerializer (ArrowStreamPandasUDFSerializer ):
570- """
571- Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs.
572- """
573-
574- def __init__ (self , timezone , safecheck ):
575- super (ArrowStreamPandasUDTFSerializer , self ).__init__ (
576- timezone = timezone ,
577- safecheck = safecheck ,
578- # The output pandas DataFrame's columns are unnamed.
579- assign_cols_by_name = False ,
580- # Set to 'False' to avoid converting struct type inputs into a pandas DataFrame.
581- df_for_struct = False ,
582- # Defines how struct type inputs are converted. If set to "row", struct type inputs
583- # are converted into Rows. Without this setting, a struct type input would be treated
584- # as a dictionary. For example, for named_struct('name', 'Alice', 'age', 1),
585- # if struct_in_pandas="dict", it becomes {"name": "Alice", "age": 1}
586- # if struct_in_pandas="row", it becomes Row(name="Alice", age=1)
587- struct_in_pandas = "row" ,
588- # When dealing with array type inputs, Arrow converts them into numpy.ndarrays.
589- # To ensure consistency across regular and arrow-optimized UDTFs, we further
590- # convert these numpy.ndarrays into Python lists.
591- ndarray_as_list = True ,
592- # Enables explicit casting for mismatched return types of Arrow Python UDTFs.
593- arrow_cast = True ,
594- )
595- self ._converter_map = dict ()
596-
597- def _create_batch (self , series ):
598- """
599- Create an Arrow record batch from the given pandas.Series pandas.DataFrame
600- or list of Series or DataFrame, with optional type.
601-
602- Parameters
603- ----------
604- series : pandas.Series or pandas.DataFrame or list
605- A single series or dataframe, list of series or dataframe,
606- or list of (series or dataframe, arrow_type)
607-
608- Returns
609- -------
610- pyarrow.RecordBatch
611- Arrow RecordBatch
612- """
613- import pandas as pd
614- import pyarrow as pa
615-
616- # Make input conform to [(series1, type1), (series2, type2), ...]
617- if not isinstance (series , (list , tuple )) or (
618- len (series ) == 2 and isinstance (series [1 ], pa .DataType )
619- ):
620- series = [series ]
621- series = ((s , None ) if not isinstance (s , (list , tuple )) else s for s in series )
622-
623- arrs = []
624- for s , t in series :
625- if not isinstance (s , pd .DataFrame ):
626- raise PySparkValueError (
627- "Output of an arrow-optimized Python UDTFs expects "
628- f"a pandas.DataFrame but got: { type (s )} "
629- )
630-
631- arrs .append (self ._create_struct_array (s , t ))
632-
633- return pa .RecordBatch .from_arrays (arrs , ["_%d" % i for i in range (len (arrs ))])
634-
635- def _get_or_create_converter_from_pandas (self , dt ):
636- if dt not in self ._converter_map :
637- conv = _create_converter_from_pandas (
638- dt ,
639- timezone = self ._timezone ,
640- error_on_duplicated_field_names = False ,
641- ignore_unexpected_complex_type_values = True ,
642- )
643- self ._converter_map [dt ] = conv
644- return self ._converter_map [dt ]
645-
646- def _create_array (self , series , arrow_type , spark_type = None , arrow_cast = False ):
647- """
648- Override the `_create_array` method in the superclass to create an Arrow Array
649- from a given pandas.Series and an arrow type. The difference here is that we always
650- use arrow cast when creating the arrow array. Also, the error messages are specific
651- to arrow-optimized Python UDTFs.
652-
653- Parameters
654- ----------
655- series : pandas.Series
656- A single series
657- arrow_type : pyarrow.DataType, optional
658- If None, pyarrow's inferred type will be used
659- spark_type : DataType, optional
660- If None, spark type converted from arrow_type will be used
661- arrow_cast: bool, optional
662- Whether to apply Arrow casting when the user-specified return type mismatches the
663- actual return values.
664-
665- Returns
666- -------
667- pyarrow.Array
668- """
669- import pyarrow as pa
670- import pandas as pd
671-
672- if isinstance (series .dtype , pd .CategoricalDtype ):
673- series = series .astype (series .dtypes .categories .dtype )
674-
675- if arrow_type is not None :
676- dt = spark_type or from_arrow_type (arrow_type , prefer_timestamp_ntz = True )
677- conv = self ._get_or_create_converter_from_pandas (dt )
678- series = conv (series )
679-
680- if hasattr (series .array , "__arrow_array__" ):
681- mask = None
682- else :
683- mask = series .isnull ()
684-
685- try :
686- try :
687- return pa .Array .from_pandas (
688- series , mask = mask , type = arrow_type , safe = self ._safecheck
689- )
690- except pa .lib .ArrowException :
691- if arrow_cast :
692- return pa .Array .from_pandas (series , mask = mask ).cast (
693- target_type = arrow_type , safe = self ._safecheck
694- )
695- else :
696- raise
697- except pa .lib .ArrowException :
698- # Display the most user-friendly error messages instead of showing
699- # arrow's error message. This also works better with Spark Connect
700- # where the exception messages are by default truncated.
701- raise PySparkRuntimeError (
702- errorClass = "UDTF_ARROW_TYPE_CAST_ERROR" ,
703- messageParameters = {
704- "col_name" : series .name ,
705- "col_type" : str (series .dtype ),
706- "arrow_type" : arrow_type ,
707- },
708- ) from None
709-
710- def __repr__ (self ):
711- return "ArrowStreamPandasUDTFSerializer"
712-
713-
714582class CogroupArrowUDFSerializer (ArrowStreamGroupUDFSerializer ):
715583 """
716584 Serializes pyarrow.RecordBatch data with Arrow streaming format.
0 commit comments